[
  {
    "path": ".dockerignore",
    "content": "aml\ntarget\nserver/transformers\nserver/flash-attention\ncmake-build-debug/\ncmake-build-release/\nDockerfile*\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yml",
    "content": "name: \"\\U0001F41B Bug Report\"\ndescription: Submit a bug report to help us improve text-generation-inference\nbody:\n  - type: textarea\n    id: system-info\n    attributes:\n      label: System Info\n      description: |\n        Please share your system info with us (`text-generation-launcher --env` if installed locally).\n        The full command line used that causes issues:\n        OS version:\n        Rust version (if self-compiling, `cargo version`):\n        Model being used (`curl 127.0.0.1:8080/info | jq`):\n          If local model please explicit the kind of model and/or equivalents.\n        Hardware used (GPUs, how many, on which cloud) (`nvidia-smi`):\n        Deployment specificities (Kubernetes, EKS, AKS, any particular deployments):\n        The current version being used:\n\n      placeholder: text-generation-inference version, platform, python version, ...\n    validations:\n      required: true\n\n  - type: checkboxes\n    id: information-scripts-examples\n    attributes:\n      label: Information\n      description: 'The problem arises when using:'\n      options:\n        - label: \"Docker\"\n        - label: \"The CLI directly\"\n\n  - type: checkboxes\n    id: information-tasks\n    attributes:\n      label: Tasks\n      description: \"The thing I am working on is:\"\n      options:\n        - label: \"An officially supported command\"\n        - label: \"My own modifications\"\n\n  - type: textarea\n    id: reproduction\n    validations:\n      required: true\n    attributes:\n      label: Reproduction\n      description: |\n        Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.\n        If you have code snippets, error messages, stack traces please provide them here as well.\n        Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting\n        Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.\n\n      placeholder: |\n        Steps to reproduce the behavior:\n\n          1.\n          2.\n          3.\n\n\n  - type: textarea\n    id: expected-behavior\n    validations:\n      required: true\n    attributes:\n      label: Expected behavior\n      description: \"A clear and concise description of what you would expect to happen.\"\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: true\nversion: 2.1\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.yml",
    "content": "name: \"\\U0001F680 Feature request\"\ndescription: Submit a proposal/request for a new text-generation-inference feature\nlabels: [ \"feature\" ]\nbody:\n  - type: textarea\n    id: feature-request\n    validations:\n      required: true\n    attributes:\n      label: Feature request\n      description: |\n        A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.\n\n  - type: textarea\n    id: motivation\n    validations:\n      required: true\n    attributes:\n      label: Motivation\n      description: |\n        Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.\n\n\n  - type: textarea\n    id: contribution\n    validations:\n      required: true\n    attributes:\n      label: Your contribution\n      description: |\n        Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md)\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/new-model-addition.yml",
    "content": "name: \"\\U0001F31F New model addition\"\ndescription: Submit a proposal/request to implement a new model\nlabels: [ \"New model\" ]\n\nbody:\n  - type: textarea\n    id: description-request\n    validations:\n      required: true\n    attributes:\n      label: Model description\n      description: |\n        Put any and all important information relative to the model\n\n  - type: checkboxes\n    id: information-tasks\n    attributes:\n      label: Open source status\n      description: |\n          Please note that if the model implementation isn't available or if the weights aren't open-source, we are less likely to implement it in `transformers`.\n      options:\n        - label: \"The model implementation is available\"\n        - label: \"The model weights are available\"\n\n  - type: textarea\n    id: additional-info\n    attributes:\n      label: Provide useful links for the implementation\n      description: |\n        Please provide information regarding the implementation, the weights, and the authors.\n        Please mention the authors by @gh-username if you're aware of their usernames.\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.\n\nThen, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.\n\nOnce you're done, someone will review your PR shortly (see the section \"Who can review?\" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.\n-->\n\n<!-- Remove if not applicable -->\n\nFixes # (issue)\n\n\n## Before submitting\n- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).\n- [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),\n      Pull Request section?\n- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link\n      to it if that's the case.\n- [ ] Did you make sure to update the documentation with your changes? Here are the\n      [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and\n      [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).\n- [ ] Did you write any new necessary tests?\n\n\n## Who can review?\n\nAnyone in the community is free to review the PR once the tests have passed. Feel free to tag\nmembers/contributors who may be interested in your PR.\n\n<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @\n\n\n@OlivierDehaene OR @Narsil\n\n -->\n"
  },
  {
    "path": ".github/workflows/autodocs.yaml",
    "content": "name: Automatic Documentation for Launcher\n\non:\n  pull_request:\n\njobs:\n  update_docs:\n    runs-on: ubuntu-latest\n\n    steps:\n    - name: Checkout code\n      uses: actions/checkout@v2\n\n    - name: Set up Rust\n      uses: actions-rs/toolchain@v1\n      with:\n        profile: minimal\n        toolchain: stable\n\n    - name: Install Protocol Buffers compiler\n      run: |\n        sudo apt-get update\n        sudo apt-get install -y protobuf-compiler libprotobuf-dev\n\n    - name: Install Launcher\n      id: install-launcher\n      run: cargo install --path launcher/\n\n    - name: Install router\n      id: install-router\n      run: cargo install --path backends/v3/\n\n    - uses: actions/setup-node@v4\n      with:\n        node-version: 22\n\n    - name: Set up Python\n      uses: actions/setup-python@v2\n      with:\n        python-version: '3.x'\n\n    - name: Check that documentation is up-to-date\n      run: |\n        npm install -g @redocly/cli@1.34.2\n        python update_doc.py --check\n"
  },
  {
    "path": ".github/workflows/build.yaml",
    "content": "name: Build and push docker image to internal registry\n\non:\n  workflow_call:\n    inputs:\n      hardware:\n        type: string\n        description: Hardware\n        # options:\n        # - cuda\n        # - cuda-trtllm\n        # - rocm\n        # - intel\n        required: true\n      release-tests:\n        description: \"Run release integration tests\"\n        required: true\n        default: false\n        type: boolean\n\njobs:\n  build-and-push:\n    outputs:\n      docker_image: ${{ steps.final.outputs.docker_image }}\n      docker_volume: ${{ steps.final.outputs.docker_volume }}\n      docker_devices: ${{ steps.final.outputs.docker_devices }}\n      runs_on: ${{ steps.final.outputs.runs_on }}\n      label_extension: ${{ steps.final.outputs.label_extension }}\n      extra_pytest: ${{ steps.final.outputs.extra_pytest }}\n    concurrency:\n      group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    runs-on:\n      group: aws-highmemory-64-plus-priv\n    permissions:\n      contents: write\n      packages: write\n      id-token: write\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n      - name: Inject slug/short variables\n        uses: rlespinasse/github-slug-action@v4.4.1\n      - name: Inject required variables for sccache to interact with Github Actions Cache\n        uses: actions/github-script@v7\n        with:\n          script: |\n            core.exportVariable('ACTIONS_RESULTS_URL', process.env.ACTIONS_RESULTS_URL || '');\n            core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || '');\n\n      - name: Extract TensorRT-LLM version\n        run: |\n          echo \"TENSORRT_LLM_VERSION=$(grep -oP '([a-z,0-9]{40})' $GITHUB_WORKSPACE/backends/trtllm/cmake/trtllm.cmake)\" >> $GITHUB_ENV\n          echo \"TensorRT-LLM version: ${{ env.TENSORRT_LLM_VERSION }}\"\n      - name: Construct hardware variables\n        shell: bash\n        run: |\n          case ${{ inputs.hardware }} in\n            cuda)\n                export dockerfile=\"Dockerfile\"\n                export label_extension=\"\"\n                export docker_volume=\"/mnt/cache\"\n                export docker_devices=\"\"\n                export runs_on=\"aws-g6-12xl-plus-priv-cache\"\n                export platform=\"\"\n                export extra_pytest=\"\"\n                export target=\"\"\n                ;;\n            cuda-trtllm)\n                export dockerfile=\"Dockerfile_trtllm\"\n                export label_extension=\"-trtllm\"\n                export docker_volume=\"/mnt/cache\"\n                export docker_devices=\"\"\n                export runs_on=\"ubuntu-latest\"\n                export platform=\"\"\n                export extra_pytest=\"\"\n                if [[ \"${GITHUB_REF}\" == refs/tags/* ]]; then\n                  export build_type=\"release\";\n                  export target=\"\";\n                else\n                  export build_type=\"dev\";\n                  export target=\"ci-runtime\";\n                fi\n                ;;\n            rocm)\n                export dockerfile=\"Dockerfile_amd\"\n                export label_extension=\"-rocm\"\n                export docker_devices=\"/dev/kfd,/dev/dri\"\n                export docker_volume=\"/mnt\"\n                # This runner was deactivated.\n                export runs_on=\"ubuntu-latest\"\n                export platform=\"\"\n                export extra_pytest=\"-k test_flash_gemma_gptq_load\"\n                export target=\"\"\n                ;;\n            intel-xpu)\n                export dockerfile=\"Dockerfile_intel\"\n                export label_extension=\"-intel-xpu\"\n                export docker_devices=\"\"\n                export docker_volume=\"/mnt/cache\"\n                export runs_on=\"ubuntu-latest\"\n                export platform=\"xpu\"\n                export extra_pytest=\"\"\n                export target=\"\"\n                ;;\n            intel-cpu)\n                export dockerfile=\"Dockerfile_intel\"\n                export label_extension=\"-intel-cpu\"\n                export docker_devices=\"none\"\n                export docker_volume=\"/mnt/cache\"\n                # export runs_on=\"ubuntu-latest\"\n                export runs_on=\"aws-highmemory-32-plus-priv\"\n                export platform=\"cpu\"\n                export extra_pytest=\"-k test_flash_gemma_simple\"\n                export target=\"\"\n                ;;\n            neuron)\n                export dockerfile=\"Dockerfile.neuron\"\n                export label_extension=\"-neuron\"\n                export docker_devices=\"/dev/neuron0\"\n                export docker_volume=\"/mnt/cache\"\n                export runs_on=\"aws-inf2-8xlarge\"\n                export platform=\"cpu\"\n                export extra_pytest=\"--neuron\"\n                export target=\"\"\n                ;;\n            gaudi)\n                export dockerfile=\"Dockerfile_gaudi\"\n                export label_extension=\"-gaudi\"\n                export docker_volume=\"/mnt/cache\"\n                export docker_devices=\"\"\n                export runs_on=\"itac-bm-emr-gaudi3-dell-2gaudi\"\n                export platform=\"\"\n                export extra_pytest=\"--gaudi\"\n                export target=\"\"\n          esac\n          echo $dockerfile\n          echo \"Dockerfile=${dockerfile}\"\n          echo $label_extension\n          echo $docker_devices\n          echo $runs_on\n          echo $platform\n          echo \"DOCKERFILE=${dockerfile}\" >> $GITHUB_ENV\n          echo \"LABEL_EXTENSION=${label_extension}\" >> $GITHUB_ENV\n          echo \"PLATFORM=${platform}\" >> $GITHUB_ENV\n          echo \"DOCKER_VOLUME=${docker_volume}\" >> $GITHUB_ENV\n          echo \"DOCKER_DEVICES=${docker_devices}\" >> $GITHUB_ENV\n          echo \"RUNS_ON=${runs_on}\" >> $GITHUB_ENV\n          echo \"EXTRA_PYTEST=${extra_pytest}\" >> $GITHUB_ENV\n          echo REGISTRY_MIRROR=$REGISTRY_MIRROR >> $GITHUB_ENV\n          echo \"TARGET=${target}\" >> $GITHUB_ENV\n          echo \"BUILD_TYPE=${build_type}\" >> $GITHUB_ENV\n      - name: Initialize Docker Buildx\n        uses: docker/setup-buildx-action@v3\n        with:\n          install: true\n          buildkitd-config: /tmp/buildkitd.toml\n      - name: Login to internal Container Registry\n        if: github.event_name != 'pull_request'\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.REGISTRY_USERNAME }}\n          password: ${{ secrets.REGISTRY_PASSWORD }}\n          registry: registry.internal.huggingface.tech\n      - name: Login to GitHub Container Registry\n        if: github.event_name != 'pull_request'\n        uses: docker/login-action@v3\n        with:\n          registry: ghcr.io\n          username: ${{ github.actor }}\n          password: ${{ secrets.GITHUB_TOKEN }}\n      - name: Login to Docker Hub Container Registry\n        uses: docker/login-action@v3\n        with:\n          registry: docker.io\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      - name: configure aws credentials\n        id: aws-creds\n        uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502\n        with:\n          role-to-assume: ${{ secrets.AWS_ROLE_GITHUB_BUILDX_CACHE }}\n          role-duration-seconds: 18000\n          aws-region: us-east-1\n          output-credentials: true\n      # If pull request\n      - name: Extract metadata (tags, labels) for Docker\n        if: ${{ github.event_name == 'pull_request' }}\n        id: meta-pr\n        uses: docker/metadata-action@v5\n        with:\n          images: |\n            docker.io/huggingface/text-generation-inference-ci\n          tags: |\n            type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}\n      # If main, release or tag\n      - name: Extract metadata (tags, labels) for Docker\n        if: ${{ github.event_name != 'pull_request' }}\n        id: meta\n        uses: docker/metadata-action@v4.3.0\n        with:\n          flavor: |\n            latest=false\n          images: |\n            registry.internal.huggingface.tech/api-inference/community/text-generation-inference\n            ghcr.io/huggingface/text-generation-inference\n          tags: |\n            type=semver,pattern={{version}}${{ env.LABEL_EXTENSION }}\n            type=semver,pattern={{major}}.{{minor}}${{ env.LABEL_EXTENSION }}\n            type=raw,value=latest${{ env.LABEL_EXTENSION }},enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }}\n            type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}\n      - name: Build and push Docker image\n        id: build-and-push\n        uses: docker/build-push-action@v4\n        env: \n          DOCKER_BUILD_SUMMARY: false\n        with:\n          context: .\n          file: ${{ env.DOCKERFILE }}\n          push: true\n          platforms: 'linux/amd64'\n          build-args: |\n            GIT_SHA=${{ env.GITHUB_SHA }}\n            DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL_EXTENSION }}\n            PLATFORM=${{ env.PLATFORM }}\n            build_type=${{ env.BUILD_TYPE }}\n            sccache_gha_enabled=on\n          secrets: |\n            actions_results_url=${{ env.ACTIONS_RESULTS_URL }}\n            actions_runtime_token=${{ env.ACTIONS_RUNTIME_TOKEN }}\n          target: ${{ env.TARGET }}\n          tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}\n          labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}\n          cache-from: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max\n          cache-to: type=s3,region=us-east-1,bucket=${{ vars.AWS_S3BUCKET_GITHUB_BUILDX_CACHE }},name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ steps.aws-creds.outputs.aws-access-key-id }},secret_access_key=${{ steps.aws-creds.outputs.aws-secret-access-key }},session_token=${{ steps.aws-creds.outputs.aws-session-token }},mode=max\n      - name: Final\n        id: final\n        run: |\n\n          if [ \"${{ github.event_name }}\" = \"pull_request\" ]; then\n            echo \"docker_image=docker.io/huggingface/text-generation-inference-ci:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}\" >> \"$GITHUB_OUTPUT\"\n          else\n            echo \"docker_image=ghcr.io/huggingface/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL_EXTENSION }}\" >> \"$GITHUB_OUTPUT\"\n          fi\n          echo \"docker_devices=${{ env.DOCKER_DEVICES }}\" >> \"$GITHUB_OUTPUT\"\n          echo \"docker_volume=${{ env.DOCKER_VOLUME }}\" >> \"$GITHUB_OUTPUT\"\n          echo \"runs_on=${{ env.RUNS_ON }}\" >> \"$GITHUB_OUTPUT\"\n          echo \"label_extension=${{ env.LABEL_EXTENSION }}\" >> \"$GITHUB_OUTPUT\"\n          echo \"extra_pytest=${{ env.EXTRA_PYTEST }}\" >> \"$GITHUB_OUTPUT\"\n  precompile_neuron_models:\n    concurrency:\n      group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    needs: build-and-push\n    if: needs.build-and-push.outputs.label_extension == '-neuron'\n    runs-on:\n      group: ${{ needs.build-and-push.outputs.runs_on }}\n    env:\n      PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n      - name: Inject slug/short variables\n        uses: rlespinasse/github-slug-action@v4.4.1\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.11\"\n      - name: Install\n        run: |\n          make install-integration-tests\n      - name: Export neuron models\n        run: |\n          export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}\n          echo $DOCKER_IMAGE\n          docker pull $DOCKER_IMAGE\n          export HF_TOKEN=${{ secrets.HF_TOKEN_NEURON }}\n          python integration-tests/fixtures/neuron/export_models.py\n  integration_tests:\n    concurrency:\n      group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label_extension }}-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    needs: [precompile_neuron_models, build-and-push]\n    if: ${{ always() && !contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && needs.build-and-push.outputs.runs_on != 'ubuntu-latest' }}\n    runs-on:\n      group: ${{ needs.build-and-push.outputs.runs_on }}\n    env:\n      PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '--release' }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n      - name: Inject slug/short variables\n        uses: rlespinasse/github-slug-action@v4.4.1\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: \"3.11\"\n      - name: Install\n        run: |\n          make install-integration-tests\n      - name: Run tests\n        run: |\n          export DOCKER_VOLUME=${{ needs.build-and-push.outputs.docker_volume }}\n          export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}\n          export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}\n          export EXTRA_PYTEST=\"${{ needs.build-and-push.outputs.extra_pytest }}\"\n          export HF_TOKEN=${{ secrets.HF_TOKEN }}\n          echo $DOCKER_IMAGE\n          docker pull $DOCKER_IMAGE\n          pytest -s -vv integration-tests ${PYTEST_FLAGS} ${EXTRA_PYTEST}\n\n  backend_trtllm_cxx_tests:\n    needs: build-and-push\n    if: needs.build-and-push.outputs.label_extension == '-trtllm'\n    concurrency:\n      group: ${{ github.workflow }}-${{ github.job }}-trtllm-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    runs-on:\n      group: aws-g6-12xl-plus-priv-cache\n    container:\n      image: ${{ needs.build-and-push.outputs.docker_image }}\n      credentials:\n        username: ${{ secrets.DOCKERHUB_USERNAME }}\n        password: ${{ secrets.DOCKERHUB_PASSWORD }}\n      options: --gpus all --shm-size=8g\n\n    steps:\n      - name: Run C++/CUDA tests\n        if: ${{ env.LABEL_EXTENSION == 'ci-runtime' }}\n        run: /usr/local/tgi/bin/tgi_trtllm_backend_tests\n"
  },
  {
    "path": ".github/workflows/build_documentation.yaml",
    "content": "name: Build documentation\n\non:\n  push:\n    paths:\n      - \"docs/source/**\"\n    branches:\n      - main\n      - doc-builder*\n      - v*-release\n\njobs:\n   build:\n    uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main\n    with:\n      commit_sha: ${{ github.sha }}\n      package: text-generation-inference\n      additional_args: --not_python_module\n    secrets:\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n"
  },
  {
    "path": ".github/workflows/build_pr_documentation.yaml",
    "content": "name: Build PR Documentation\n\non:\n  pull_request:\n    paths:\n      - \"docs/source/**\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main\n    with:\n      commit_sha: ${{ github.event.pull_request.head.sha }}\n      pr_number: ${{ github.event.number }}\n      package: text-generation-inference\n      additional_args: --not_python_module\n"
  },
  {
    "path": ".github/workflows/ci_build.yaml",
    "content": "name: CI build\n\non:\n  push:\n    branches:\n      - 'main'\n    tags:\n      - 'v*'\n  pull_request:\n    paths:\n      - \".github/workflows/build.yaml\"\n      - \"integration-tests/**\"\n      - \"backends/**\"\n      - \"server/**\"\n      - \"proto/**\"\n      - \"router/**\"\n      - \"launcher/**\"\n      - \"Cargo.lock\"\n      - \"rust-toolchain.toml\"\n      - \"Dockerfile\"\n      - \"Dockerfile_amd\"\n      - \"Dockerfile_intel\"\n      - \"Dockerfile.neuron\"\n      - \"Dockerfile_gaudi\"\n    branches:\n      - \"main\"\n  workflow_dispatch:\n    inputs:\n      release-tests:\n        description: \"Run release integration tests\"\n        required: true\n        default: false\n        type: boolean\n\njobs:\n  build:\n    strategy:\n      # super important if you want to see all results, even if one fails\n      # fail-fast is true by default\n      fail-fast: false\n      matrix:\n        hardware: [\"cuda\", \"cuda-trtllm\", \"rocm\", \"intel-xpu\", \"intel-cpu\", \"neuron\", \"gaudi\"]\n    uses: ./.github/workflows/build.yaml # calls the one above ^\n    permissions:\n      contents: write\n      packages: write\n      id-token: write\n    with:\n      hardware: ${{ matrix.hardware }}\n      # https://github.com/actions/runner/issues/2206\n      release-tests: ${{ inputs.release-tests == true }}\n    secrets: inherit\n"
  },
  {
    "path": ".github/workflows/client-tests.yaml",
    "content": "name: Python Client Tests\n\non:\n  pull_request:\n    paths:\n      - \".github/workflows/client-tests.yaml\"\n      - \"clients/python/**\"\n\njobs:\n  run_tests:\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python\n        uses: actions/setup-python@v1\n        with:\n          python-version: 3.9\n      - name: Install\n        run: |\n          cd clients/python && pip install .\n      - name: Run tests\n        run: |\n          pip install pytest pytest-asyncio\n          export HF_TOKEN=${{ secrets.HF_TOKEN }}\n          make python-client-tests\n"
  },
  {
    "path": ".github/workflows/codeql.yml",
    "content": "---\nname: CodeQL Security Analysis For Github Actions\n\non:\n  push:\n    branches: [\"main\"]\n  workflow_dispatch:\n  # pull_request:\n\njobs:\n  codeql:\n    name: CodeQL Analysis\n    uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1.2.0\n    permissions:\n      security-events: write\n      packages: read\n      actions: read\n      contents: read\n    with:\n      languages: '[\"actions\"]'\n      queries: 'security-extended,security-and-quality'\n      runner: 'ubuntu-latest' #optional if need custom runner\n      use-runner-group: false #optional\n\n      # if need to use runner group:\n      # runner: 'cpu-low'\n      # use-runner-group: true\n"
  },
  {
    "path": ".github/workflows/integration_tests.yaml",
    "content": "name: Integration tests\n\non:\n  workflow_call:\n    inputs:\n      docker_image:\n        type: string\n        description: Hardware\n        required: true\n      docker_devices:\n        type: string\n        description: Hardware\n      runs_on:\n        type: string\n        required: true\n        description: Hardware to run integration tests\njobs:\n  integration_tests:\n    concurrency:\n      group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    runs-on: ${{ inputs.runs_on }}\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n      - name: Inject slug/short variables\n        uses: rlespinasse/github-slug-action@v4.4.1\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: 3.9\n      - name: Install\n        run: |\n          make install-integration-tests\n      - name: Run tests\n        run: |\n          export DOCKER_VOLUME=/mnt/cache\n          export DOCKER_IMAGE=${{ inputs.docker_image }}\n          export DOCKER_DEVICES=${{ inputs.docker_devices }}\n          export HF_TOKEN=${{ secrets.HF_TOKEN }}\n          pytest -s -vv integration-tests\n"
  },
  {
    "path": ".github/workflows/load_test.yaml",
    "content": "name: Nightly load test\n\non:\n  schedule:\n    - cron: '0 0 * * 1-5'\n  workflow_call:\n  workflow_dispatch:\n\n  pull_request:\n    paths:\n      - \".github/workflows/load_test.yaml\"\n\nenv:\n  AWS_DEFAULT_REGION: us-east-1\n  AWS_ACCESS_KEY_ID: ${{ secrets.S3_AWS_ACCESS_KEY_ID }}\n  AWS_SECRET_ACCESS_KEY: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }}\n\njobs:\n  load-tests:\n    concurrency:\n      group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}\n      cancel-in-progress: true\n    runs-on:\n      group: aws-g6-12xl-plus-priv-cache\n    env:\n      DOCKER_VOLUME: /cache\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v3\n\n      - name: Install Python 3.11\n        uses: actions/setup-python@v2\n        with:\n          python-version: 3.11\n\n      - name: Install poetry\n        run: |\n          curl -sSL https://install.python-poetry.org | python3 -\n          export PATH=\"$HOME/.local/bin:$PATH\"\n          poetry --version\n\n      - name: Run bench test\n        run: |\n          export PATH=\"$HOME/.local/bin:$PATH\"\n          cd load_tests\n          poetry install\n          poetry run python benchmarks.py --sha ${{ github.sha }} --results-file \"s3://text-generation-inference-ci/benchmarks/ci/${{ github.sha }}.parquet\"\n        shell: bash\n        env:\n          HF_TOKEN: ${{ secrets.HF_TOKEN_BENCHMARK }}\n"
  },
  {
    "path": ".github/workflows/nix_build.yaml",
    "content": "name: \"Nix Build Docker image\"\non:\n  pull_request:\n  push:\n    branches:\n      - 'main'\n    tags:\n      - 'v*'\nconcurrency:\n  group: nix-image-${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  build_nix_image:\n    runs-on:\n      group: aws-highmemory-32-plus-priv\n    steps:\n    - uses: actions/checkout@v4\n    - uses: cachix/install-nix-action@v27\n      with:\n        nix_path: nixpkgs=channel:nixos-unstable\n    - uses: cachix/cachix-action@v14\n      with:\n        name: huggingface\n        # If you chose signing key for write access\n        # authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'\n      env:\n        USER: github_runner\n    - name: Build\n      run: nix build .#dockerImage\n    - name: Initialize Docker Buildx\n      uses: docker/setup-buildx-action@v3\n      with:\n        install: true\n        buildkitd-config: /tmp/buildkitd.toml\n    - name: Inject slug/short variables\n      uses: rlespinasse/github-slug-action@v4.4.1\n    - name: Login to internal Container Registry\n      # if: github.event_name != 'pull_request'\n      uses: docker/login-action@v3\n      with:\n        username: ${{ secrets.REGISTRY_USERNAME }}\n        password: ${{ secrets.REGISTRY_PASSWORD }}\n        registry: registry.internal.huggingface.tech\n    - name: Push to docker\n      run: |\n        if [ \"${{ github.event_name }}\" = \"pull_request\" ]; then\n          export TAG=nix-sha-${{ env.GITHUB_SHA_SHORT }}\n        else\n          export TAG=${{ github.ref_name }}-nix\n        fi\n        export IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:$TAG\n        nix-shell -p skopeo --command \"skopeo --insecure-policy copy docker-archive:$(readlink -f ./result) docker://$IMAGE --dest-compress-format zstd\"\n"
  },
  {
    "path": ".github/workflows/nix_cache.yaml",
    "content": "name: \"Cache devshells\"\non:\n  pull_request:\n    paths:\n      - \"flake.nix\"\n      - \"flake.lock\"\n      - \"nix/**\"\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  tests:\n    runs-on:\n      group: aws-highmemory-32-plus-priv\n    steps:\n      - uses: actions/checkout@v4\n      - uses: cachix/install-nix-action@v27\n        with:\n          nix_path: nixpkgs=channel:nixos-unstable\n      - uses: cachix/cachix-action@v14\n        with:\n          name: huggingface\n          # If you chose signing key for write access\n          #authToken: \"${{ secrets.CACHIX_AUTH_TOKEN }}\"\n        env:\n          USER: github_runner\n      - name: Build impure devshell\n        run: nix build .\\#devShells.x86_64-linux.impure\n      - name: Build impure devshell (CUDA dev)\n        run: nix build .\\#devShells.x86_64-linux.impureWithCuda\n      # Pure shell dependencies are covered by Nix tests.\n      # - name: Build pure devshell\n      #   run: nix build .\\#devShells.x86_64-linux.pure\n"
  },
  {
    "path": ".github/workflows/nix_tests.yaml",
    "content": "name: \"Nix Tests\"\non:\n  pull_request:\n    paths:\n      - \".github/workflows/nix_tests.yaml\"\n      - \"server/**\"\n      - \"proto/**\"\n      - \"router/**\"\n      - \"launcher/**\"\n      - \"backends/**\"\n      - \"Cargo.lock\"\n      - \"rust-toolchain.toml\"\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  tests:\n    runs-on:\n      group: aws-highmemory-32-plus-priv\n    steps:\n    - uses: actions/checkout@v4\n    - uses: cachix/install-nix-action@v27\n      with:\n        nix_path: nixpkgs=channel:nixos-unstable\n    - uses: cachix/cachix-action@v14\n      with:\n        name: huggingface\n        # If you chose signing key for write access\n        #authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}'\n      env:\n        USER: github_runner\n    - name: Nix info\n      run: nix-shell -p nix-info --run \"nix-info -m\"\n    - name: Build\n      run: nix develop .#test --command echo \"Ok\"\n    - name: Pre-commit tests.\n      run: nix develop .#test --command pre-commit run --all-files\n    - name: Python tests.\n      run: nix develop .#test --command python -m pytest server/tests/\n      env:\n        HF_TOKEN: ${{ secrets.HF_TOKEN }}\n    - name: Rust tests.\n      run: nix develop .#test --command cargo test\n"
  },
  {
    "path": ".github/workflows/stale.yaml",
    "content": "name: 'Close stale issues and PRs'\non:\n  schedule:\n    - cron: '30 1 * * *'\n\njobs:\n  stale:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/stale@v8\n        with:\n          stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'\n          days-before-stale: 30\n          days-before-close: 5\n"
  },
  {
    "path": ".github/workflows/tests.yaml",
    "content": "name: Server Tests\n\non:\n  pull_request:\n    paths:\n      - \".github/workflows/tests.yaml\"\n      - \"server/**\"\n      - \"proto/**\"\n      - \"router/**\"\n      - \"launcher/**\"\n      - \"backends/**\"\n      - \"Cargo.lock\"\n      - \"rust-toolchain.toml\"\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}\n  cancel-in-progress: true\n\njobs:\n  run_tests:\n    runs-on:\n      group: aws-highmemory-32-plus-priv\n    steps:\n      - uses: actions/checkout@v4\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        id: python\n        with:\n          python-version: 3.11\n      - uses: dtolnay/rust-toolchain@1.85.0\n        with:\n          components: rustfmt, clippy\n      - name: Install Protoc\n        uses: arduino/setup-protoc@v1\n      - name: Clean unused files\n        run: |\n          sudo rm -rf /usr/local/lib/android # will release about 10 GB if you don't need Android\n          sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET\n      - name: Install\n        run: |\n          sudo apt update\n          sudo apt install python3.11-dev -y\n          pip install -U pip uv\n          uv venv\n          source ./.venv/bin/activate\n          make install-cpu\n      - name: Download locked kernels\n        run: |\n          source ./.venv/bin/activate\n          kernels download server\n      - name: Run server tests\n        run: |\n          source ./.venv/bin/activate\n          uv pip install pytest\n          export HF_TOKEN=${{ secrets.HF_TOKEN }}\n          pytest -s -vv server/tests\n      - name: Pre-commit checks\n        run: |\n          pip install pre-commit\n          pre-commit install\n          pre-commit run --all-files\n      - name: Run Rust tests\n        run: |\n          cargo test\n      - name: Run Rust tests with google feature\n        run: |\n          cargo test --features google\n"
  },
  {
    "path": ".github/workflows/trufflehog.yaml",
    "content": "on:\n  push:\n\nname: Secret Leaks\n\npermissions:\n  contents: read\n\njobs:\n  trufflehog:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          fetch-depth: 0\n      - name: Secret Scanning\n        uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d\n        with:\n          # exclude buggy postgres detector that is causing false positives and not relevant to our codebase\n          extra_args: --results=verified,unknown --exclude-detectors=postgres\n"
  },
  {
    "path": ".github/workflows/upload_pr_documentation.yaml",
    "content": "name: Upload PR Documentation\n\non:\n  workflow_run:\n    workflows: [\"Build PR Documentation\"]\n    types:\n      - completed\n\njobs:\n  build:\n    uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main\n    with:\n      package_name: text-generation-inference\n    secrets:\n      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}\n      comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": ".idea\ntarget\nrouter/tokenizer.json\n*__pycache__*\n\nbackends/v2/src/client/pb\nbackends/v3/src/client/pb\nbackends/client/src/v2/pb\nbackends/client/src/v3/pb\n\n# ROCm auto-generated files\n*.hip\nserver/exllamav2\nserver/exllama_kernels/exllama_kernels/hip/\nserver/exllama_kernels/exllama_kernels/hip_func/\n*_hip.cuh\nserver/exllama_kernels/exllama_kernels/hip_buffers.cuh\nserver/exllama_kernels/exllama_kernels/exllama_ext_hip.cpp\n\ndata/\nload_tests/*.json\nserver/fbgemmm\n\n.direnv/\n.venv/\n\n# Gaudi auto-generated files\nhl-smi_log*.txt\n.graph_dumps\nout\nhqt_output\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n-   repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n    -   id: check-yaml\n    -   id: end-of-file-fixer\n        exclude: crate-hashes.json\n    -   id: trailing-whitespace\n        exclude: docs/source/reference/launcher.md\n-   repo: https://github.com/psf/black\n    rev: 24.2.0\n    hooks:\n    -   id: black\n-   repo: https://github.com/doublify/pre-commit-rust\n    rev: v1.0\n    hooks:\n    -   id: cargo-check\n    -   id: fmt\n    -   id: clippy\n-   repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.3.0\n    hooks:\n    -   id: ruff\n        args: [--fix, --exit-non-zero-on-fix]\n"
  },
  {
    "path": ".redocly.lint-ignore.yaml",
    "content": "# This file instructs Redocly's linter to ignore the rules contained for specific parts of your API.\n# See https://redoc.ly/docs/cli/ for more information.\ndocs/openapi.json:\n  no-empty-servers:\n    - '#/openapi'\n  spec:\n    - >-\n      #/components/schemas/GenerateParameters/properties/best_of/exclusiveMinimum\n    - >-\n      #/components/schemas/GenerateParameters/properties/frequency_penalty/exclusiveMinimum\n    - '#/components/schemas/GenerateParameters/properties/grammar/nullable'\n    - >-\n      #/components/schemas/GenerateParameters/properties/repetition_penalty/exclusiveMinimum\n    - '#/components/schemas/GenerateParameters/properties/seed/exclusiveMinimum'\n    - >-\n      #/components/schemas/GenerateParameters/properties/temperature/exclusiveMinimum\n    - '#/components/schemas/GenerateParameters/properties/top_k/exclusiveMinimum'\n    - >-\n      #/components/schemas/GenerateParameters/properties/top_n_tokens/exclusiveMinimum\n    - '#/components/schemas/GenerateParameters/properties/top_p/exclusiveMinimum'\n    - >-\n      #/components/schemas/GenerateParameters/properties/typical_p/exclusiveMinimum\n    - '#/components/schemas/GenerateResponse/properties/details/nullable'\n    - '#/components/schemas/StreamResponse/properties/details/nullable'\n    - '#/components/schemas/ChatRequest/properties/response_format/nullable'\n    - '#/components/schemas/ChatRequest/properties/stream_options/nullable'\n    - '#/components/schemas/ChatRequest/properties/tool_choice/nullable'\n    - '#/components/schemas/ToolChoice/nullable'\n    - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'\n    - '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'\n    - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'\n  no-invalid-media-type-examples:\n    - '#/paths/~1/post/responses/422/content/application~1json/example'\n    - '#/paths/~1/post/responses/424/content/application~1json/example'\n    - '#/paths/~1/post/responses/429/content/application~1json/example'\n    - '#/paths/~1/post/responses/500/content/application~1json/example'\n    - '#/paths/~1generate/post/responses/422/content/application~1json/example'\n    - '#/paths/~1generate/post/responses/424/content/application~1json/example'\n    - '#/paths/~1generate/post/responses/429/content/application~1json/example'\n    - '#/paths/~1generate/post/responses/500/content/application~1json/example'\n    - >-\n      #/paths/~1generate_stream/post/responses/422/content/text~1event-stream/example\n    - >-\n      #/paths/~1generate_stream/post/responses/424/content/text~1event-stream/example\n    - >-\n      #/paths/~1generate_stream/post/responses/429/content/text~1event-stream/example\n    - >-\n      #/paths/~1generate_stream/post/responses/500/content/text~1event-stream/example\n    - '#/paths/~1tokenize/post/responses/404/content/application~1json/example'\n    - >-\n      #/paths/~1v1~1chat~1completions/post/responses/422/content/application~1json/example\n    - >-\n      #/paths/~1v1~1chat~1completions/post/responses/424/content/application~1json/example\n    - >-\n      #/paths/~1v1~1chat~1completions/post/responses/429/content/application~1json/example\n    - >-\n      #/paths/~1v1~1chat~1completions/post/responses/500/content/application~1json/example\n    - >-\n      #/paths/~1v1~1completions/post/responses/422/content/application~1json/example\n    - >-\n      #/paths/~1v1~1completions/post/responses/424/content/application~1json/example\n    - >-\n      #/paths/~1v1~1completions/post/responses/429/content/application~1json/example\n    - >-\n      #/paths/~1v1~1completions/post/responses/500/content/application~1json/example\n  operation-4xx-response:\n    - '#/paths/~1health/get/responses'\n    - '#/paths/~1info/get/responses'\n    - '#/paths/~1metrics/get/responses'\n  no-unused-components:\n    - '#/components/schemas/Completion'\n  security-defined:\n    - '#/paths/~1/post'\n    - '#/paths/~1generate/post'\n    - '#/paths/~1generate_stream/post'\n    - '#/paths/~1health/get'\n    - '#/paths/~1info/get'\n    - '#/paths/~1metrics/get'\n    - '#/paths/~1tokenize/post'\n    - '#/paths/~1v1~1chat~1completions/post'\n    - '#/paths/~1v1~1completions/post'\n    - '#/paths/~1v1~1models/get'\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, caste, color, religion, or sexual\nidentity and orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the overall\n  community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or advances of\n  any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email address,\n  without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\nfeedback@huggingface.co.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series of\nactions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or permanent\nban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior, harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within the\ncommunity.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.1, available at\n[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].\n\nCommunity Impact Guidelines were inspired by\n[Mozilla's code of conduct enforcement ladder][Mozilla CoC].\n\nFor answers to common questions about this code of conduct, see the FAQ at\n[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at\n[https://www.contributor-covenant.org/translations][translations].\n\n[homepage]: https://www.contributor-covenant.org\n[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html\n[Mozilla CoC]: https://github.com/mozilla/diversity\n[FAQ]: https://www.contributor-covenant.org/faq\n[translations]: https://www.contributor-covenant.org/translations\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "<!---\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Contribute to text-generation-inference\n\nEveryone is welcome to contribute, and we value everybody's contribution. Code\ncontributions are not the only way to help the community. Answering questions, helping\nothers, and improving the documentation are also immensely valuable.\n\nIt also helps us if you spread the word! Reference the library in blog posts\nabout the awesome projects it made possible, shout out on Twitter every time it has\nhelped you, or simply ⭐️ the repository to say thank you.\n\nHowever you choose to contribute, please be mindful and respect our\n[code of conduct](https://github.com/huggingface/text-generation-inference/blob/main/CODE_OF_CONDUCT.md).\n\n**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**\n\n## Ways to contribute\n\nThere are several ways you can contribute to text-generation-inference.\n\n* Fix outstanding issues with the existing code.\n* Submit issues related to bugs or desired new features.\n* Contribute to the examples or to the documentation.\n\n> All contributions are equally valuable to the community. 🥰\n\n## Fixing outstanding issues\n\nIf you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) and open\na Pull Request!\n\n## Submitting a bug-related issue or feature request\n\nDo your best to follow these guidelines when submitting a bug-related issue or a feature\nrequest. It will make it easier for us to come back to you quickly and with good\nfeedback.\n\n### Did you find a bug?\n\nThe text-generation-inference library is robust and reliable thanks to users who report the problems they encounter.\n\nBefore you report an issue, we would really appreciate it if you could **make sure the bug was not\nalready reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the\nlibrary itself, and not your code.\n\nOnce you've confirmed the bug hasn't already been reported, please include the following information in your issue so\nwe can quickly resolve it:\n\n* Your **OS type and version**, as well as your environment versions (versions of rust, python, and dependencies).\n* A short, self-contained, code snippet that allows us to reproduce the bug.\n* The *full* traceback if an exception is raised.\n* Attach any other additional information, like screenshots, you think may help.\n\nTo get the OS and software versions automatically, you can re-run the launcher with the `--env` flag:\n\n```bash\ntext-generation-launcher --env\n```\n\nThis will precede the launch of the model with the information relative to your environment. We recommend pasting\nthat in your issue report.\n\n### Do you want a new feature?\n\nIf there is a new feature you'd like to see in text-generation-inference, please open an issue and describe:\n\n1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it\n   a feature related to something you need for a project? Is it something you worked on and think it could benefit\n   the community?\n\n   Whatever it is, we'd love to hear about it!\n\n2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better\n   we'll be able to help you.\n3. Provide a *code snippet* that demonstrates the feature's usage.\n4. If the feature is related to a paper, please include a link.\n\nIf your issue is well written we're already 80% of the way there by the time you create it.\n\nWe have added [templates](https://github.com/huggingface/text-generation-inference/tree/main/.github/ISSUE_TEMPLATE)\nto help you get started with your issue.\n\n## Do you want to implement a new model?\n\nNew models are constantly released and if you want to implement a new model, please provide the following information:\n\n* A short description of the model and a link to the paper.\n* Link to the implementation if it is open-sourced.\n* Link to the model weights if they are available.\n\nIf you are willing to contribute the model yourself, let us know so we can help you add it to text-generation-inference!\n\n## Do you want to add documentation?\n\nWe're always looking for improvements to the documentation that make it more clear and accurate. Please let us know\nhow the documentation can be improved such as typos and any content that is missing, unclear or inaccurate. We'll be\nhappy to make the changes or help you make a contribution if you're interested!\n\n## I want to become a maintainer of the project. How do I get there?\n\nTGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have\nmotivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference\nservice.\n\nIf you are such an individual (or organization), please reach out to us and let's collaborate.\n"
  },
  {
    "path": "Cargo.toml",
    "content": "[workspace]\nmembers = [\n    \"benchmark\",\n    \"backends/v2\",\n    \"backends/v3\",\n    \"backends/grpc-metadata\",\n    \"backends/trtllm\",\n    \"backends/llamacpp\",\n    \"launcher\",\n    \"router\"\n]\ndefault-members = [\n    \"benchmark\",\n    \"backends/v2\",\n    \"backends/v3\",\n    \"backends/grpc-metadata\",\n    # \"backends/trtllm\",\n    \"launcher\",\n    \"router\"\n]\nresolver = \"2\"\n\n[workspace.package]\nversion = \"3.3.6-dev0\"\nedition = \"2021\"\nauthors = [\"Olivier Dehaene\"]\nhomepage = \"https://github.com/huggingface/text-generation-inference\"\n\n[workspace.dependencies]\nbase64 = \"0.22.0\"\ntokenizers = { version = \"0.20.0\", features = [\"http\"] }\nhf-hub = { version = \"0.4.2\", features = [\"tokio\"] }\nmetrics = { version = \"0.23.0\" }\nmetrics-exporter-prometheus = { version = \"0.15.1\", features = [] }\nminijinja = { version = \"2.2.0\", features = [\"json\"] }\nminijinja-contrib = { version = \"2.0.2\", features = [\"pycompat\"] }\npyo3 = { version = \"0.22.2\", features = [\"auto-initialize\"] }\n\n[profile.release]\nincremental = true\n\n[profile.release-binary]\ninherits = \"release\"\ndebug = 1\nincremental = true\npanic = \"abort\"\n\n[profile.release-opt]\ninherits = \"release\"\ndebug = 0\nincremental = false\nlto = \"fat\"\nopt-level = 3\ncodegen-units = 1\n"
  },
  {
    "path": "Dockerfile",
    "content": "# Rust builder\nFROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef\nWORKDIR /usr/src\n\nARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse\n\nFROM chef AS planner\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\n\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM chef AS builder\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n    python3.11-dev\nRUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \\\n    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \\\n    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \\\n    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \\\n    rm -f $PROTOC_ZIP\n\nCOPY --from=planner /usr/src/recipe.json recipe.json\nRUN cargo chef cook --profile release-opt --recipe-path recipe.json\n\nARG GIT_SHA\nARG DOCKER_LABEL\n\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo build --profile release-opt --frozen\n\n# Python builder\n# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile\nFROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS pytorch-install\nWORKDIR /usr/src/\n\n# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099\nARG PYTORCH_VERSION=2.7\nARG PYTHON_VERSION=3.11\n\n# Keep in sync with `server/pyproject.toml\n# Automatically set by buildx\nARG TARGETPLATFORM\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        build-essential \\\n        ca-certificates \\\n        ccache \\\n        curl \\\n        git && \\\n        rm -rf /var/lib/apt/lists/*\nCOPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/\nENV PATH=\"$PATH:/root/.local/bin\"\nRUN uv python install ${PYTHON_VERSION}\nRUN uv venv --python ${PYTHON_VERSION} && uv pip install torch==${PYTORCH_VERSION} torchvision pip setuptools packaging\nENV VIRTUAL_ENV=/usr/src/.venv/\nENV PATH=\"$PATH:/usr/src/.venv/bin/\"\n\n# CUDA kernels builder image\nFROM pytorch-install AS kernel-builder\n\nARG MAX_JOBS=8\nENV TORCH_CUDA_ARCH_LIST=\"8.0;8.6;9.0+PTX\"\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        ninja-build cmake \\\n        && rm -rf /var/lib/apt/lists/*\n\n# Build Flash Attention CUDA kernels\nFROM kernel-builder AS flash-att-builder\n\nWORKDIR /usr/src\n\nCOPY server/Makefile-flash-att Makefile\n\n# Build specific version of flash attention\nRUN . .venv/bin/activate && make build-flash-attention\n\n# Build Flash Attention v2 CUDA kernels\nFROM kernel-builder AS flash-att-v2-builder\n\nWORKDIR /usr/src\n\nCOPY server/Makefile-flash-att-v2 Makefile\n\n# Build specific version of flash attention v2\nRUN . .venv/bin/activate && make build-flash-attention-v2-cuda\n\n# Build Transformers exllama kernels\nFROM kernel-builder AS exllama-kernels-builder\nWORKDIR /usr/src\nCOPY server/exllama_kernels/ .\n\nRUN . .venv/bin/activate && python setup.py build\n\n# Build Transformers exllama kernels\nFROM kernel-builder AS exllamav2-kernels-builder\nWORKDIR /usr/src\nCOPY server/Makefile-exllamav2/ Makefile\n\n# Build specific version of transformers\nRUN . .venv/bin/activate && make build-exllamav2\n\n# Build Transformers awq kernels\nFROM kernel-builder AS awq-kernels-builder\nWORKDIR /usr/src\nCOPY server/Makefile-awq Makefile\n# Build specific version of transformers\nRUN . .venv/bin/activate && make build-awq\n\n# Build Transformers CUDA kernels\nFROM kernel-builder AS custom-kernels-builder\nWORKDIR /usr/src\nCOPY server/custom_kernels/ .\n# Build specific version of transformers\nRUN . .venv/bin/activate && python setup.py build\n\n# Build mamba kernels\nFROM kernel-builder AS mamba-builder\nWORKDIR /usr/src\nCOPY server/Makefile-selective-scan Makefile\nRUN . .venv/bin/activate && make build-all\n\n# Build flashinfer\nFROM kernel-builder AS flashinfer-builder\nWORKDIR /usr/src\nCOPY server/Makefile-flashinfer Makefile\nRUN . .venv/bin/activate && make install-flashinfer\n\n# Text Generation Inference base image\nFROM nvidia/cuda:12.4.0-base-ubuntu22.04 AS base\n\n# Text Generation Inference base env\nENV HF_HOME=/data \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\nWORKDIR /usr/src\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        libssl-dev \\\n        ca-certificates \\\n        make \\\n        curl \\\n        git \\\n        && rm -rf /var/lib/apt/lists/*\n\n# RUN curl -LsSf https://astral.sh/uv/install.sh | sh\n# ENV PATH=\"$PATH:/root/.local/bin\"\nCOPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/\n# Install flash-attention dependencies\n# RUN pip install einops --no-cache-dir\n\n# Copy env with PyTorch installed\nCOPY --from=pytorch-install /usr/src/.venv /usr/src/.venv\nENV PYTHON_VERSION=3.11\nRUN uv python install ${PYTHON_VERSION}\nENV VIRTUAL_ENV=/usr/src/.venv/\nENV PATH=\"$PATH:/usr/src/.venv/bin/\"\n\n# Install server\nCOPY proto proto\nCOPY server server\nCOPY server/Makefile server/Makefile\nENV HF_KERNELS_CACHE=/kernels\nRUN cd server && \\\n\tuv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --no-install-project --active && \\\n    make gen-server-raw && \\\n    kernels download .\n\nRUN cd server && \\\n    uv sync --frozen --extra gen --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --extra torch --active --python=${PYTHON_VERSION} && \\\n    uv pip install nvidia-nccl-cu12==2.25.1 && \\\n    pwd && \\\n    text-generation-server --help\n\n# Copy build artifacts from flash attention builder\nCOPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\nCOPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\nCOPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\n\n# Copy build artifacts from flash attention v2 builder\nCOPY --from=flash-att-v2-builder /usr/src/.venv/lib/python3.11/site-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so /usr/src/.venv/lib/python3.11/site-packages\n\n# Copy build artifacts from custom kernels builder\nCOPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\n# Copy build artifacts from exllama kernels builder\nCOPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\n# Copy build artifacts from exllamav2 kernels builder\nCOPY --from=exllamav2-kernels-builder /usr/src/exllamav2/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\n# Copy build artifacts from awq kernels builder\nCOPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-311 /usr/src/.venv/lib/python3.11/site-packages\n# Copy build artifacts from mamba builder\nCOPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages\nCOPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-311/ /usr/src/.venv/lib/python3.11/site-packages\nCOPY --from=flashinfer-builder /usr/src/.venv/lib/python3.11/site-packages/flashinfer/ /usr/src/.venv/lib/python3.11/site-packages/flashinfer/\n\n\n# ENV LD_PRELOAD=/opt/conda/lib/python3.11/site-packages/nvidia/nccl/lib/libnccl.so.2\n# Required to find libpython within the rust binaries\n# This is needed because exl2 tries to load flash-attn\n# And fails with our builds.\nENV EXLLAMA_NO_FLASH_ATTN=1\n\n# Deps before the binaries\n# The binaries change on every build given we burn the SHA into them\n# The deps change less often.\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        build-essential \\\n        g++ \\\n        && rm -rf /var/lib/apt/lists/*\n\n# Install benchmarker\nCOPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark\n# Install router\nCOPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher\n\n\n# AWS Sagemaker compatible image\nFROM base AS sagemaker\n\nCOPY sagemaker-entrypoint.sh entrypoint.sh\nRUN chmod +x entrypoint.sh\n\nENTRYPOINT [\"./entrypoint.sh\"]\n\n# Final image\nFROM base\n\nCOPY ./tgi-entrypoint.sh /tgi-entrypoint.sh\nRUN chmod +x /tgi-entrypoint.sh\n\nENV LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/\"\nENTRYPOINT [\"/tgi-entrypoint.sh\"]\n# CMD [\"--json-output\"]\n"
  },
  {
    "path": "Dockerfile.neuron",
    "content": "# Fetch and extract the TGI sources\nFROM alpine AS tgi\nRUN mkdir -p /tgi\n\n# Fetch the optimum-neuron sources directly to avoid relying on pypi deployments\nFROM alpine AS optimum-neuron\nRUN mkdir -p /optimum-neuron\nADD https://github.com/huggingface/optimum-neuron/archive/refs/tags/v0.3.0.tar.gz /optimum-neuron/sources.tar.gz\nRUN tar -C /optimum-neuron -xf /optimum-neuron/sources.tar.gz --strip-components=1\n\n# Build cargo components (adapted from TGI original Dockerfile)\n# Note: we cannot use the cargo-chef base image as it uses python 3.11\nFROM ubuntu:22.04 AS chef\n\nRUN apt-get update -y \\\n && apt-get install -y --no-install-recommends \\\n    curl ca-certificates build-essential \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y\nENV PATH=\"/root/.cargo/bin:${PATH}\"\nRUN cargo install cargo-chef --locked\n\nWORKDIR /usr/src\n\nFROM chef AS planner\nCOPY backends/neuron/Cargo.toml Cargo.toml\nCOPY Cargo.lock Cargo.lock\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM chef AS builder\n\nRUN apt-get update -y \\\n && apt-get install -y --no-install-recommends \\\n    unzip python3-dev libssl-dev pkg-config \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\nRUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \\\n    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \\\n    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \\\n    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \\\n    rm -f $PROTOC_ZIP\n\nCOPY backends/neuron/Cargo.toml Cargo.toml\nCOPY --from=planner /usr/src/recipe.json recipe.json\nRUN cargo chef cook --release --recipe-path recipe.json\n\nCOPY Cargo.lock Cargo.lock\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo build --release\n\n# Python base image\nFROM ubuntu:22.04 AS base\n\nRUN apt-get update -y \\\n    && apt-get install -y --no-install-recommends \\\n    python3-pip \\\n    python3-setuptools \\\n    python-is-python3 \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\nRUN pip3 --no-cache-dir install --upgrade pip\n\n# Python server build image\nFROM base AS pyserver\n\nRUN apt-get update -y \\\n    && apt-get install -y --no-install-recommends \\\n    make \\\n    python3-venv \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\nRUN install -d /pyserver\nWORKDIR /pyserver\nCOPY backends/neuron/server server\nCOPY proto proto\nRUN pip3 install -r server/build-requirements.txt\nRUN VERBOSE=1 BUILDDIR=/pyserver/build PROTODIR=/pyserver/proto make -C server package\n\n# Neuron base image (used for deployment)\nFROM base AS neuron\n\n# Install system prerequisites\nRUN apt-get update -y \\\n    && apt-get install -y --no-install-recommends \\\n    gnupg2 \\\n    wget \\\n    python3-dev \\\n    libexpat1 \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\nRUN echo \"deb https://apt.repos.neuron.amazonaws.com jammy main\" > /etc/apt/sources.list.d/neuron.list\nRUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -\n\n# Install neuronx packages\nRUN apt-get update -y \\\n    && apt-get install -y --no-install-recommends \\\n    aws-neuronx-dkms=2.22.2.0 \\\n    aws-neuronx-collectives=2.26.43.0-47cc904ea \\\n    aws-neuronx-runtime-lib=2.26.42.0-2ff3b5c7d  \\\n    aws-neuronx-tools=2.24.54.0 \\\n    libxml2 \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\nENV PATH=\"/opt/bin/:/opt/aws/neuron/bin:${PATH}\"\n\n# Install manually torch CPU version to avoid pulling CUDA\nRUN pip3 install \\\n    torch==2.7.0 \\\n    torchvision==0.22.0 \\\n    --index-url https://download.pytorch.org/whl/cpu\n\nRUN pip3 install \\\n    neuronx-cc==2.19.8089.0+8ab9f450 \\\n    torch-neuronx==2.7.0.2.8.6734+ac864f72 \\\n    neuronx-distributed==0.13.14393+b8569585 \\\n    libneuronxla==2.2.4410.0+835a67fb \\\n    --extra-index-url=https://pip.repos.neuron.amazonaws.com\n\n# Install HuggingFace packages\nRUN pip3 install \\\n    hf_transfer huggingface_hub\n\n# Install optimum-neuron\nCOPY --from=optimum-neuron /optimum-neuron optimum-neuron\nRUN pip3 install ./optimum-neuron\n\n# TGI base env\nENV HUGGINGFACE_HUB_CACHE=/tmp \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\n# Disable color logs as they are not supported by CloudWatch\nENV LOGURU_COLORIZE=NO\nENV LOG_COLORIZE=0\n\n# Install router\nCOPY --from=builder /usr/src/target/release/text-generation-router-v2 /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher\n# Install python server\nCOPY --from=pyserver /pyserver/build/dist dist\nRUN pip install dist/text_generation_server*.tar.gz\n\n# Final image\nFROM neuron\n\nCOPY backends/neuron/tgi_entry_point.py /tgi_entry_point.py\nCOPY backends/neuron/tgi-entrypoint.sh /tgi-entrypoint.sh\nRUN chmod +x /tgi-entrypoint.sh\n\nENTRYPOINT [\"/tgi-entrypoint.sh\"]\n"
  },
  {
    "path": "Dockerfile.nix",
    "content": "# Build the image and get out the docker file:\n#\n# docker build -t tgi-nix-builder -f Dockerfile.nix\n# docker run --log-driver=none tgi-nix-builder | docker load\n\nFROM nixos/nix:2.18.8 AS builder\nRUN echo \"experimental-features = nix-command flakes\" >> /etc/nix/nix.conf\nRUN nix profile install nixpkgs#cachix\nRUN cachix use huggingface\nWORKDIR /root\nADD . .\nRUN nix build .\nRUN mkdir /tmp/nix-store-closure\nRUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure\n\nFROM ubuntu:24.04\n\nWORKDIR /app\n\n# Copy /nix/store\nCOPY --from=builder /tmp/nix-store-closure /nix/store\nCOPY --from=builder /root/result /app\nRUN ldconfig\nCMD [\"ldconfig\", \"/app/bin/text-generation-launcher\"]\n"
  },
  {
    "path": "Dockerfile_amd",
    "content": "# Rust builder\nFROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef\nWORKDIR /usr/src\n\nARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse\n\nFROM chef AS planner\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM chef AS builder\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n    python3.11-dev\nRUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \\\n    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \\\n    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \\\n    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \\\n    rm -f $PROTOC_ZIP\n\nCOPY --from=planner /usr/src/recipe.json recipe.json\nRUN cargo chef cook --profile release-opt --recipe-path recipe.json\n\nARG GIT_SHA\nARG DOCKER_LABEL\n\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo build --profile release-opt --frozen\n\nFROM rocm/dev-ubuntu-22.04:6.3.1-complete AS base\n\nARG HIPBLASLT_BRANCH=\"4d40e36\"\nARG HIPBLAS_COMMON_BRANCH=\"7c1566b\"\nARG LEGACY_HIPBLASLT_OPTION=\nARG RCCL_BRANCH=\"648a58d\"\nARG RCCL_REPO=\"https://github.com/ROCm/rccl\"\nARG TRITON_BRANCH=\"e5be006\"\nARG TRITON_REPO=\"https://github.com/triton-lang/triton.git\"\nARG PYTORCH_BRANCH=\"3a585126\"\nARG PYTORCH_VISION_BRANCH=\"v0.19.1\"\nARG PYTORCH_REPO=\"https://github.com/pytorch/pytorch.git\"\nARG PYTORCH_VISION_REPO=\"https://github.com/pytorch/vision.git\"\nARG FA_BRANCH=\"b7d29fb\"\nARG FA_REPO=\"https://github.com/ROCm/flash-attention.git\"\nARG AITER_BRANCH=\"21d47a9\"\nARG AITER_REPO=\"https://github.com/ROCm/aiter.git\"\n\nENV PATH=/opt/rocm/llvm/bin:$PATH\nENV ROCM_PATH=/opt/rocm\nENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:\nARG PYTORCH_ROCM_ARCH=gfx90a;gfx942\nENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}\n\nARG PYTHON_VERSION=3.11\n\nRUN mkdir -p /app\nWORKDIR /app\nENV DEBIAN_FRONTEND=noninteractive\n\n# Install Python and other dependencies\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        build-essential \\\n        ca-certificates \\\n        ccache \\\n        curl \\\n        git \\\n        ninja-build \\\n        cmake \\\n        software-properties-common \\\n        python3.11-dev \\\n        python3.11-venv && \\\n        rm -rf /var/lib/apt/lists/*\n\nCOPY --from=ghcr.io/astral-sh/uv:0.5.31 /uv /uvx /bin/\nENV PATH=\"$PATH:/root/.local/bin\"\nRUN uv python install ${PYTHON_VERSION}\nRUN uv venv --python ${PYTHON_VERSION} && uv pip install pip setuptools packaging\nENV VIRTUAL_ENV=/usr/src/.venv/\nENV PATH=\"$PATH:/usr/src/.venv/bin/\"\n\nRUN . .venv/bin/activate && pip install -U packaging cmake ninja wheel setuptools pybind11 Cython\n\nFROM base AS build_hipblaslt\nARG HIPBLASLT_BRANCH\nARG HIPBLAS_COMMON_BRANCH\n# Set to \"--legacy_hipblas_direct\" for ROCm<=6.2\nARG LEGACY_HIPBLASLT_OPTION\nRUN git clone https://github.com/ROCm/hipBLAS-common.git\nRUN . .venv/bin/activate && cd hipBLAS-common \\\n    && git checkout ${HIPBLAS_COMMON_BRANCH} \\\n    && mkdir build \\\n    && cd build \\\n    && cmake .. \\\n    && make package \\\n    && dpkg -i ./*.deb\nRUN git clone https://github.com/ROCm/hipBLASLt\nRUN . .venv/bin/activate && cd hipBLASLt \\\n    && git checkout ${HIPBLASLT_BRANCH} \\\n    && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \\\n    && cd build/release \\\n    && make package\nRUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install\n\nFROM base AS build_rccl\nARG RCCL_BRANCH\nARG RCCL_REPO\nRUN git clone ${RCCL_REPO}\nRUN . .venv/bin/activate && cd rccl \\\n    && git checkout ${RCCL_BRANCH} \\\n    && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}\nRUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install\n\nFROM base AS build_triton\nARG TRITON_BRANCH\nARG TRITON_REPO\nRUN git clone ${TRITON_REPO}\nRUN . .venv/bin/activate && cd triton \\\n    && git checkout ${TRITON_BRANCH} \\\n    && cd python \\\n    && python3 setup.py bdist_wheel --dist-dir=dist\nRUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install\n\nFROM base AS build_amdsmi\nRUN . .venv/bin/activate && cd /opt/rocm/share/amd_smi \\\n    && pip wheel . --wheel-dir=dist\nRUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install\n\nFROM base AS build_pytorch\nARG PYTORCH_BRANCH\nARG PYTORCH_VISION_BRANCH\nARG PYTORCH_REPO\nARG PYTORCH_VISION_REPO\nARG FA_BRANCH\nARG FA_REPO\nRUN git clone ${PYTORCH_REPO} pytorch\nRUN . .venv/bin/activate && cd pytorch && git checkout ${PYTORCH_BRANCH} && \\\n    pip install -r requirements.txt && git submodule update --init --recursive \\\n    && python3 tools/amd_build/build_amd.py \\\n    && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \\\n    && pip install dist/*.whl\nRUN git clone ${PYTORCH_VISION_REPO} vision\nRUN . .venv/bin/activate && cd vision && git checkout ${PYTORCH_VISION_BRANCH} \\\n    && python3 setup.py bdist_wheel --dist-dir=dist \\\n    && pip install dist/*.whl\nRUN git clone ${FA_REPO}\nRUN . .venv/bin/activate && cd flash-attention \\\n    && git checkout ${FA_BRANCH} \\\n    && git submodule update --init \\\n    && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist\nRUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \\\n    && cp /app/vision/dist/*.whl /app/install \\\n    && cp /app/flash-attention/dist/*.whl /app/install\n\nFROM base AS final\nRUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \\\n    dpkg -i /install/*deb \\\n    && sed -i 's/, hipblaslt-dev \\(.*\\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \\\n    && sed -i 's/, hipblaslt \\(.*\\), hipfft/, hipfft/g' /var/lib/dpkg/status\nRUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \\\n    dpkg -i /install/*deb \\\n    && sed -i 's/, rccl-dev \\(.*\\), rocalution/, rocalution/g' /var/lib/dpkg/status \\\n    && sed -i 's/, rccl \\(.*\\), rocalution/, rocalution/g' /var/lib/dpkg/status\nRUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \\\n    . .venv/bin/activate && \\\n    pip install /install/*.whl\nRUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \\\n    . .venv/bin/activate && \\\n    pip install /install/*.whl\nRUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \\\n    . .venv/bin/activate && \\\n    pip install /install/*.whl\n\nARG AITER_REPO\nARG AITER_BRANCH\nRUN git clone --recursive ${AITER_REPO}\nRUN . .venv/bin/activate && cd aiter \\\n    && git checkout ${AITER_BRANCH} \\\n    && git submodule update --init --recursive \\\n    && pip install -r requirements.txt \\\n    && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter\n\nRUN rm -rf /var/lib/apt/lists/*\n\nFROM final AS kernel-builder\n# # Build vllm kernels\nFROM kernel-builder AS vllm-builder\n\nCOPY server/Makefile-vllm Makefile\nRUN . .venv/bin/activate && pip install setuptools_scm\n\n# Build specific version of vllm\nRUN . .venv/bin/activate && make build-vllm-rocm\n\n# Build Transformers CUDA kernels (gpt-neox and bloom)\nFROM kernel-builder AS custom-kernels-builder\nCOPY server/custom_kernels/ .\nRUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist\n\n# Build exllama kernels\nFROM kernel-builder AS exllama-kernels-builder\nCOPY server/exllama_kernels/ .\nRUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist\n\n# Build exllama v2 kernels\nFROM kernel-builder AS exllamav2-kernels-builder\nCOPY server/exllamav2_kernels/ .\nRUN . .venv/bin/activate && python3 setup.py bdist_wheel --dist-dir=dist\n\nFROM kernel-builder AS marlin-kernels\nENV MARLIN_KERNELS_BRANCH=v0.3.6\nENV VLLM_TARGET_DEVICE=rocm\nRUN . .venv/bin/activate && git clone https://github.com/danieldk/marlin-kernels.git && \\\n    cd marlin-kernels && \\\n    git checkout ${MARLIN_KERNELS_BRANCH} && \\\n    python3 setup.py bdist_wheel --dist-dir=dist\n\nFROM kernel-builder AS moe-kernels\nENV MOE_KERNELS_BRANCH=v0.8.2\nENV VLLM_TARGET_DEVICE=rocm\nRUN . .venv/bin/activate && git clone https://github.com/danieldk/moe-kernels.git && \\\n    cd moe-kernels && \\\n    git checkout ${MOE_KERNELS_BRANCH} && \\\n    python3 setup.py bdist_wheel --dist-dir=dist\n\nFROM final AS base-copy\n\n# Text Generation Inference base env\nENV HF_HOME=/data \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\nENV VIRTUAL_ENV=/app/.venv/\nENV PATH=\"$PATH:/app/.venv/bin/\"\n\n# Install server\nCOPY proto proto\nCOPY server server\nCOPY server/Makefile server/Makefile\nRUN cd server && \\\n    uv pip install grpcio-tools mypy-protobuf && \\\n    uv pip install -e \".[accelerate, compressed-tensors, peft, outlines]\" --no-cache-dir && \\\n    make gen-server-raw\nRUN cd server && \\\n    pwd && \\\n    text-generation-server --help\n\nRUN --mount=type=bind,from=vllm-builder,src=/app/vllm/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=custom-kernels-builder,src=/app/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=exllama-kernels-builder,src=/app/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=exllamav2-kernels-builder,src=/app/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=marlin-kernels,src=/app/marlin-kernels/dist,target=/install \\\n    uv pip install /install/*.whl\nRUN --mount=type=bind,from=moe-kernels,src=/app/moe-kernels/dist,target=/install \\\n    uv pip install /install/*.whl\n\n# Install benchmarker\nCOPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark\n# Install router\nCOPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher\n\n# AWS Sagemaker compatible image\nFROM base AS sagemaker\n\nCOPY sagemaker-entrypoint.sh entrypoint.sh\nRUN chmod +x entrypoint.sh\n\nENTRYPOINT [\"./entrypoint.sh\"]\n\n# Final image\nFROM base-copy\n\n# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm\nENV HIP_FORCE_DEV_KERNARG=1\n\n# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.\n# However, Triton requires a tunning for each prompt length, which is prohibitive.\nENV ROCM_USE_FLASH_ATTN_V2_TRITON=0\nENV ROCM_USE_CUSTOM_PAGED_ATTN=1\nENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0\nENV VLLM_MOE_PADDING=0\nENV ATTENTION=paged\nENV PREFIX_CACHING=0\nENV PREFILL_CHUNKING=0\nENV ROCM_USE_SKINNY_GEMM=1\n\nCOPY ./tgi-entrypoint.sh /tgi-entrypoint.sh\nRUN chmod +x /tgi-entrypoint.sh\n\nENTRYPOINT [\"/tgi-entrypoint.sh\"]\nENV LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:/root/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib\"\nENV PYTHONPATH=/app/.venv/lib/python3.11/site-packages\n# CMD [\"--json-output\"]\n"
  },
  {
    "path": "Dockerfile_gaudi",
    "content": "# Those arguments are required to build the image\nARG HABANA_VERSION=1.21.0\nARG PYTORCH_VERSION=2.6.0\n\n# Rust builder\nFROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef\nWORKDIR /usr/src\n\nARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse\n\nFROM chef AS planner\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM chef AS builder\n\nENV PYO3_PYTHON=\"/root/.local/bin/python\" \\\n    PYTHON_SYS_EXECUTABLE=\"/root/.local/bin/python\" \\\n    PYO3_PYTHON_VERSION=\"3.10\"\n\nRUN curl -LsSf https://astral.sh/uv/install.sh | sh \\\n    && . $HOME/.local/bin/env \\\n    && uv python install 3.10 --default --preview \\\n    && test -f /root/.local/bin/python || (echo \"Python 3.10 not found at /root/.local/bin/python\" && exit 1)\n\nRUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \\\n    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \\\n    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \\\n    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \\\n    rm -f $PROTOC_ZIP\n\nCOPY --from=planner /usr/src/recipe.json recipe.json\nRUN cargo chef cook --profile release-opt --recipe-path recipe.json\n\nARG GIT_SHA\nARG DOCKER_LABEL\n\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo build --profile release-opt\n\n# Text Generation Inference base image\nARG HABANA_VERSION\nARG PYTORCH_VERSION\n\nFROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base\n\nENV ATTENTION=paged\nENV PREFIX_CACHING=0\nENV PREFILL_CHUNKING=0\nENV PT_HPU_LAZY_MODE=1\nENV PT_HPU_WEIGHT_SHARING=0\nENV VLLM_EXPONENTIAL_BUCKETING=true\n\n# Text Generation Inference base env\nENV HF_HOME=/data \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\n# Assert that Python 3.10 is installed as the launcher is compiled with Python 3.10\nRUN python3.10 --version || (echo \"Python 3.10 is not installed\" && exit 1)\n\n# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it\nRUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \\\n    dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb\n\nWORKDIR /usr/src\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n        libssl-dev \\\n        ca-certificates \\\n        make \\\n        curl \\\n        git \\\n        && rm -rf /var/lib/apt/lists/*\n\n# Install server\nCOPY proto proto\nCOPY backends/gaudi/server server\nCOPY backends/gaudi/server/Makefile server/Makefile\nARG HABANA_VERSION\nRUN cd server && \\\n    make gen-server && \\\n    pip install --no-deps -r requirements.txt && \\\n    bash ./dill-0.3.8-patch.sh && \\\n    pip install . --no-cache-dir\nRUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git@bmax_fix\nRUN pip install compressed-tensors==0.9.1\n\n# Install benchmarker\nCOPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark\n# Install router\nCOPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher\n\n\n# AWS Sagemaker compatible image\nFROM base AS sagemaker\n\nCOPY sagemaker-entrypoint.sh entrypoint.sh\nRUN chmod +x entrypoint.sh\n\nENTRYPOINT [\"./entrypoint.sh\"]\n\n# Final image\nFROM base\n\nENV HF_HUB_ENABLE_HF_TRANSFER=1\nENV HABANA_VISIBLE_DEVICES=all\nENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE\n\nCOPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh\nRUN chmod +x /tgi-entrypoint.sh\n\nENTRYPOINT [\"/tgi-entrypoint.sh\"]\nCMD [\"--json-output\"]\n"
  },
  {
    "path": "Dockerfile_intel",
    "content": "ARG PLATFORM=xpu\n\nFROM lukemathwalker/cargo-chef:latest-rust-1.85.1 AS chef\nWORKDIR /usr/src\n\nARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse\n\nFROM chef AS planner\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM chef AS builder\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n    python3.11-dev\nRUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \\\n    curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \\\n    unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \\\n    unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \\\n    rm -f $PROTOC_ZIP\n\nCOPY --from=planner /usr/src/recipe.json recipe.json\nRUN cargo chef cook --profile release-opt --recipe-path recipe.json\n\nARG GIT_SHA\nARG DOCKER_LABEL\n\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY proto proto\nCOPY benchmark benchmark\nCOPY router router\nCOPY backends backends\nCOPY launcher launcher\nRUN cargo build --profile release-opt --frozen\n\n\n# Text Generation Inference base image for Intel\n\nFROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu\n\nUSER root\n\nARG MAMBA_VERSION=23.1.0-1\nARG PYTHON_VERSION='3.11.10'\n# Automatically set by buildx\nARG TARGETPLATFORM\nENV PATH=/opt/conda/bin:$PATH\n\n# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.\n# Install mamba\n# translating Docker's TARGETPLATFORM into mamba arches\nRUN case ${TARGETPLATFORM} in \\\n         \"linux/arm64\")  MAMBA_ARCH=aarch64  ;; \\\n         *)              MAMBA_ARCH=x86_64   ;; \\\n    esac && \\\n    curl -fsSL -v -o ~/mambaforge.sh -O  \"https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh\"\nRUN chmod +x ~/mambaforge.sh && \\\n    bash ~/mambaforge.sh -b -p /opt/conda && \\\n    rm ~/mambaforge.sh\n\nRUN case ${TARGETPLATFORM} in \\\n         \"linux/arm64\")  exit 1 ;; \\\n         *)              /opt/conda/bin/conda update -y conda &&  \\\n                         /opt/conda/bin/conda install -y \"python=${PYTHON_VERSION}\" ;; \\\n    esac && \\\n    /opt/conda/bin/conda clean -ya\n\n# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it\nRUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \\\n    dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb\n\nRUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null\n\nRUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \\\n| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo \"deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main\" | tee /etc/apt/sources.list.d/oneAPI.list\n\nRUN echo \"deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/intel-for-pytorch-gpu-dev all main\" > /tmp/intel-for-pytorch-gpu-dev.list\n\nRUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200\n\n# Text Generation Inference base env\nENV HF_HOME=/data \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\n\n\n\nWORKDIR /usr/src\n\nRUN pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/xpu\n\n# Install server\nCOPY proto proto\nCOPY server server\nCOPY server/Makefile server/Makefile\nENV UV_SYSTEM_PYTHON=1\nRUN cd server && \\\n    make gen-server && \\\n    pip install -U pip uv && \\\n    uv pip install -e \".[accelerate, compressed-tensors, peft, outlines]\" --no-cache-dir\n\nENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/conda/lib\nENV CCL_ZE_IPC_EXCHANGE=sockets\nENV TORCH_LLM_ALLREDUCE=1\nENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0\nENV TORCH_DEVICE_BACKEND_AUTOLOAD=0\n\nRUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.8.10%2Bxpu-cp311-cp311-linux_x86_64.whl\n# Install benchmarker\nCOPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark\n# Install router\nCOPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher\n\n\n# Text Generation Inference base image for Intel-cpu\nFROM ubuntu:22.04 AS cpu\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \\\n    curl \\\n    ca-certificates \\\n    make \\\n    g++-12 \\\n    gcc-12 \\\n    git \\\n    wget \\\n    cmake \\\n    libnuma-dev\n\nRUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12\nRUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12\nRUN update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 30\nRUN update-alternatives --set cc /usr/bin/gcc\n\nRUN update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 30\nRUN update-alternatives --set c++ /usr/bin/g++\n\n\nENV HUGGINGFACE_HUB_CACHE=/data \\\n    HF_HUB_ENABLE_HF_TRANSFER=1 \\\n    PORT=80\n\nARG MAMBA_VERSION=23.1.0-1\nARG PYTHON_VERSION='3.11.10'\n# Automatically set by buildx\nARG TARGETPLATFORM\nENV PATH=/opt/conda/bin:$PATH\n\n# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.\n# Install mamba\n# translating Docker's TARGETPLATFORM into mamba arches\nRUN case ${TARGETPLATFORM} in \\\n         \"linux/arm64\")  MAMBA_ARCH=aarch64  ;; \\\n         *)              MAMBA_ARCH=x86_64   ;; \\\n    esac && \\\n    curl -fsSL -v -o ~/mambaforge.sh -O  \"https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh\"\nRUN chmod +x ~/mambaforge.sh && \\\n    bash ~/mambaforge.sh -b -p /opt/conda && \\\n    rm ~/mambaforge.sh\n\nRUN case ${TARGETPLATFORM} in \\\n         \"linux/arm64\")  exit 1 ;; \\\n         *)              /opt/conda/bin/conda update -y conda &&  \\\n                         /opt/conda/bin/conda install -y \"python=${PYTHON_VERSION}\" ;; \\\n    esac && \\\n    /opt/conda/bin/conda clean -ya\n\nRUN conda install -c conda-forge gperftools mkl\n\nRUN pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu\nRUN pip install triton==3.2.0 py-libnuma\n\nWORKDIR /usr/src\n\nRUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/intel_extension_for_pytorch-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl\nRUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/cpu/oneccl_bind_pt-2.7.0%2Bcpu-cp311-cp311-linux_x86_64.whl\n\n\nENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so\nENV CCL_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch\nENV I_MPI_ROOT=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch\nENV FI_PROVIDER_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric\nENV LD_LIBRARY_PATH=/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.11/site-packages/oneccl_bindings_for_pytorch/lib\nENV LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:/opt/conda/lib/\"\n\n# Install server\nCOPY proto proto\nCOPY server server\nCOPY server/Makefile server/Makefile\nENV UV_SYSTEM_PYTHON=1\nRUN cd server && \\\n    make gen-server && \\\n    pip install -U pip uv && \\\n    uv pip install -e \".[accelerate, compressed-tensors, peft, outlines]\" --no-cache-dir\n\n# Install benchmarker\nCOPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark\n# Install router\nCOPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router\n# Install launcher\nCOPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher\n\nFROM ${PLATFORM} AS final\nENV ATTENTION=flashdecoding-ipex\nENV PREFIX_CACHING=1\nENV PREFILL_CHUNKING=1\nENV CUDA_GRAPHS=0\nENTRYPOINT [\"text-generation-launcher\"]\nCMD [\"--json-output\"]\n"
  },
  {
    "path": "Dockerfile_llamacpp",
    "content": "FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04 AS deps\n\nARG llamacpp_version=b4827\nARG llamacpp_cuda=OFF\nARG llamacpp_native=ON\nARG llamacpp_cpu_arm_arch=native\nARG cuda_arch=75-real;80-real;86-real;89-real;90-real\n\nWORKDIR /opt/src\n\nENV DEBIAN_FRONTEND=noninteractive\nRUN apt update && apt upgrade -y && apt install -y \\\n    clang \\\n    cmake \\\n    curl \\\n    git \\\n    python3-dev \\\n    libssl-dev \\\n    pkg-config \\\n    tar\n\nADD https://github.com/ggml-org/llama.cpp/archive/refs/tags/${llamacpp_version}.tar.gz /opt/src/\nRUN mkdir -p llama.cpp \\\n && tar -xzf ${llamacpp_version}.tar.gz -C llama.cpp --strip-components=1 \\\n && cd llama.cpp \\\n && cmake -B build \\\n    -DCMAKE_INSTALL_PREFIX=/usr \\\n    -DCMAKE_INSTALL_LIBDIR=/usr/lib \\\n    -DCMAKE_C_COMPILER=clang \\\n    -DCMAKE_CXX_COMPILER=clang++ \\\n    -DCMAKE_CUDA_ARCHITECTURES=${cuda_arch} \\\n    -DGGML_CUDA=${llamacpp_cuda} \\\n    -DGGML_NATIVE=${llamacpp_native} \\\n    -DGGML_CPU_ARM_ARCH=${llamacpp_cpu_arm_arch} \\\n    -DLLAMA_BUILD_COMMON=OFF \\\n    -DLLAMA_BUILD_TESTS=OFF \\\n    -DLLAMA_BUILD_EXAMPLES=OFF \\\n    -DLLAMA_BUILD_SERVER=OFF \\\n && cmake --build build --parallel --config Release \\\n && cmake --install build\n\nWORKDIR /app\nCOPY rust-toolchain.toml rust-toolchain.toml\nRUN curl -sSf https://sh.rustup.rs | sh -s -- --no-modify-path --default-toolchain 1.85.1 --profile minimal -y\nENV PATH=\"/root/.cargo/bin:$PATH\"\nRUN cargo install cargo-chef --locked\n\nFROM deps AS planner\nCOPY . .\nRUN cargo chef prepare --recipe-path recipe.json\n\nFROM deps AS builder\nCOPY --from=planner /app/recipe.json recipe.json\nRUN cargo chef cook \\\n    --recipe-path recipe.json \\\n    --profile release \\\n    --package text-generation-router-llamacpp\nCOPY . .\nRUN cargo build \\\n    --profile release \\\n    --package text-generation-router-llamacpp --frozen\n\nFROM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04\nWORKDIR /app\n\nENV DEBIAN_FRONTEND=noninteractive\nRUN apt update && apt upgrade -y && apt install -y \\\n    python3-venv \\\n    python3-pip\n\nRUN python3 -m venv /venv\nENV PATH=\"/venv/bin:$PATH\"\n\nCOPY backends/llamacpp/requirements.txt requirements.txt\nCOPY --from=builder /opt/src/llama.cpp/gguf-py gguf-py\nCOPY --from=builder /opt/src/llama.cpp/convert_hf_to_gguf.py /bin/\n\nRUN pip3 install --no-cache-dir \\\n    -r requirements.txt \\\n    -e gguf-py\n\nCOPY --from=builder /usr/lib/libllama.so /usr/lib/\nCOPY --from=builder /usr/lib/libggml*.so /usr/lib/\nCOPY --from=builder /app/target/release/text-generation-router-llamacpp /usr/bin/\n\nENV HF_HUB_ENABLE_HF_TRANSFER=1\n\nENTRYPOINT [\"text-generation-router-llamacpp\"]\n"
  },
  {
    "path": "Dockerfile_trtllm",
    "content": "ARG cuda_arch_list=\"75-real;80-real;86-real;89-real;90-real;100-real;120-real\"\nARG cuda_base=12.8.0\nARG build_type=release\nARG ompi_version=4.1.7\nARG sccache_gha_enabled=off\nARG actions_results_url=\"\"\nARG actions_runtime_token=\"\"\n\n# CUDA dependent dependencies resolver stage\nFROM nvidia/cuda:${cuda_base}-cudnn-devel-ubuntu24.04 AS cuda-builder\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \\\n    build-essential \\\n    cmake \\\n    curl \\\n    gcc-14  \\\n    g++-14 \\\n    git \\\n    git-lfs \\\n    lld \\\n    libssl-dev \\\n    libucx-dev \\\n    libasan8 \\\n    libubsan1 \\\n    ninja-build \\\n    pkg-config \\\n    pipx \\\n    python3 \\\n    python3-dev \\\n    python3-setuptools \\\n    tar \\\n    wget --no-install-recommends && \\\n    pipx ensurepath\n\nENV TGI_INSTALL_PREFIX=/usr/local/tgi\nENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt\n\n# Install OpenMPI\nFROM cuda-builder AS mpi-builder\nWORKDIR /opt/src/mpi\n\nARG ompi_version\nENV OMPI_VERSION=${ompi_version}\nENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2\nADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \\\n    https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .\n\nRUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\\\n    ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \\\n    make -j all && \\\n    make install && \\\n    rm -rf ${OMPI_TARBALL_FILENAME}/..\n\n# Install TensorRT\nFROM cuda-builder AS trt-builder\nCOPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh\nRUN chmod +x /opt/install_tensorrt.sh && \\\n    /opt/install_tensorrt.sh\n\n# Build Backend\nFROM cuda-builder AS tgi-builder\nWORKDIR /usr/src/text-generation-inference\n\n# Scoped global args reuse\nARG cuda_arch_list\nARG build_type\nARG sccache_gha_enabled\n\n# Install Rust\nENV PATH=\"/root/.cargo/bin:$PATH\"\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain 1.85.1 --profile minimal -y && \\\n    chmod -R a+w /root/.rustup && \\\n    chmod -R a+w /root/.cargo && \\\n    cargo install sccache --version \">=0.10.0\" --locked\n\nENV LD_LIBRARY_PATH=\"/usr/local/mpi/lib:$LD_LIBRARY_PATH\"\nENV PKG_CONFIG_PATH=\"/usr/local/mpi/lib/pkgconfig\"\nENV CMAKE_PREFIX_PATH=\"/usr/local/mpi:/usr/local/tensorrt\"\n\nENV USE_LLD_LINKER=ON\nENV CUDA_ARCH_LIST=${cuda_arch_list}\n\n# SCCACHE Specifics args - before finding a better, more generic, way...\nENV SCCACHE_GHA_ENABLED=${sccache_gha_enabled}\n\nCOPY Cargo.lock Cargo.lock\nCOPY Cargo.toml Cargo.toml\nCOPY rust-toolchain.toml rust-toolchain.toml\nCOPY router router\nCOPY backends backends\nCOPY benchmark benchmark\nCOPY launcher launcher\nCOPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt\nCOPY --from=mpi-builder /usr/local/mpi /usr/local/mpi\n\nENV RUSTC_WRAPPER=sccache\nENV CMAKE_INSTALL_PREFIX=$TGI_INSTALL_PREFIX\nRUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \\\n    --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \\\n    export CMAKE_C_COMPILER_LAUNCHER=sccache && \\\n    export CMAKE_CXX_COMPILER_LAUNCHER=sccache && \\\n    export CMAKE_CUDA_COMPILER_LAUNCHER=sccache && \\\n    mkdir $TGI_INSTALL_PREFIX && mkdir \"$TGI_INSTALL_PREFIX/include\" && mkdir \"$TGI_INSTALL_PREFIX/lib\" && \\\n    cargo build --profile ${build_type} --package text-generation-backends-trtllm --bin text-generation-backends-trtllm && \\\n    sccache --show-stats\n\nFROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS runtime\nRUN apt update && apt install -y libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \\\n    rm -rf /var/lib/{apt,dpkg,cache,log}/ && \\\n    pipx ensurepath && \\\n    pipx install --include-deps transformers tokenizers\n\nWORKDIR /usr/local/tgi/bin\n\nENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH\nENV LD_LIBRARY_PATH=\"/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH\"\nENV TOKENIZERS_PARALLELISM=false\nENV OMPI_MCA_plm_rsh_agent=\"\"\n\nCOPY --from=mpi-builder /usr/local/mpi /usr/local/mpi\nCOPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt\nCOPY --from=tgi-builder /usr/local/tgi /usr/local/tgi\nCOPY --from=tgi-builder /usr/src/text-generation-inference/target/release/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher\n\n# This is used only for the CI/CD\nFROM nvidia/cuda:${cuda_base}-cudnn-runtime-ubuntu24.04 AS ci-runtime\nRUN apt update && apt install -y libasan8 libubsan1 libucx0 pipx python3-minimal python3-dev python3-pip python3-venv && \\\n    rm -rf /var/lib/{apt,dpkg,cache,log}/ && \\\n    pipx ensurepath && \\\n    pipx install --include-deps transformers tokenizers\n\nWORKDIR /usr/local/tgi/bin\n\nENV PATH=/root/.local/share/pipx/venvs/transformers/bin/:$PATH\nENV LD_LIBRARY_PATH=\"/usr/local/tgi/lib:/usr/local/mpi/lib:/usr/local/tensorrt/lib:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH\"\nENV TOKENIZERS_PARALLELISM=false\nENV OMPI_MCA_plm_rsh_agent=\"\"\n\nCOPY --from=mpi-builder /usr/local/mpi /usr/local/mpi\nCOPY --from=trt-builder /usr/local/tensorrt /usr/local/tensorrt\nCOPY --from=tgi-builder /usr/local/tgi /usr/local/tgi\n\n# Basically we copy from target/debug instead of target/release\nCOPY --from=tgi-builder /usr/src/text-generation-inference/target/debug/text-generation-backends-trtllm /usr/local/tgi/bin/text-generation-launcher\n\n# This is the final image\nFROM runtime\n\nLABEL co.huggingface.vendor=\"Hugging Face Inc.\"\nLABEL org.opencontainers.image.authors=\"hardware@hf.co\"\nLABEL org.opencontainers.title=\"Text-Generation-Inference TensorRT-LLM Backend\"\n\nENTRYPOINT [\"./text-generation-launcher\"]\nCMD [\"--executor-worker\", \"/usr/local/tgi/bin/executorWorker\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2022 Hugging Face\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": "install-server:\n\tcd server && make install\n\ninstall-server-cpu:\n\tcd server && make install-server\n\ninstall-router:\n\tcargo install --path backends/v3/\n\ninstall-launcher:\n\tcargo install --path launcher/\n\ninstall-benchmark:\n\tcargo install --path benchmark/\n\ninstall: install-server install-router install-launcher\n\n\ninstall-cpu: install-server-cpu install-router install-launcher\n\nserver-dev:\n\tcd server && make run-dev\n\nrouter-dev:\n\tcd router && cargo run -- --port 8080\n\nrust-tests: install-router install-launcher\n\tcargo test\n\ninstall-integration-tests:\n\tcd integration-tests && pip install -r requirements.txt\n\tcd clients/python && pip install .\n\nintegration-tests: install-integration-tests\n\tpytest -s -vv -m \"not private\" integration-tests\n\nupdate-integration-tests: install-integration-tests\n\tpytest -s -vv --snapshot-update integration-tests\n\npython-server-tests:\n\tHF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m \"not private\" server/tests\n\npython-client-tests:\n\tpytest clients/python/tests\n\npython-tests: python-server-tests python-client-tests\n\nrun-falcon-7b-instruct:\n\ttext-generation-launcher --model-id tiiuae/falcon-7b-instruct --port 8080\n\nrun-falcon-7b-instruct-quantize:\n\ttext-generation-launcher --model-id tiiuae/falcon-7b-instruct --quantize bitsandbytes --port 8080\n\nclean:\n\trm -rf target aml\n\npreview_doc:\n\tdoc-builder preview text-generation-inference docs/source --not_python_module\n"
  },
  {
    "path": "README.md",
    "content": "> [!CAUTION]\n> text-generation-inference is now in maintenance mode. Going forward, we will accept pull requests for minor bug fixes, documentation improvements and lightweight maintenance tasks.\n>\n> TGI has initiated the movement for optimized inference engines to rely on a `transformers` model architectures. This approach is now adopted by downstream inference engines, which we contribute to and recommend using going forward: [vllm](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), as well as local engines with inter-compatibility such as llama.cpp or MLX.\n\n<div align=\"center\">\n\n<a href=\"https://www.youtube.com/watch?v=jlMAX2Oaht0\">\n  <img width=560 alt=\"Making TGI deployment optimal\" src=\"https://huggingface.co/datasets/Narsil/tgi_assets/resolve/main/thumbnail.png\">\n</a>\n\n# Text Generation Inference\n\n<a href=\"https://github.com/huggingface/text-generation-inference\">\n  <img alt=\"GitHub Repo stars\" src=\"https://img.shields.io/github/stars/huggingface/text-generation-inference?style=social\">\n</a>\n<a href=\"https://huggingface.github.io/text-generation-inference\">\n  <img alt=\"Swagger API documentation\" src=\"https://img.shields.io/badge/API-Swagger-informational\">\n</a>\n\nA Rust, Python and gRPC server for text generation inference. Used in production at [Hugging Face](https://huggingface.co)\nto power Hugging Chat, the Inference API and Inference Endpoints.\n\n</div>\n\n## Table of contents\n\n  - [Get Started](#get-started)\n    - [Docker](#docker)\n    - [API documentation](#api-documentation)\n    - [Using a private or gated model](#using-a-private-or-gated-model)\n    - [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)\n    - [Distributed Tracing](#distributed-tracing)\n    - [Architecture](#architecture)\n    - [Local install](#local-install)\n    - [Local install (Nix)](#local-install-nix)\n  - [Optimized architectures](#optimized-architectures)\n  - [Run locally](#run-locally)\n    - [Run](#run)\n    - [Quantization](#quantization)\n  - [Develop](#develop)\n  - [Testing](#testing)\n\nText Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:\n\n- Simple launcher to serve most popular LLMs\n- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)\n- Tensor Parallelism for faster inference on multiple GPUs\n- Token streaming using Server-Sent Events (SSE)\n- Continuous batching of incoming requests for increased total throughput\n- [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) compatible with Open AI Chat Completion API\n- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures\n- Quantization with :\n  - [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)\n  - [GPT-Q](https://arxiv.org/abs/2210.17323)\n  - [EETQ](https://github.com/NetEase-FuXi/EETQ)\n  - [AWQ](https://github.com/casper-hansen/AutoAWQ)\n  - [Marlin](https://github.com/IST-DASLab/marlin)\n  - [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/)\n- [Safetensors](https://github.com/huggingface/safetensors) weight loading\n- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n- Logits warper (temperature scaling, top-p, top-k, repetition penalty, more details see [transformers.LogitsProcessor](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.LogitsProcessor))\n- Stop sequences\n- Log probabilities\n- [Speculation](https://huggingface.co/docs/text-generation-inference/conceptual/speculation) ~2x latency\n- [Guidance/JSON](https://huggingface.co/docs/text-generation-inference/conceptual/guidance). Specify output format to speed up inference and make sure the output is valid according to some specs..\n- Custom Prompt Generation: Easily generate text by providing custom prompts to guide the model's output\n- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance\n\n### Hardware support\n\n- [Nvidia](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference)\n- [AMD](https://github.com/huggingface/text-generation-inference/pkgs/container/text-generation-inference) (-rocm)\n- [Inferentia](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference)\n- [Intel GPU](https://github.com/huggingface/text-generation-inference/pull/1475)\n- [Gaudi](https://github.com/huggingface/tgi-gaudi)\n- [Google TPU](https://huggingface.co/docs/optimum-tpu/howto/serving)\n\n\n## Get Started\n\n### Docker\n\nFor a detailed starting guide, please see the [Quick Tour](https://huggingface.co/docs/text-generation-inference/quicktour). The easiest way of getting started is using the official Docker container:\n\n```shell\nmodel=HuggingFaceH4/zephyr-7b-beta\n# share a volume with the Docker container to avoid downloading weights every run\nvolume=$PWD/data\n\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model\n```\n\nAnd then you can make requests like\n\n```bash\ncurl 127.0.0.1:8080/generate_stream \\\n    -X POST \\\n    -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n    -H 'Content-Type: application/json'\n```\n\nYou can also use [TGI's Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) to obtain Open AI Chat Completion API compatible responses.\n\n```bash\ncurl localhost:8080/v1/chat/completions \\\n    -X POST \\\n    -d '{\n  \"model\": \"tgi\",\n  \"messages\": [\n    {\n      \"role\": \"system\",\n      \"content\": \"You are a helpful assistant.\"\n    },\n    {\n      \"role\": \"user\",\n      \"content\": \"What is deep learning?\"\n    }\n  ],\n  \"stream\": true,\n  \"max_tokens\": 20\n}' \\\n    -H 'Content-Type: application/json'\n```\n\n**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.\n\n**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/installation_amd#using-tgi-with-amd-gpus). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5-rocm --model-id $model` instead of the command above.\n\nTo see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):\n```\ntext-generation-launcher --help\n```\n\n### API documentation\n\nYou can consult the OpenAPI documentation of the `text-generation-inference` REST API using the `/docs` route.\nThe Swagger UI is also available at: [https://huggingface.github.io/text-generation-inference](https://huggingface.github.io/text-generation-inference).\n\n### Using a private or gated model\n\nYou have the option to utilize the `HF_TOKEN` environment variable for configuring the token employed by\n`text-generation-inference`. This allows you to gain access to protected resources.\n\nFor example, if you want to serve the gated Llama V2 model variants:\n\n1. Go to https://huggingface.co/settings/tokens\n2. Copy your CLI READ token\n3. Export `HF_TOKEN=<your CLI READ token>`\n\nor with Docker:\n\n```shell\nmodel=meta-llama/Meta-Llama-3.1-8B-Instruct\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\ntoken=<your cli READ token>\n\ndocker run --gpus all --shm-size 1g -e HF_TOKEN=$token -p 8080:80 -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model\n```\n\n### A note on Shared Memory (shm)\n\n[`NCCL`](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html) is a communication framework used by\n`PyTorch` to do distributed training/inference. `text-generation-inference` makes\nuse of `NCCL` to enable Tensor Parallelism to dramatically speed up inference for large language models.\n\nIn order to share data between the different devices of a `NCCL` group, `NCCL` might fall back to using the host memory if\npeer-to-peer using NVLink or PCI is not possible.\n\nTo allow the container to use 1G of Shared Memory and support SHM sharing, we add `--shm-size 1g` on the above command.\n\nIf you are running `text-generation-inference` inside `Kubernetes`. You can also add Shared Memory to the container by\ncreating a volume with:\n\n```yaml\n- name: shm\n  emptyDir:\n   medium: Memory\n   sizeLimit: 1Gi\n```\n\nand mounting it to `/dev/shm`.\n\nFinally, you can also disable SHM sharing by using the `NCCL_SHM_DISABLE=1` environment variable. However, note that\nthis will impact performance.\n\n### Distributed Tracing\n\n`text-generation-inference` is instrumented with distributed tracing using OpenTelemetry. You can use this feature\nby setting the address to an OTLP collector with the `--otlp-endpoint` argument. The default service name can be\noverridden with the `--otlp-service-name` argument\n\n### Architecture\n\n![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)\n\nDetailed blogpost by Adyen on TGI inner workings: [LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)\n\n### Local install\n\nYou can also opt to install `text-generation-inference` locally.\n\nFirst clone the repository and change directory into it:\n\n```shell\ngit clone https://github.com/huggingface/text-generation-inference\ncd text-generation-inference\n```\n\nThen [install Rust](https://rustup.rs/) and create a Python virtual environment with at least\nPython 3.9, e.g. using `conda` or `python venv`:\n\n```shell\ncurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\n\n#using conda\nconda create -n text-generation-inference python=3.11\nconda activate text-generation-inference\n\n#using python venv\npython3 -m venv .venv\nsource .venv/bin/activate\n```\n\nYou may also need to install Protoc.\n\nOn Linux:\n\n```shell\nPROTOC_ZIP=protoc-21.12-linux-x86_64.zip\ncurl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP\nsudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc\nsudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'\nrm -f $PROTOC_ZIP\n```\n\nOn MacOS, using Homebrew:\n\n```shell\nbrew install protobuf\n```\n\nThen run:\n\n```shell\nBUILD_EXTENSIONS=True make install # Install repository and HF/transformer fork with CUDA kernels\ntext-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2\n```\n\n**Note:** on some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:\n\n```shell\nsudo apt-get install libssl-dev gcc -y\n```\n\n### Local install (Nix)\n\nAnother option is to install `text-generation-inference` locally using [Nix](https://nixos.org). Currently,\nwe only support Nix on x86_64 Linux with CUDA GPUs. When using Nix, all dependencies can\nbe pulled from a binary cache, removing the need to build them locally.\n\nFirst follow the instructions to [install Cachix and enable the Hugging Face cache](https://app.cachix.org/cache/huggingface).\nSetting up the cache is important, otherwise Nix will build many of the dependencies\nlocally, which can take hours.\n\nAfter that you can run TGI with `nix run`:\n\n```shell\ncd text-generation-inference\nnix run --extra-experimental-features nix-command --extra-experimental-features flakes . -- --model-id meta-llama/Llama-3.1-8B-Instruct\n```\n\n**Note:** when you are using Nix on a non-NixOS system, you have to [make some symlinks](https://danieldk.eu/Nix-CUDA-on-non-NixOS-systems#make-runopengl-driverlib-and-symlink-the-driver-library)\nto make the CUDA driver libraries visible to Nix packages.\n\nFor TGI development, you can use the `impure` dev shell:\n\n```shell\nnix develop .#impure\n\n# Only needed the first time the devshell is started or after updating the protobuf.\n(\ncd server\nmkdir text_generation_server/pb || true\npython -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \\\n       --grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto\nfind text_generation_server/pb/ -type f -name \"*.py\" -print0 -exec sed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' {} \\;\ntouch text_generation_server/pb/__init__.py\n)\n```\n\nAll development dependencies (cargo, Python, Torch), etc. are available in this\ndev shell.\n\n## Optimized architectures\n\nTGI works out of the box to serve optimized models for all modern models. They can be found in [this list](https://huggingface.co/docs/text-generation-inference/supported_models).\n\nOther architectures are supported on a best-effort basis using:\n\n`AutoModelForCausalLM.from_pretrained(<model>, device_map=\"auto\")`\n\nor\n\n`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map=\"auto\")`\n\n\n\n## Run locally\n\n### Run\n\n```shell\ntext-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2\n```\n\n### Quantization\n\nYou can also run pre-quantized weights (AWQ, GPTQ, Marlin) or on-the-fly quantize weights with bitsandbytes, EETQ, fp8, to reduce the VRAM requirement:\n\n```shell\ntext-generation-launcher --model-id mistralai/Mistral-7B-Instruct-v0.2 --quantize\n```\n\n4bit quantization is available using the [NF4 and FP4 data types from bitsandbytes](https://arxiv.org/pdf/2305.14314.pdf). It can be enabled by providing `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` as a command line argument to `text-generation-launcher`.\n\nRead more about quantization in the [Quantization documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/quantization).\n\n## Develop\n\n```shell\nmake server-dev\nmake router-dev\n```\n\n## Testing\n\n```shell\n# python\nmake python-server-tests\nmake python-client-tests\n# or both server and client tests\nmake python-tests\n# rust cargo tests\nmake rust-tests\n# integration tests\nmake integration-tests\n```\n"
  },
  {
    "path": "assets/tgi_grafana.json",
    "content": "{\n  \"__inputs\": [\n    {\n      \"name\": \"DS_PROMETHEUS_EKS API INFERENCE PROD\",\n      \"label\": \"Prometheus EKS API Inference Prod\",\n      \"description\": \"\",\n      \"type\": \"datasource\",\n      \"pluginId\": \"prometheus\",\n      \"pluginName\": \"Prometheus\"\n    }\n  ],\n  \"__elements\": {},\n  \"__requires\": [\n    {\n      \"type\": \"panel\",\n      \"id\": \"gauge\",\n      \"name\": \"Gauge\",\n      \"version\": \"\"\n    },\n    {\n      \"type\": \"grafana\",\n      \"id\": \"grafana\",\n      \"name\": \"Grafana\",\n      \"version\": \"10.0.2\"\n    },\n    {\n      \"type\": \"panel\",\n      \"id\": \"heatmap\",\n      \"name\": \"Heatmap\",\n      \"version\": \"\"\n    },\n    {\n      \"type\": \"datasource\",\n      \"id\": \"prometheus\",\n      \"name\": \"Prometheus\",\n      \"version\": \"1.0.0\"\n    },\n    {\n      \"type\": \"panel\",\n      \"id\": \"timeseries\",\n      \"name\": \"Time series\",\n      \"version\": \"\"\n    }\n  ],\n  \"annotations\": {\n    \"list\": [\n      {\n        \"builtIn\": 1,\n        \"datasource\": {\n          \"type\": \"grafana\",\n          \"uid\": \"-- Grafana --\"\n        },\n        \"enable\": true,\n        \"hide\": true,\n        \"iconColor\": \"rgba(0, 211, 255, 1)\",\n        \"name\": \"Annotations & Alerts\",\n        \"target\": {\n          \"limit\": 100,\n          \"matchAny\": false,\n          \"tags\": [],\n          \"type\": \"dashboard\"\n        },\n        \"type\": \"dashboard\"\n      }\n    ]\n  },\n  \"editable\": true,\n  \"fiscalYearStartMonth\": 0,\n  \"graphTooltip\": 2,\n  \"id\": 551,\n  \"links\": [],\n  \"liveNow\": false,\n  \"panels\": [\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"thresholds\"\n          },\n          \"fieldMinMax\": false,\n          \"mappings\": [],\n          \"min\": 0,\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 1000\n              }\n            ]\n          },\n          \"unit\": \"ms\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 7,\n        \"w\": 8,\n        \"x\": 0,\n        \"y\": 0\n      },\n      \"id\": 49,\n      \"options\": {\n        \"colorMode\": \"value\",\n        \"graphMode\": \"area\",\n        \"justifyMode\": \"auto\",\n        \"orientation\": \"auto\",\n        \"reduceOptions\": {\n          \"calcs\": [\n            \"mean\"\n          ],\n          \"fields\": \"\",\n          \"values\": false\n        },\n        \"showPercentChange\": false,\n        \"textMode\": \"auto\",\n        \"wideLayout\": true\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"(histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\\\"$service\\\"}[10m]))) * 1000) > 0\",\n          \"hide\": true,\n          \"instant\": false,\n          \"legendFormat\": \"__auto\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"prefill\\\", container=\\\"$service\\\"}[10m]))) * 1000) > 0\",\n          \"hide\": true,\n          \"instant\": false,\n          \"legendFormat\": \"__auto\",\n          \"range\": true,\n          \"refId\": \"C\"\n        },\n        {\n          \"datasource\": {\n            \"name\": \"Expression\",\n            \"type\": \"__expr__\",\n            \"uid\": \"__expr__\"\n          },\n          \"expression\": \"$B + $C\",\n          \"hide\": false,\n          \"refId\": \"D\",\n          \"type\": \"math\"\n        }\n      ],\n      \"title\": \"Time to first token\",\n      \"type\": \"stat\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"thresholds\"\n          },\n          \"mappings\": [],\n          \"min\": 0,\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"ms\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 7,\n        \"w\": 8,\n        \"x\": 9,\n        \"y\": 0\n      },\n      \"id\": 44,\n      \"options\": {\n        \"colorMode\": \"value\",\n        \"graphMode\": \"area\",\n        \"justifyMode\": \"auto\",\n        \"orientation\": \"auto\",\n        \"reduceOptions\": {\n          \"calcs\": [\n            \"mean\"\n          ],\n          \"fields\": \"\",\n          \"values\": false\n        },\n        \"showPercentChange\": false,\n        \"textMode\": \"auto\",\n        \"wideLayout\": true\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"(histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m]))) * 1000)>0\",\n          \"instant\": false,\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Decode per-token latency\",\n      \"type\": \"stat\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"thresholds\"\n          },\n          \"mappings\": [],\n          \"min\": 0,\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              }\n            ]\n          },\n          \"unit\": \"short\"\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 7,\n        \"w\": 7,\n        \"x\": 17,\n        \"y\": 0\n      },\n      \"id\": 45,\n      \"options\": {\n        \"colorMode\": \"value\",\n        \"graphMode\": \"area\",\n        \"justifyMode\": \"auto\",\n        \"orientation\": \"auto\",\n        \"reduceOptions\": {\n          \"calcs\": [\n            \"mean\"\n          ],\n          \"fields\": \"\",\n          \"values\": false\n        },\n        \"showPercentChange\": false,\n        \"textMode\": \"auto\",\n        \"wideLayout\": true\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sum((rate(tgi_request_generated_tokens_sum{container=\\\"$service\\\"}[10m]) / rate(tgi_request_generated_tokens_count{container=\\\"$service\\\"}[10m]))>0)\",\n          \"instant\": false,\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Throughput (generated tok/s)\",\n      \"type\": \"stat\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"none\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 7\n      },\n      \"id\": 48,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Number of tokens per prompt\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"none\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 7\n      },\n      \"id\": 30,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_generated_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Number of generated tokens per request\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"collapsed\": false,\n      \"gridPos\": {\n        \"h\": 1,\n        \"w\": 24,\n        \"x\": 0,\n        \"y\": 15\n      },\n      \"id\": 20,\n      \"panels\": [],\n      \"title\": \"General\",\n      \"type\": \"row\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 30,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 8,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 16\n      },\n      \"id\": 4,\n      \"maxDataPoints\": 100,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sum(increase(tgi_request_success{container=\\\"$service\\\"}[1m]))\",\n          \"legendFormat\": \"Success\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sum(increase(tgi_request_failure{container=\\\"$service\\\"}[1m])) by (err)\",\n          \"hide\": false,\n          \"legendFormat\": \"Error: {{err}}\",\n          \"range\": true,\n          \"refId\": \"B\"\n        }\n      ],\n      \"title\": \"Requests\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 13,\n        \"w\": 9,\n        \"x\": 6,\n        \"y\": 16\n      },\n      \"id\": 6,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_mean_time_per_token_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Mean Time Per Token quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 13,\n        \"w\": 9,\n        \"x\": 15,\n        \"y\": 16\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 13,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_request_mean_time_per_token_duration_bucket{container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Mean Time Per Token\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"auto\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"percentage\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"orange\",\n                \"value\": 70\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 85\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 5,\n        \"w\": 3,\n        \"x\": 0,\n        \"y\": 24\n      },\n      \"id\": 18,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": false\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"9.1.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"count(tgi_request_count{container=\\\"$service\\\"})\",\n          \"legendFormat\": \"Replicas\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Number of replicas\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"percentage\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"orange\",\n                \"value\": 70\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 85\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 5,\n        \"w\": 3,\n        \"x\": 3,\n        \"y\": 24\n      },\n      \"id\": 32,\n      \"options\": {\n        \"minVizHeight\": 75,\n        \"minVizWidth\": 75,\n        \"orientation\": \"auto\",\n        \"reduceOptions\": {\n          \"calcs\": [\n            \"lastNotNull\"\n          ],\n          \"fields\": \"\",\n          \"values\": false\n        },\n        \"showThresholdLabels\": false,\n        \"showThresholdMarkers\": true,\n        \"sizing\": \"auto\"\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sum(tgi_queue_size{container=\\\"$service\\\"})\",\n          \"legendFormat\": \"__auto\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Queue Size\",\n      \"type\": \"gauge\"\n    },\n    {\n      \"collapsed\": false,\n      \"gridPos\": {\n        \"h\": 1,\n        \"w\": 24,\n        \"x\": 0,\n        \"y\": 29\n      },\n      \"id\": 26,\n      \"panels\": [],\n      \"title\": \"Batching\",\n      \"type\": \"row\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"bars\",\n            \"fillOpacity\": 50,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"normal\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 5,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 30\n      },\n      \"id\": 29,\n      \"maxDataPoints\": 40,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": false\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"9.1.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"avg(tgi_batch_current_max_tokens{container=\\\"$service\\\"})\",\n          \"legendFormat\": \"{{ pod }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Max tokens per batch\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"none\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 9,\n        \"w\": 4,\n        \"x\": 6,\n        \"y\": 30\n      },\n      \"id\": 33,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_skipped_tokens_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Speculated Tokens\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"none\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 9,\n        \"w\": 5,\n        \"x\": 10,\n        \"y\": 30\n      },\n      \"id\": 46,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_input_length_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Prompt Tokens\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 9,\n        \"w\": 9,\n        \"x\": 15,\n        \"y\": 30\n      },\n      \"id\": 8,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Latency quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"bars\",\n            \"fillOpacity\": 50,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"normal\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 4,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 35\n      },\n      \"id\": 27,\n      \"maxDataPoints\": 40,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": false\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"pluginVersion\": \"9.1.0\",\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"avg(tgi_batch_current_size{container=\\\"$service\\\"})\",\n          \"legendFormat\": \"{{ pod }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Batch Size\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 30,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 9,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 39\n      },\n      \"id\": 28,\n      \"maxDataPoints\": 100,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"sum(increase(tgi_batch_concat{container=\\\"$service\\\"}[1m])) by (reason)\",\n          \"hide\": false,\n          \"legendFormat\": \"Reason: {{ reason }}\",\n          \"range\": true,\n          \"refId\": \"B\"\n        }\n      ],\n      \"title\": \"Concatenates\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 9,\n        \"w\": 9,\n        \"x\": 6,\n        \"y\": 39\n      },\n      \"id\": 31,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_request_queue_duration_bucket{container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Queue quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"collapsed\": false,\n      \"gridPos\": {\n        \"h\": 1,\n        \"w\": 24,\n        \"x\": 0,\n        \"y\": 48\n      },\n      \"id\": 22,\n      \"panels\": [],\n      \"title\": \"Prefill\",\n      \"type\": \"row\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 49\n      },\n      \"id\": 7,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"prefill\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"prefill\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"prefill\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Prefill Quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 49\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 14,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_inference_duration_bucket{method=\\\"prefill\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Prefill Latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"collapsed\": false,\n      \"gridPos\": {\n        \"h\": 1,\n        \"w\": 24,\n        \"x\": 0,\n        \"y\": 60\n      },\n      \"id\": 24,\n      \"panels\": [],\n      \"title\": \"Decode\",\n      \"type\": \"row\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 12,\n        \"x\": 0,\n        \"y\": 61\n      },\n      \"id\": 11,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_inference_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Decode quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 12,\n        \"x\": 12,\n        \"y\": 61\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 15,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_inference_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Decode Latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"collapsed\": false,\n      \"gridPos\": {\n        \"h\": 1,\n        \"w\": 24,\n        \"x\": 0,\n        \"y\": 72\n      },\n      \"id\": 43,\n      \"panels\": [],\n      \"title\": \"Debug\",\n      \"type\": \"row\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 73\n      },\n      \"id\": 38,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_forward_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Forward quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 6,\n        \"y\": 73\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 35,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_forward_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Forward Latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 12,\n        \"y\": 73\n      },\n      \"id\": 34,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_decode_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Token Decode quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 18,\n        \"y\": 73\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 40,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_decode_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Token Decode Latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 0,\n        \"y\": 84\n      },\n      \"id\": 42,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_filter_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Filter Batch quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 6,\n        \"y\": 84\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 39,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_filter_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Filter Batch Latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    },\n    {\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"color\": {\n            \"mode\": \"palette-classic\"\n          },\n          \"custom\": {\n            \"axisBorderShow\": false,\n            \"axisCenteredZero\": false,\n            \"axisColorMode\": \"text\",\n            \"axisLabel\": \"\",\n            \"axisPlacement\": \"auto\",\n            \"barAlignment\": 0,\n            \"drawStyle\": \"line\",\n            \"fillOpacity\": 0,\n            \"gradientMode\": \"none\",\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"insertNulls\": false,\n            \"lineInterpolation\": \"linear\",\n            \"lineWidth\": 1,\n            \"pointSize\": 5,\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            },\n            \"showPoints\": \"never\",\n            \"spanNulls\": false,\n            \"stacking\": {\n              \"group\": \"A\",\n              \"mode\": \"none\"\n            },\n            \"thresholdsStyle\": {\n              \"mode\": \"off\"\n            }\n          },\n          \"mappings\": [],\n          \"thresholds\": {\n            \"mode\": \"absolute\",\n            \"steps\": [\n              {\n                \"color\": \"green\",\n                \"value\": null\n              },\n              {\n                \"color\": \"red\",\n                \"value\": 80\n              }\n            ]\n          },\n          \"unit\": \"s\"\n        },\n        \"overrides\": [\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p50\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"green\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p90\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"orange\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          },\n          {\n            \"matcher\": {\n              \"id\": \"byName\",\n              \"options\": \"p99\"\n            },\n            \"properties\": [\n              {\n                \"id\": \"color\",\n                \"value\": {\n                  \"fixedColor\": \"red\",\n                  \"mode\": \"fixed\"\n                }\n              }\n            ]\n          }\n        ]\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 12,\n        \"y\": 84\n      },\n      \"id\": 36,\n      \"options\": {\n        \"legend\": {\n          \"calcs\": [\n            \"min\",\n            \"max\"\n          ],\n          \"displayMode\": \"list\",\n          \"placement\": \"bottom\",\n          \"showLegend\": true\n        },\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"sort\": \"none\"\n        }\n      },\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.5, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"legendFormat\": \"p50\",\n          \"range\": true,\n          \"refId\": \"A\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.9, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p90\",\n          \"range\": true,\n          \"refId\": \"B\"\n        },\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"expr\": \"histogram_quantile(0.99, sum by (le) (rate(tgi_batch_concat_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[10m])))\",\n          \"hide\": false,\n          \"legendFormat\": \"p99\",\n          \"range\": true,\n          \"refId\": \"C\"\n        }\n      ],\n      \"title\": \"Batch Concat quantiles\",\n      \"type\": \"timeseries\"\n    },\n    {\n      \"cards\": {},\n      \"color\": {\n        \"cardColor\": \"#5794F2\",\n        \"colorScale\": \"linear\",\n        \"colorScheme\": \"interpolateSpectral\",\n        \"exponent\": 0.5,\n        \"min\": 0,\n        \"mode\": \"opacity\"\n      },\n      \"dataFormat\": \"tsbuckets\",\n      \"datasource\": {\n        \"type\": \"prometheus\",\n        \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n      },\n      \"fieldConfig\": {\n        \"defaults\": {\n          \"custom\": {\n            \"hideFrom\": {\n              \"legend\": false,\n              \"tooltip\": false,\n              \"viz\": false\n            },\n            \"scaleDistribution\": {\n              \"type\": \"linear\"\n            }\n          }\n        },\n        \"overrides\": []\n      },\n      \"gridPos\": {\n        \"h\": 11,\n        \"w\": 6,\n        \"x\": 18,\n        \"y\": 84\n      },\n      \"heatmap\": {},\n      \"hideZeroBuckets\": false,\n      \"highlightCards\": true,\n      \"id\": 41,\n      \"legend\": {\n        \"show\": false\n      },\n      \"maxDataPoints\": 25,\n      \"options\": {\n        \"calculate\": false,\n        \"calculation\": {},\n        \"cellGap\": 2,\n        \"cellValues\": {},\n        \"color\": {\n          \"exponent\": 0.5,\n          \"fill\": \"#5794F2\",\n          \"min\": 0,\n          \"mode\": \"scheme\",\n          \"reverse\": false,\n          \"scale\": \"exponential\",\n          \"scheme\": \"Spectral\",\n          \"steps\": 128\n        },\n        \"exemplars\": {\n          \"color\": \"rgba(255,0,255,0.7)\"\n        },\n        \"filterValues\": {\n          \"le\": 1e-9\n        },\n        \"legend\": {\n          \"show\": false\n        },\n        \"rowsFrame\": {\n          \"layout\": \"auto\"\n        },\n        \"showValue\": \"never\",\n        \"tooltip\": {\n          \"mode\": \"single\",\n          \"showColorScale\": false,\n          \"yHistogram\": false\n        },\n        \"yAxis\": {\n          \"axisPlacement\": \"left\",\n          \"decimals\": 1,\n          \"reverse\": false,\n          \"unit\": \"s\"\n        }\n      },\n      \"pluginVersion\": \"10.4.2\",\n      \"reverseYBuckets\": false,\n      \"targets\": [\n        {\n          \"datasource\": {\n            \"type\": \"prometheus\",\n            \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n          },\n          \"editorMode\": \"code\",\n          \"exemplar\": true,\n          \"expr\": \"sum(increase(tgi_batch_concat_duration_bucket{method=\\\"decode\\\", container=\\\"$service\\\"}[5m])) by (le)\",\n          \"format\": \"heatmap\",\n          \"interval\": \"\",\n          \"legendFormat\": \"{{ le }}\",\n          \"range\": true,\n          \"refId\": \"A\"\n        }\n      ],\n      \"title\": \"Batch Concat latency\",\n      \"tooltip\": {\n        \"show\": true,\n        \"showHistogram\": false\n      },\n      \"type\": \"heatmap\",\n      \"xAxis\": {\n        \"show\": true\n      },\n      \"yAxis\": {\n        \"decimals\": 1,\n        \"format\": \"s\",\n        \"logBase\": 1,\n        \"show\": true\n      },\n      \"yBucketBound\": \"auto\"\n    }\n  ],\n  \"refresh\": \"\",\n  \"schemaVersion\": 39,\n  \"tags\": [],\n  \"templating\": {\n    \"list\": [\n      {\n        \"current\": {\n          \"selected\": false,\n          \"text\": \"gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1\",\n          \"value\": \"gpu-txt-gen-cohereforai-c4ai-command-r-plu-ba7f1\"\n        },\n        \"datasource\": {\n          \"type\": \"prometheus\",\n          \"uid\": \"${DS_PROMETHEUS_EKS API INFERENCE PROD}\"\n        },\n        \"definition\": \"label_values(tgi_request_count, container)\",\n        \"hide\": 0,\n        \"includeAll\": false,\n        \"multi\": false,\n        \"name\": \"service\",\n        \"options\": [],\n        \"query\": {\n          \"query\": \"label_values(tgi_request_count, container)\",\n          \"refId\": \"StandardVariableQuery\"\n        },\n        \"refresh\": 1,\n        \"regex\": \"\",\n        \"skipUrlSync\": false,\n        \"sort\": 1,\n        \"type\": \"query\"\n      }\n    ]\n  },\n  \"time\": {\n    \"from\": \"now-30m\",\n    \"to\": \"now-30s\"\n  },\n  \"timepicker\": {\n    \"nowDelay\": \"30s\"\n  },\n  \"timezone\": \"\",\n  \"title\": \"Text Generation Inference\",\n  \"uid\": \"RHSk7EL4kdqsd\",\n  \"version\": 12,\n  \"weekStart\": \"\"\n}\n"
  },
  {
    "path": "backends/client/Cargo.toml",
    "content": "[package]\nname = \"text-generation-client\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[dependencies]\nasync-trait = \"^0.1\"\nbase64 = { workspace = true }\nfutures = \"^0.3\"\ngrpc-metadata = { path = \"../grpc-metadata\" }\nprost = \"^0.12\"\nthiserror = \"^1.0\"\ntokio = { version = \"^1.32\", features = [\"sync\"] }\ntonic = \"^0.10\"\ntower = \"^0.4\"\ntracing = \"^0.1\"\n\n[build-dependencies]\ntonic-build = \"0.10.1\"\nprost-build = \"0.12.1\"\n"
  },
  {
    "path": "backends/client/build.rs",
    "content": "use std::fs;\n\nfn main() -> Result<(), Box<dyn std::error::Error>> {\n    println!(\"cargo:rerun-if-changed=../../proto/\");\n\n    fs::create_dir_all(\"src/v2/pb\").unwrap_or(());\n    let mut config = prost_build::Config::new();\n    config.protoc_arg(\"--experimental_allow_proto3_optional\");\n\n    tonic_build::configure()\n        .build_client(true)\n        .build_server(false)\n        .out_dir(\"src/v2/pb\")\n        .include_file(\"mod.rs\")\n        .compile_with_config(config, &[\"../../proto/generate.proto\"], &[\"../../proto\"])\n        .map_err(|e| match e.kind(){\n            std::io::ErrorKind::NotFound => {panic!(\"`protoc` not found, install libprotoc\")},\n            std::io::ErrorKind::Other => {panic!(\"`protoc` version unsupported, upgrade protoc: https://github.com/protocolbuffers/protobuf/releases\")},\n            e => {e}\n        }).unwrap_or_else(|e| panic!(\"protobuf compilation failed: {e}\"));\n\n    fs::create_dir_all(\"src/v3/pb\").unwrap_or(());\n    let mut config = prost_build::Config::new();\n    config.protoc_arg(\"--experimental_allow_proto3_optional\");\n\n    tonic_build::configure()\n        .build_client(true)\n        .build_server(false)\n        .out_dir(\"src/v3/pb\")\n        .include_file(\"mod.rs\")\n        .compile_with_config(config, &[\"../../proto/v3/generate.proto\"], &[\"../../proto\"])\n        .unwrap_or_else(|e| panic!(\"protobuf compilation failed: {e}\"));\n\n    Ok(())\n}\n"
  },
  {
    "path": "backends/client/src/lib.rs",
    "content": "//! Text Generation gRPC client library\n\npub mod v2;\npub mod v3;\n\nuse async_trait::async_trait;\nuse base64::{engine::general_purpose::STANDARD, Engine};\nuse thiserror::Error;\nuse tonic::transport;\nuse tonic::Status;\n\npub use v3::{Chunk, Image, Input, InputChunk};\n\n#[async_trait]\npub trait Health {\n    /// Check if a generate server is healthy by asking it to allocate a tensor on device\n    async fn device_health(&self) -> Result<()>;\n\n    /// Check if a generate server is healthy by doing a forward pass.\n    /// EXPENSIVE\n    async fn model_health(&self) -> Result<()>;\n}\n\n#[derive(Debug)]\npub struct ShardInfo {\n    pub requires_padding: bool,\n    pub dtype: String,\n    pub device_type: String,\n    pub window_size: Option<u32>,\n    pub speculate: u32,\n}\n\n#[derive(Error, Debug, Clone)]\npub enum ClientError {\n    #[error(\"Could not connect to Text Generation server: {0}\")]\n    Connection(String),\n    #[error(\"Server error: {0}\")]\n    Generation(String),\n    #[error(\"Sharded results are empty\")]\n    EmptyResults,\n}\n\nimpl From<Status> for ClientError {\n    fn from(err: Status) -> Self {\n        let err = Self::Generation(err.message().to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\nimpl From<transport::Error> for ClientError {\n    fn from(err: transport::Error) -> Self {\n        let err = Self::Connection(err.to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\n// Small convenience re-wrapping of `Chunk`.\nimpl From<Chunk> for InputChunk {\n    fn from(chunk: Chunk) -> Self {\n        InputChunk { chunk: Some(chunk) }\n    }\n}\n\n/// Convert input chunks to a stringly-typed input for backwards\n/// compat for backends that haven't implemented chunked inputs.\npub trait ChunksToString {\n    /// Convert chunks to string.\n    fn chunks_to_string(&self) -> String;\n}\n\nimpl ChunksToString for Vec<InputChunk> {\n    fn chunks_to_string(&self) -> String {\n        let mut output = String::new();\n        self.iter().for_each(|c| match &c.chunk {\n            Some(Chunk::Text(text)) => output.push_str(text),\n            Some(Chunk::Image(Image { data, mimetype })) => {\n                let encoded = STANDARD.encode(data);\n                output.push_str(&format!(\"![](data:{};base64,{})\", mimetype, encoded))\n            }\n            // We don't create empty chunks, so this should be unreachable.\n            None => unreachable!(\"Chunks should never be empty\"),\n        });\n        output\n    }\n}\n\nstatic WARMUP_IMAGE_BASE64 :&str = \"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=\";\n\npub type Result<T> = std::result::Result<T, ClientError>;\n"
  },
  {
    "path": "backends/client/src/v2/client.rs",
    "content": "/// Single shard Client\nuse crate::v2::pb;\nuse crate::{ClientError, Result};\n\nuse crate::WARMUP_IMAGE_BASE64;\nuse grpc_metadata::InjectTelemetryContext;\nuse pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;\nuse pb::generate::v2::*;\nuse std::cmp::min;\nuse std::time::Duration;\nuse tonic::transport::{Channel, Uri};\nuse tracing::instrument;\n\n/// Text Generation Inference gRPC client\n#[derive(Debug, Clone)]\npub struct Client {\n    stub: TextGenerationServiceClient<Channel>,\n}\n\nimpl Client {\n    /// Returns a client connected to the given url\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let channel = Channel::builder(uri).connect().await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let channel = Channel::from_shared(\"http://[::]:50051\".to_string())\n            .unwrap()\n            .connect_with_connector(tower::service_fn(move |_: Uri| {\n                tokio::net::UnixStream::connect(path.clone())\n            }))\n            .await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a list of uris or unix sockets of all shards\n    #[instrument(skip(self))]\n    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {\n        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();\n        let response = self.stub.service_discovery(request).await.map_err(|_| {\n            ClientError::Connection(\"Server does not support v2 interface\".to_string())\n        })?;\n        let urls = response\n            .into_inner()\n            .urls\n            .into_iter()\n            // Remove unix socket prefix\n            .map(|url| match url.strip_prefix(\"unix://\") {\n                None => url,\n                Some(stripped_url) => stripped_url.to_string(),\n            })\n            .collect();\n        Ok(urls)\n    }\n\n    /// Get model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<InfoResponse> {\n        let request = tonic::Request::new(InfoRequest {}).inject_context();\n        let response = self.stub.info(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Get model health\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let request = tonic::Request::new(HealthRequest {}).inject_context();\n        let response = self.stub.health(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();\n        self.stub.clear_cache(request).await?;\n        Ok(())\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let request = tonic::Request::new(FilterBatchRequest {\n            batch_id,\n            request_ids,\n        })\n        .inject_context();\n        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();\n        Ok(filtered_batch.batch)\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip_all)]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: u32,\n        max_prefill_tokens: u32,\n        max_total_tokens: u32,\n        max_batch_size: Option<usize>,\n    ) -> Result<Option<u32>> {\n        let mut n_tokens = 0;\n        let mut requests = Vec::new();\n        // Create requests\n        while n_tokens < max_prefill_tokens {\n            let truncate = min(max_input_length, max_prefill_tokens - n_tokens);\n\n            let mut inputs = String::new();\n            inputs.push_str(&\"_test \".to_string().repeat(max_input_length as usize));\n            if n_tokens == 0 {\n                // 1 request is enough to test vision heads.\n                // Sending images on other queries messes up easily with truncation.\n                inputs.push_str(&format!(\n                    \"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})\",\n                ));\n            }\n\n            requests.push(Request {\n                id: 0,\n                inputs,\n                // We truncate the input on the server side to be sure that it has the correct size\n                truncate,\n                // Set sampling parameters to also take these ops into account in the max memory\n                parameters: Some(NextTokenChooserParameters {\n                    temperature: 0.9,\n                    top_k: 10,\n                    top_p: 0.9,\n                    typical_p: 0.9,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 1.2,\n                    frequency_penalty: 0.1,\n                    watermark: true,\n                    grammar: String::new(),\n                    grammar_type: GrammarType::None as i32,\n                }),\n                stopping_parameters: Some(StoppingCriteriaParameters {\n                    max_new_tokens: max_total_tokens - truncate,\n                    stop_sequences: vec![],\n                    ignore_eos_token: true,\n                }),\n                prefill_logprobs: true,\n                top_n_tokens: 20,\n            });\n            n_tokens += max_input_length;\n\n            // Check max_batch_size\n            if Some(requests.len()) == max_batch_size {\n                break;\n            }\n        }\n\n        let batch = Batch {\n            id: 0,\n            size: requests.len() as u32,\n            requests,\n            max_tokens: 0,\n        };\n\n        let request = tonic::Request::new(WarmupRequest {\n            batch: Some(batch),\n            max_input_length,\n            max_prefill_tokens,\n            max_total_tokens,\n        })\n        .inject_context();\n        let response = self.stub.warmup(request).await?.into_inner();\n        Ok(response.max_supported_total_tokens)\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();\n        let response = self.stub.prefill(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),\n        ))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();\n        let response = self.stub.decode(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            DecodeTimings::new(\n                response.concat_ns,\n                response.forward_ns,\n                response.decode_ns,\n                response.total_ns,\n            ),\n        ))\n    }\n}\n\npub struct PrefillTimings {\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl PrefillTimings {\n    fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n\npub struct DecodeTimings {\n    pub concat: Option<Duration>,\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl DecodeTimings {\n    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            concat: concat_ns.map(Duration::from_nanos),\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n"
  },
  {
    "path": "backends/client/src/v2/mod.rs",
    "content": "#[allow(clippy::derive_partial_eq_without_eq)]\nmod pb;\n\nmod client;\nmod sharded_client;\n\npub use client::Client;\npub use pb::generate::v2::HealthResponse;\npub use pb::generate::v2::{\n    Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, InfoResponse,\n    NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,\n};\npub use sharded_client::ShardedClient;\n"
  },
  {
    "path": "backends/client/src/v2/sharded_client.rs",
    "content": "/// Multi shard Client\nuse crate::{v2, Health, ShardInfo};\nuse crate::{ClientError, Result};\n\nuse crate::v2::InfoResponse;\nuse async_trait::async_trait;\nuse futures::future::join_all;\nuse tonic::transport::Uri;\nuse tracing::instrument;\nuse v2::client::{DecodeTimings, PrefillTimings};\nuse v2::{\n    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,\n    NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\n\n#[derive(Debug, Clone)]\n/// Text Generation Inference gRPC multi client\npub struct ShardedClient {\n    clients: Vec<Client>,\n}\n\nimpl ShardedClient {\n    fn new(clients: Vec<Client>) -> Self {\n        Self { clients }\n    }\n\n    /// Create a new ShardedClient from a master client. The master client will communicate with\n    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.\n    async fn from_master_client(mut master_client: Client) -> Result<Self> {\n        // Get all uris/unix sockets from the master client\n        let uris = master_client.service_discovery().await?;\n        let futures = uris.into_iter().map(Client::connect_uds);\n        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();\n        Ok(Self::new(clients?))\n    }\n\n    /// Returns a client connected to the given uri\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let master_client = Client::connect(uri).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let master_client = Client::connect_uds(path).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Get the model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<ShardInfo> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.info())\n            .collect();\n        join_all(futures).await.pop().unwrap().map(ShardInfo::from)\n    }\n\n    /// GRPC health check\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.health())\n            .collect();\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.clear_cache(batch_id))\n            .collect();\n        join_all(futures).await.into_iter().collect()\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))\n            .collect();\n        // all shards return the same message\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip(self))]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: u32,\n        max_prefill_tokens: u32,\n        max_total_tokens: u32,\n        max_batch_size: Option<usize>,\n    ) -> Result<Option<u32>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| {\n                Box::pin(client.warmup(\n                    max_input_length,\n                    max_prefill_tokens,\n                    max_total_tokens,\n                    max_batch_size,\n                ))\n            })\n            .collect();\n        // Take the minimum value\n        let results = join_all(futures)\n            .await\n            .into_iter()\n            .collect::<Result<Vec<Option<u32>>>>()?;\n        Ok(results.into_iter().flatten().min())\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.prefill(batch.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.decode(batches.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n}\n\nimpl From<InfoResponse> for ShardInfo {\n    fn from(value: InfoResponse) -> Self {\n        Self {\n            requires_padding: value.requires_padding,\n            dtype: value.dtype,\n            device_type: value.device_type,\n            window_size: value.window_size,\n            speculate: value.speculate,\n        }\n    }\n}\n\n#[async_trait]\nimpl Health for ShardedClient {\n    async fn device_health(&self) -> Result<()> {\n        self.clone().health().await?;\n        Ok(())\n    }\n\n    async fn model_health(&self) -> Result<()> {\n        // Dummy batch of 1 token and 1 generated token\n        let liveness_request = Request {\n            id: u64::MAX,\n            inputs: \"liveness\".to_string(),\n            truncate: 10,\n            prefill_logprobs: false,\n            parameters: Some(NextTokenChooserParameters {\n                temperature: 1.0,\n                top_k: 0,\n                top_p: 1.0,\n                typical_p: 1.0,\n                do_sample: false,\n                seed: 0,\n                repetition_penalty: 1.0,\n                frequency_penalty: 0.0,\n                watermark: false,\n                grammar: String::new(),\n                grammar_type: GrammarType::None as i32,\n            }),\n            stopping_parameters: Some(StoppingCriteriaParameters {\n                max_new_tokens: 1,\n                stop_sequences: vec![],\n                ignore_eos_token: false,\n            }),\n            top_n_tokens: 0,\n        };\n        let batch = Batch {\n            id: u64::MAX,\n            requests: vec![liveness_request],\n            size: 1,\n            max_tokens: 2,\n        };\n        self.clone().prefill(batch).await?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "backends/client/src/v3/client.rs",
    "content": "use crate::v3::{pb, Chunk};\nuse crate::{ClientError, Result, WARMUP_IMAGE_BASE64};\n/// Single shard Client\nuse base64::engine::general_purpose::STANDARD;\nuse base64::Engine;\nuse grpc_metadata::InjectTelemetryContext;\nuse pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;\nuse pb::generate::v3::*;\nuse std::cmp::min;\nuse std::time::Duration;\nuse tonic::transport::{Channel, Uri};\nuse tracing::instrument;\n\n/// Text Generation Inference gRPC client\n#[derive(Debug, Clone)]\npub struct Client {\n    stub: TextGenerationServiceClient<Channel>,\n}\n\nimpl Client {\n    /// Returns a client connected to the given url\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let channel = Channel::builder(uri).connect().await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let channel = Channel::from_shared(\"http://[::]:50051\".to_string())\n            .unwrap()\n            .connect_with_connector(tower::service_fn(move |_: Uri| {\n                tokio::net::UnixStream::connect(path.clone())\n            }))\n            .await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a list of uris or unix sockets of all shards\n    #[instrument(skip(self))]\n    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {\n        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();\n        let response = self.stub.service_discovery(request).await.map_err(|_| {\n            ClientError::Connection(\"Server does not support v3 interface\".to_string())\n        })?;\n        let urls = response\n            .into_inner()\n            .urls\n            .into_iter()\n            // Remove unix socket prefix\n            .map(|url| match url.strip_prefix(\"unix://\") {\n                None => url,\n                Some(stripped_url) => stripped_url.to_string(),\n            })\n            .collect();\n        Ok(urls)\n    }\n\n    /// Get model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<InfoResponse> {\n        let request = tonic::Request::new(InfoRequest {}).inject_context();\n        let response = self.stub.info(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Get model health\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let request = tonic::Request::new(HealthRequest {}).inject_context();\n        let response = self.stub.health(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();\n        self.stub.clear_cache(request).await?;\n        Ok(())\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let request = tonic::Request::new(FilterBatchRequest {\n            batch_id,\n            request_ids,\n        })\n        .inject_context();\n        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();\n        Ok(filtered_batch.batch)\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip_all)]\n    pub async fn warmup(\n        &mut self,\n        max_input_tokens: Option<u32>,\n        max_prefill_tokens: u32,\n        max_total_tokens: Option<u32>,\n        max_batch_size: Option<usize>,\n    ) -> Result<(Option<u32>, u32, u32)> {\n        let mut n_tokens = 0;\n        let mut requests = Vec::new();\n        // Create requests\n        while n_tokens < max_prefill_tokens {\n            let mut truncate = max_prefill_tokens - n_tokens;\n            if let Some(max_input_tokens) = max_input_tokens {\n                truncate = min(max_input_tokens, truncate);\n            }\n\n            let mut input_chunks = Vec::new();\n            input_chunks.push(Chunk::Text(\"_test \".to_string().repeat(truncate as usize)).into());\n            if n_tokens == 0 {\n                input_chunks.push(\n                    Chunk::Image(Image {\n                        // Safe unwrap, because we control the data.\n                        data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),\n                        mimetype: \"image/jpeg;base64\".to_string(),\n                    })\n                    .into(),\n                );\n            }\n\n            // Send stringly-typed inputs for compatibility for backends that haven't\n            // been updated to support chunks.\n\n            let mut inputs = String::new();\n            inputs.push_str(&\"_test \".to_string().repeat(truncate as usize));\n            if n_tokens == 0 {\n                // 1 request is enough to test vision heads.\n                // Sending images on other queries messes up easily with truncation.\n                inputs.push_str(&format!(\n                    \"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})\",\n                ));\n            }\n\n            let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {\n                max_total_tokens - truncate\n            } else {\n                1\n            };\n\n            requests.push(Request {\n                id: 0,\n                inputs,\n                input_chunks: Some(Input {\n                    chunks: input_chunks,\n                }),\n                // We truncate the input on the server side to be sure that it has the correct size\n                truncate,\n                // Most request will have that\n                add_special_tokens: true,\n                // Blocks and slots will be set on the server side if we use paged attention\n                blocks: vec![],\n                slots: vec![],\n                cache_len: 0,\n                chunk_len: None,\n                // Set sampling parameters to also take these ops into account in the max memory\n                parameters: Some(NextTokenChooserParameters {\n                    temperature: 0.9,\n                    top_k: 10,\n                    top_p: 0.9,\n                    typical_p: 0.9,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 1.2,\n                    frequency_penalty: 0.1,\n                    watermark: true,\n                    grammar: String::new(),\n                    grammar_type: GrammarType::None as i32,\n                }),\n                stopping_parameters: Some(StoppingCriteriaParameters {\n                    max_new_tokens,\n                    stop_sequences: vec![],\n                    ignore_eos_token: true,\n                }),\n                prefill_logprobs: true,\n                top_n_tokens: 20,\n                adapter_id: None,\n            });\n            n_tokens += truncate;\n\n            // Check max_batch_size\n            if Some(requests.len()) == max_batch_size {\n                break;\n            }\n        }\n\n        let batch = Batch {\n            id: 0,\n            size: requests.len() as u32,\n            requests,\n            max_tokens: max_input_tokens.unwrap_or(0),\n            max_blocks: 0,\n        };\n\n        let request = tonic::Request::new(WarmupRequest {\n            batch: Some(batch),\n            max_input_tokens,\n            max_prefill_tokens,\n            max_total_tokens,\n        })\n        .inject_context();\n        let response = self.stub.warmup(request).await?.into_inner();\n        Ok((\n            response.max_supported_total_tokens,\n            response.max_input_tokens,\n            response.max_total_tokens,\n        ))\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n        cached_batch: Option<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let request = tonic::Request::new(PrefillRequest {\n            batch: Some(batch),\n            cached_batch,\n        })\n        .inject_context();\n        let response = self.stub.prefill(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),\n        ))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();\n        let response = self.stub.decode(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            DecodeTimings::new(\n                response.concat_ns,\n                response.forward_ns,\n                response.decode_ns,\n                response.total_ns,\n            ),\n        ))\n    }\n}\n\npub struct PrefillTimings {\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl PrefillTimings {\n    fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n\npub struct DecodeTimings {\n    pub concat: Option<Duration>,\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl DecodeTimings {\n    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            concat: concat_ns.map(Duration::from_nanos),\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n"
  },
  {
    "path": "backends/client/src/v3/mod.rs",
    "content": "#[allow(clippy::derive_partial_eq_without_eq)]\nmod pb;\n\nmod client;\nmod sharded_client;\n\npub use client::Client;\npub use pb::generate::v3::{\n    input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,\n    HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,\n    StoppingCriteriaParameters, Tokens,\n};\npub use sharded_client::ShardedClient;\n"
  },
  {
    "path": "backends/client/src/v3/sharded_client.rs",
    "content": "/// Multi shard Client\nuse crate::{v3, Health, ShardInfo};\nuse crate::{ClientError, Result};\n\nuse crate::v3::{Chunk, InfoResponse, Input};\nuse async_trait::async_trait;\nuse futures::future::join_all;\nuse tonic::transport::Uri;\nuse tracing::instrument;\nuse v3::client::{DecodeTimings, PrefillTimings};\nuse v3::{\n    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,\n    NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\n\n#[derive(Debug, Clone)]\n/// Text Generation Inference gRPC multi client\npub struct ShardedClient {\n    clients: Vec<Client>,\n}\n\nimpl ShardedClient {\n    fn new(clients: Vec<Client>) -> Self {\n        Self { clients }\n    }\n\n    /// Create a new ShardedClient from a master client. The master client will communicate with\n    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.\n    async fn from_master_client(mut master_client: Client) -> Result<Self> {\n        // Get all uris/unix sockets from the master client\n        let uris = master_client.service_discovery().await?;\n        let futures = uris.into_iter().map(Client::connect_uds);\n        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();\n        Ok(Self::new(clients?))\n    }\n\n    /// Returns a client connected to the given uri\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let master_client = Client::connect(uri).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let master_client = Client::connect_uds(path).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Get the model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<ShardInfo> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.info())\n            .collect();\n        join_all(futures).await.pop().unwrap().map(ShardInfo::from)\n    }\n\n    /// GRPC health check\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.health())\n            .collect();\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.clear_cache(batch_id))\n            .collect();\n        join_all(futures).await.into_iter().collect()\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))\n            .collect();\n        // all shards return the same message\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip(self))]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: Option<u32>,\n        max_prefill_tokens: u32,\n        max_total_tokens: Option<u32>,\n        max_batch_size: Option<usize>,\n    ) -> Result<(Option<u32>, u32, u32)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| {\n                Box::pin(client.warmup(\n                    max_input_length,\n                    max_prefill_tokens,\n                    max_total_tokens,\n                    max_batch_size,\n                ))\n            })\n            .collect();\n        // Take the minimum value\n        let results = join_all(futures)\n            .await\n            .into_iter()\n            .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;\n\n        // Take the minimum value\n        // Different shards hold different parts of vocab, might yield\n        // different available block size.\n        let min = results\n            .iter()\n            .min()\n            .expect(\"Expect at least 1 warmup result\");\n        Ok(*min)\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n        cached_batch: Option<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.decode(batches.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n}\n\nimpl From<InfoResponse> for ShardInfo {\n    fn from(value: InfoResponse) -> Self {\n        Self {\n            requires_padding: value.requires_padding,\n            dtype: value.dtype,\n            device_type: value.device_type,\n            window_size: value.window_size,\n            speculate: value.speculate,\n        }\n    }\n}\n\n#[async_trait]\nimpl Health for ShardedClient {\n    async fn device_health(&self) -> Result<()> {\n        self.clone().health().await?;\n        Ok(())\n    }\n\n    async fn model_health(&self) -> Result<()> {\n        // Dummy batch of 1 token and 1 generated token\n        let liveness_request = Request {\n            id: u64::MAX,\n            inputs: \"liveness\".to_string(),\n            input_chunks: Some(Input {\n                chunks: vec![Chunk::Text(\"liveness\".into()).into()],\n            }),\n            truncate: 10,\n            add_special_tokens: true,\n            prefill_logprobs: false,\n            parameters: Some(NextTokenChooserParameters {\n                temperature: 1.0,\n                top_k: 0,\n                top_p: 1.0,\n                typical_p: 1.0,\n                do_sample: false,\n                seed: 0,\n                repetition_penalty: 1.0,\n                frequency_penalty: 0.0,\n                watermark: false,\n                grammar: String::new(),\n                grammar_type: GrammarType::None as i32,\n            }),\n            stopping_parameters: Some(StoppingCriteriaParameters {\n                max_new_tokens: 1,\n                stop_sequences: vec![],\n                ignore_eos_token: false,\n            }),\n            top_n_tokens: 0,\n            // Block 0 is reserved for health checks\n            blocks: vec![0],\n            slots: (0..16).collect(),\n            cache_len: 0,\n            chunk_len: None,\n            adapter_id: None,\n        };\n        let batch = Batch {\n            id: u64::MAX,\n            requests: vec![liveness_request],\n            size: 1,\n            max_tokens: 2,\n            max_blocks: 1,\n        };\n        self.clone().prefill(batch, None).await?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "backends/gaudi/Makefile",
    "content": "mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))\nmkfile_dir := $(dir $(mkfile_path))\nroot_dir := ${mkfile_dir}/../..\n\nHABANA_VERSION := 1.21.0\nPYTORCH_VERSION := 2.6.0\n\n.PHONY:\timage run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install\n\nimage:\n\tdocker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)\n\nrun-local-dev-container:\n\t\tdocker run -it \\\n\t\t--runtime=habana \\\n\t\t--ipc=host \\\n\t\t--cap-add=sys_nice \\\n\t\t--net=host \\\n\t\t-e HABANA_VISIBLE_DEVICES=all \\\n\t\t-e OMPI_MCA_btl_vader_single_copy_mechanism=none \\\n\t\t-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \\\n\t\t-e HF_TOKEN=`cat /home/ubuntu/.cache/huggingface/token` \\\n\t\t-e LOG_LEVEL=debug \\\n\t\t-e PORT=8080 \\\n\t\t-v /home/ubuntu/.cache/huggingface:/data \\\n\t\t-v $(PWD):/text-generation-inference \\\n\t\t-w /text-generation-inference \\\n\t\tvault.habana.ai/gaudi-docker/$(HABANA_VERSION)/ubuntu22.04/habanalabs/pytorch-installer-$(PYTORCH_VERSION):latest\n\ninstall-dependencies:\n\tpip install git+https://github.com/HabanaAI/DeepSpeed.git@$(HABANA_VERSION)\n\tpip install outlines~=0.0.34\n\tcurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y\n\ninstall-server:\n\tmake -C ${root_dir}/backends/gaudi/server install PROTO_PATH=../../../proto/v3\n\ninstall-router:\n\tmake -C ${root_dir} install-router\n\ninstall-launcher:\n\tmake -C ${root_dir} install-launcher\n\n# use source to load the rust in path\nlocal-dev-install: install-dependencies\n\tbash -c 'source \"$$HOME/.cargo/env\" && \\\n\t\tmake install-server && \\\n\t\tmake install-router && \\\n\t\tmake install-launcher'\n\n# In order to run the integration tests, you need to first build the image (make -C backends/gaudi image)\nrun-integration-tests:\n\tDOCKER_VOLUME=${root_dir}/data \\\n\tHF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \\\n    pytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi\n\nrun-integration-tests-with-all-models:\n\tDOCKER_VOLUME=${root_dir}/data \\\n\tHF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \\\n\tpytest --durations=0 -s -vv ${root_dir}/integration-tests --gaudi --gaudi-all-models\n\n# This is used to capture the expected outputs for the integration tests offering an easy way to add more models to the integration tests\ncapture-expected-outputs-for-integration-tests:\n\tpip install -U pip uv\n\tDOCKER_VOLUME=${root_dir}/data \\\n\tHF_TOKEN=`cat ${HOME}/.cache/huggingface/token` \\\n\tuv run pytest --durations=0 -sv ${root_dir}/backends/gaudi/server/integration-tests/capture_expected_outputs.py\n"
  },
  {
    "path": "backends/gaudi/README.md",
    "content": "# Text-generation-inference - Gaudi backend\n\n## Description\n\nThis is the TGI backend for Intel Gaudi. This backend is composed of the tgi server optimized for Gaudi hardware.\n\n## Build your own image\n\nThe simplest way to build TGI with the Gaudi backend is to use the provided `Makefile`:\n\nOption 1: From the project root directory:\n```bash\nmake -C backends/gaudi image\n```\n\nOption 2: From the Gaudi backend directory:\n```bash\ncd backends/gaudi\nmake image\n```\n\nYou can now run the server with the following command:\n\nOption 1: Sharded:\n```bash\nmodel=meta-llama/Llama-3.1-8B-Instruct\nhf_token=$(cat ${HOME}/.cache/huggingface/token)\nvolume=${HOME}/.cache/huggingface\n\ndocker run --runtime=habana --ipc=host --cap-add=sys_nice \\\n  -p 8080:80 -v $volume:/data \\\n  -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \\\n  tgi-gaudi --model-id $model \\\n  --sharded true --num-shard 8 \\\n  --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 8 --max-batch-prefill-tokens 2048\n```\n\nOption 2: Non-sharded:\n```bash\nmodel=meta-llama/Llama-3.1-8B-Instruct\nhf_token=$(cat ${HOME}/.cache/huggingface/token)\nvolume=${HOME}/.cache/huggingface\n\ndocker run --runtime=habana --ipc=host --cap-add=sys_nice \\\n  -p 8080:80 -v $volume:/data \\\n  -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \\\n  tgi-gaudi --model-id $model \\\n  --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048\n```\n\n## Contributing\n\n### Local Development\n\nThis is useful if you want to run the server locally for better debugging.\n```bash\nmake -C backends/gaudi run-local-dev-container\n```\n\nThen run the following command inside the container to install tgi for gaudi:\n```bash\nmake -C backends/gaudi local-dev-install\n```\n\nAdd rust to path:\n```bash\n. \"$HOME/.cargo/env\"\n```\n\nOption 1: Run the server (sharded model):\n```bash\nLOG_LEVEL=debug text-generation-launcher \\\n    --model-id meta-llama/Llama-3.1-8B-Instruct \\\n    --sharded true \\\n    --num-shard 8 \\\n    --max-input-tokens 512 \\\n    --max-total-tokens 1024 \\\n    --max-batch-size 8 \\\n    --max-batch-prefill-tokens 2048\n```\n\nOption 2: Run the server (non-sharded model):\n```bash\nLOG_LEVEL=debug text-generation-launcher \\\n    --model-id meta-llama/Llama-3.1-8B-Instruct \\\n    --max-input-tokens 512 \\\n    --max-total-tokens 1024 \\\n    --max-batch-size 4 \\\n    --max-batch-prefill-tokens 2048\n```\n\nYou can then test the server with the following curl command from another terminal (can be outside the container):\n```bash\ncurl 127.0.0.1:8080/generate \\\n     -X POST \\\n     -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n     -H 'Content-Type: application/json'\n```\n\n### Integration tests\n\nInstall the dependencies:\n```bash\npip install -r integration-tests/requirements.txt\n```\n\nTo run the integration tests, you need to first build the image:\n```bash\nmake -C backends/gaudi image\n```\n\nThen run the following command to run the integration tests (CI tests):\n```bash\nmake -C backends/gaudi run-integration-tests\n```\n\nTo run the integration tests with all models, you can run the following command:\n```bash\nmake -C backends/gaudi run-integration-tests-with-all-models\n```\n\nTo capture the expected outputs for the integration tests, you can run the following command:\n```bash\nmake -C backends/gaudi capture-expected-outputs-for-integration-tests\n```\n\n#### How the integration tests works\nThe integration tests works as follows:\n\n1. Start a tgi server in a container, similar to the command:\n```bash\ndocker run --runtime=habana --ipc=host --cap-add=sys_nice \\\n  -p 8080:80 -v $volume:/data \\\n  -e LOG_LEVEL=debug -e HF_TOKEN=$hf_token \\\n  tgi-gaudi --model-id $model \\\n  --max-input-tokens 512 --max-total-tokens 1024 --max-batch-size 4 --max-batch-prefill-tokens 2048\n```\n\n2. Do a /generate request to the server, similar to the command:\n```bash\ncurl 127.0.0.1:8080/generate \\\n     -X POST \\\n     -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n     -H 'Content-Type: application/json'\n```\n\n3. Check the output of the server against the expected output:\n```python\nassert curl_output == expected_output\n```\n\nThis is the repeated for a set of models and configurations.\n"
  },
  {
    "path": "backends/gaudi/examples/docker_commands/docker_commands.md",
    "content": "# Examples of Docker Commands for Gaudi Backend\n\nThis page gives a list of examples of docker run commands for some of the most popular models.\n\n> **Note:** The parameters are chosen for Gaudi2 hardware to maximize performance on this given hardware, please adjust the parameters based on your hardware. For example, if you are using Gaudi3, you may want to increase the batch size.\n\n## Default Precision (BF16)\n\n### Llama3.1-8B on 1 card (BF16)\n\n```bash\nmodel=meta-llama/Meta-Llama-3.1-8B-Instruct\nhf_token=YOUR_ACCESS_TOKEN\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   -e HF_TOKEN=$hf_token \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --max-input-tokens 1024 --max-total-tokens 2048 \\\n   --max-batch-prefill-tokens 2048 --max-batch-size 32 \\\n   --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64\n```\n\n### Llama3.1-70B 8 cards (BF16)\n\n```bash\nmodel=meta-llama/Meta-Llama-3.1-70B-Instruct\nhf_token=YOUR_ACCESS_TOKEN\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   -e HF_TOKEN=$hf_token \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --sharded true --num-shard 8 \\\n   --max-input-tokens 1024 --max-total-tokens 2048 \\\n   --max-batch-prefill-tokens 4096 --max-batch-size 256 \\\n   --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512\n```\n\n### Llava-v1.6-Mistral-7B on 1 card (BF16)\n\n```bash\nmodel=llava-hf/llava-v1.6-mistral-7b-hf\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \\\n   --max-total-tokens 8192 --max-batch-size 4\n```\n\n## FP8 Precision\n\nYou could also set kv cache dtype to FP8 when launching the server, fp8_e4m3fn is supported in Gaudi\n\n## Llama3-8B on 1 Card (FP8)\n\n```bash\nmodel=RedHatAI/Meta-Llama-3-8B-Instruct-FP8-KV\nhf_token=YOUR_ACCESS_TOKEN\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   -e HF_TOKEN=$hf_token \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --kv-cache-dtype fp8_e4m3fn \\\n   --max-input-tokens 1024 --max-total-tokens 2048 \\\n   --max-batch-prefill-tokens 2048 --max-batch-size 32 \\\n   --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 64\n```\n\n## Llama3-70B on 8 cards (FP8)\n\n```bash\nmodel=RedHatAI/Meta-Llama-3-70B-Instruct-FP8\nhf_token=YOUR_ACCESS_TOKEN\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   -e HF_TOKEN=$hf_token \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --kv-cache-dtype fp8_e4m3fn \\\n   --sharded true --num-shard 8 \\\n   --max-input-tokens 1024 --max-total-tokens 2048 \\\n   --max-batch-prefill-tokens 4096 --max-batch-size 256 \\\n   --max-waiting-tokens 7 --waiting-served-ratio 1.2 --max-concurrent-requests 512\n```\n"
  },
  {
    "path": "backends/gaudi/server/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\ntext_generation_server/__pycache__/\ntext_generation_server/pb/__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\ntransformers\nsafetensors\nflash-attention/\nflash-attention-v2/\nvllm/\nllm-awq/\neetq/\nmamba/\n"
  },
  {
    "path": "backends/gaudi/server/Makefile",
    "content": "include Makefile-flash-att\ninclude Makefile-flash-att-v2\ninclude Makefile-vllm\ninclude Makefile-awq\ninclude Makefile-eetq\ninclude Makefile-selective-scan\n\nPROTO_PATH ?= ../proto/v3\n\nunit-tests:\n\tpytest -s -vv -m \"not private\" tests\n\ngen-server:\n\t# Compile protos\n\tpip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir\n\tmkdir text_generation_server/pb || true\n\tpython -m grpc_tools.protoc -I$(PROTO_PATH) --python_out=text_generation_server/pb \\\n\t\t--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb $(PROTO_PATH)/generate.proto\n\tfind text_generation_server/pb/ -type f -name \"*.py\" -print0 -exec sed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' {} \\;\n\ttouch text_generation_server/pb/__init__.py\n\ninstall: gen-server\n\tpip install pip --upgrade\n\tpip install --no-deps -r requirements.txt\n\tpip install -e \".\"\n\nrun-dev:\n\tSAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded\n\ninstall-poetry:\n\tcurl -sSL https://install.python-poetry.org | python3 -\n\nupdate-lock:\n\trm poetry.lock\n\tpoetry lock --no-update\n\nexport-requirements:\n\tpoetry export -o requirements.txt --without-hashes\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-awq",
    "content": "# Fork that adds only the correct stream to this kernel in order\n# to make cuda graphs work.\nawq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4\n\nawq:\n\trm -rf llm-awq\n\tgit clone https://github.com/huggingface/llm-awq\n\nbuild-awq: awq\n\tcd llm-awq/ && git fetch && git checkout $(awq_commit)\n\tcd llm-awq/awq/kernels && python setup.py build\n\ninstall-awq: build-awq\n\tpip uninstall awq_inference_engine -y || true\n\tcd llm-awq/awq/kernels && python setup.py install\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-eetq",
    "content": "eetq_commit := 1657b1504faa359e2ce0ac02999439d7ac8c74c0\n\neetq:\n    # Clone eetq\n\tpip install packaging\n\tgit clone https://github.com/NetEase-FuXi/EETQ.git eetq\n\nbuild-eetq: eetq\n\tcd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive\n\tcd eetq && python setup.py build\n\ninstall-eetq: build-eetq\n\tcd eetq && python setup.py install\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-fbgemm",
    "content": "fbgemm_commit := v0.8.0\n\nbuild-fbgemm:\n\t@if [ ! -d \"fbgemm\" ]; then \\\n\t\tgit clone https://github.com/pytorch/FBGEMM.git fbgemm; \\\n\tfi\n\tcd fbgemm && git fetch && git checkout $(fbgemm_commit)  && \\\n\tgit submodule update --init --recursive && \\\n\tcd fbgemm_gpu && \\\n\tpip install -r requirements.txt && \\\n\tCUDA_ARCH_LIST=\"8.0;9.0a\" NVCC_GENCODE=\"-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a\" TORCH_CUDA_ARCH_LIST=\"8.0;9.0a\" python setup.py --package_variant genai build\n\ninstall-fbgemm: build-fbgemm\n\tcd fbgemm/fbgemm_gpu &&  \\\n\tCUDA_ARCH_LIST=\"8.0;9.0a\" NVCC_GENCODE=\"-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a\" TORCH_CUDA_ARCH_LIST=\"8.0;9.0a\" python setup.py --package_variant genai install\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-flash-att",
    "content": "flash_att_commit := 3a9bfd076f98746c73362328958dbc68d145fbec\n\nbuild-flash-attention:\n\tif [ ! -d 'flash-attention' ]; then \\\n\t\tpip install -U packaging ninja  --no-cache-dir && \\\n\t\tgit clone https://github.com/HazyResearch/flash-attention.git; \\\n\tfi\n\tcd flash-attention && git fetch && git checkout $(flash_att_commit) && \\\n\tMAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build\n\ninstall-flash-attention: build-flash-attention\n\tcd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-flash-att-v2",
    "content": "flash_att_v2_commit_cuda := v2.6.1\nflash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4\n\nbuild-flash-attention-v2-cuda:\n\tpip install -U packaging wheel\n\tpip install flash-attn==$(flash_att_v2_commit_cuda)\n\ninstall-flash-attention-v2-cuda: build-flash-attention-v2-cuda\n\techo \"Flash v2 installed\"\n\nbuild-flash-attention-v2-rocm:\n\tif [ ! -d 'flash-attention-v2' ]; then \\\n\t\tpip install -U packaging ninja  --no-cache-dir && \\\n\t\tgit clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \\\n\t\tcd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \\\n\t\tgit submodule update --init --recursive && GPU_ARCHS=\"gfx90a;gfx942\" PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python setup.py build; \\\n\tfi\n\ninstall-flash-attention-v2-rocm: build-flash-attention-v2-rocm\n\tcd flash-attention-v2 &&  \\\n\tGPU_ARCHS=\"gfx90a;gfx942\" PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python setup.py install\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-selective-scan",
    "content": "selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137\n\ncausal-conv1d:\n\trm -rf causal-conv1d\n\tgit clone https://github.com/Dao-AILab/causal-conv1d.git\n\nbuild-causal-conv1d: causal-conv1d\n\tcd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag\n\tcd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build\n\ninstall-causal-conv1d: build-causal-conv1d\n\tpip uninstall causal-conv1d -y || true\n\tcd causal-conv1d/ && pip install .\n\n# selective-scan dependends on causal-conv1d\nselective-scan:\n\trm -rf mamba\n\tgit clone https://github.com/state-spaces/mamba.git mamba\n\nbuild-selective-scan: selective-scan\n\tcd mamba/ && git fetch && git checkout $(selective_scan_commit)\n\tcd mamba && python setup.py build\n\ninstall-selective-scan: install-causal-conv1d build-selective-scan\n\tpip uninstall selective-scan-cuda -y || true\n\tcd mamba && pip install .\n\nbuild-all: build-causal-conv1d build-selective-scan\n"
  },
  {
    "path": "backends/gaudi/server/Makefile-vllm",
    "content": "commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b\ncommit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247\nbuild-vllm-cuda:\n\tif [ ! -d 'vllm' ]; then \\\n\t\tpip install -U ninja packaging --no-cache-dir && \\\n\t\tgit clone https://github.com/Narsil/vllm.git vllm; \\\n\tfi\n\tcd vllm  && git fetch origin && git checkout $(commit_cuda) && python setup.py build\n\ninstall-vllm-cuda: build-vllm-cuda\n\tcd vllm  && git fetch origin && git checkout $(commit_cuda) && pip install -e .\n\nbuild-vllm-rocm:\n\tif [ ! -d 'vllm' ]; then \\\n\t\tpip install -U ninja packaging --no-cache-dir && \\\n\t\tgit clone https://github.com/mht-sharma/vllm.git vllm; \\\n\tfi\n\tcd vllm && git fetch && git checkout $(commit_rocm) &&  \\\n\tPYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python setup.py build\n\ninstall-vllm-rocm: build-vllm-rocm\n\tcd vllm && git fetch && git checkout $(commit_rocm) && \\\n\tPYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" pip install -e .\n"
  },
  {
    "path": "backends/gaudi/server/README.md",
    "content": "# Text Generation Inference Python gRPC Server\n\nA Python gRPC server for Text Generation Inference\n\n## Install\n\n```shell\nmake install\n```\n\n## Run\n\n```shell\nmake run-dev\n```\n"
  },
  {
    "path": "backends/gaudi/server/dill-0.3.7-patch.sh",
    "content": "#!/bin/bash\ngit clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git\npushd dill\ncat <<EOF > dill-0.3.7.patch\ndiff --git a/dill/_dill.py b/dill/_dill.py\nindex d0cf543..f6eb662 100644\n--- a/dill/_dill.py\n+++ b/dill/_dill.py\n@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered\n XRangeType = range\n from types import MappingProxyType as DictProxyType, new_class\n from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError\n-import __main__ as _main_module\n+class _LazyMainModule(object):\n+    _module = None\n+    @property\n+    def module(self):\n+        if self._module is None:\n+            import __main__ as _m_module\n+            self._module = _m_module\n+        return self._module\n+_main_module = _LazyMainModule()\n import marshal\n import gc\n # import zlib\n@@ -353,7 +361,7 @@ class Pickler(StockPickler):\n         _fmode = kwds.pop('fmode', None)\n         _recurse = kwds.pop('recurse', None)\n         StockPickler.__init__(self, file, *args, **kwds)\n-        self._main = _main_module\n+        self._main = _main_module.module\n         self._diff_cache = {}\n         self._byref = settings['byref'] if _byref is None else _byref\n         self._strictio = False #_strictio\n@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):\n         settings = Pickler.settings\n         _ignore = kwds.pop('ignore', None)\n         StockUnpickler.__init__(self, *args, **kwds)\n-        self._main = _main_module\n+        self._main = _main_module.module\n         self._ignore = settings['ignore'] if _ignore is None else _ignore\n\n     def load(self): #NOTE: if settings change, need to update attributes\n         obj = StockUnpickler.load(self)\n-        if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):\n+        if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):\n             if not self._ignore:\n                 # point obj class to main\n                 try: obj.__class__ = getattr(self._main, type(obj).__name__)\n@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):\n         logger.trace(pickler, \"D1: %s\", _repr_dict(obj)) # obj\n         pickler.write(bytes('c__builtin__\\n__main__\\n', 'UTF-8'))\n         logger.trace(pickler, \"# D1\")\n-    elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):\n+    elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):\n         logger.trace(pickler, \"D3: %s\", _repr_dict(obj)) # obj\n         pickler.write(bytes('c__main__\\n__dict__\\n', 'UTF-8'))  #XXX: works in general?\n         logger.trace(pickler, \"# D3\")\n-    elif '__name__' in obj and obj != _main_module.__dict__ \\\\\n+    elif '__name__' in obj and obj != _main_module.module.__dict__ \\\\\n             and type(obj['__name__']) is str \\\\\n             and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):\n         logger.trace(pickler, \"D4: %s\", _repr_dict(obj)) # obj\ndiff --git a/dill/session.py b/dill/session.py\nindex 74234ab..1be8d89 100644\n--- a/dill/session.py\n+++ b/dill/session.py\n@@ -233,7 +233,7 @@ def dump_module(\n     protocol = settings['protocol']\n     main = module\n     if main is None:\n-        main = _main_module\n+        main = _main_module.module\n     elif isinstance(main, str):\n         main = _import_module(main)\n     if not isinstance(main, ModuleType):\n@@ -501,7 +501,7 @@ def load_module(\n             pass\n     assert loaded is main\n     _restore_modules(unpickler, main)\n-    if main is _main_module or main is module:\n+    if main is _main_module.module or main is module:\n         return None\n     else:\n         return main\n\nEOF\ngit apply dill-0.3.7.patch\npython -m pip install .\npopd\nrm -fr dill\n"
  },
  {
    "path": "backends/gaudi/server/dill-0.3.8-patch.sh",
    "content": "#!/bin/bash\ngit clone -b 0.3.8 https://github.com/uqfoundation/dill.git\npushd dill\ncat <<EOF > dill-0.3.8.patch\ndiff --git a/dill/_dill.py b/dill/_dill.py\nindex d42432f..1d251e6 100644\n--- a/dill/_dill.py\n+++ b/dill/_dill.py\n@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered\n XRangeType = range\n from types import MappingProxyType as DictProxyType, new_class\n from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError\n-import __main__ as _main_module\n+class _LazyMainModule(object):\n+    _module = None\n+    @property\n+    def module(self):\n+        if self._module is None:\n+            import __main__ as _m_module\n+            self._module = _m_module\n+        return self._module\n+_main_module = _LazyMainModule()\n import marshal\n import gc\n # import zlib\n@@ -355,7 +363,7 @@ class Pickler(StockPickler):\n         _fmode = kwds.pop('fmode', None)\n         _recurse = kwds.pop('recurse', None)\n         StockPickler.__init__(self, file, *args, **kwds)\n-        self._main = _main_module\n+        self._main = _main_module.module\n         self._diff_cache = {}\n         self._byref = settings['byref'] if _byref is None else _byref\n         self._strictio = False #_strictio\n@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):\n         settings = Pickler.settings\n         _ignore = kwds.pop('ignore', None)\n         StockUnpickler.__init__(self, *args, **kwds)\n-        self._main = _main_module\n+        self._main = _main_module.module\n         self._ignore = settings['ignore'] if _ignore is None else _ignore\n\n     def load(self): #NOTE: if settings change, need to update attributes\n         obj = StockUnpickler.load(self)\n-        if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):\n+        if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):\n             if not self._ignore:\n                 # point obj class to main\n                 try: obj.__class__ = getattr(self._main, type(obj).__name__)\n@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):\n         logger.trace(pickler, \"D1: %s\", _repr_dict(obj)) # obj\n         pickler.write(bytes('c__builtin__\\n__main__\\n', 'UTF-8'))\n         logger.trace(pickler, \"# D1\")\n-    elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):\n+    elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):\n         logger.trace(pickler, \"D3: %s\", _repr_dict(obj)) # obj\n         pickler.write(bytes('c__main__\\n__dict__\\n', 'UTF-8'))  #XXX: works in general?\n         logger.trace(pickler, \"# D3\")\n-    elif '__name__' in obj and obj != _main_module.__dict__ \\\\\n+    elif '__name__' in obj and obj != _main_module.module.__dict__ \\\\\n             and type(obj['__name__']) is str \\\\\n             and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):\n         logger.trace(pickler, \"D4: %s\", _repr_dict(obj)) # obj\ndiff --git a/dill/session.py b/dill/session.py\nindex e91068a..a921b43 100644\n--- a/dill/session.py\n+++ b/dill/session.py\n@@ -233,7 +233,7 @@ def dump_module(\n     protocol = settings['protocol']\n     main = module\n     if main is None:\n-        main = _main_module\n+        main = _main_module.module\n     elif isinstance(main, str):\n         main = _import_module(main)\n     if not isinstance(main, ModuleType):\n@@ -501,7 +501,7 @@ def load_module(\n             pass\n     assert loaded is main\n     _restore_modules(unpickler, main)\n-    if main is _main_module or main is module:\n+    if main is _main_module.module or main is module:\n         return None\n     else:\n         return main\n\nEOF\ngit apply dill-0.3.8.patch\npython -m pip install .\npopd\nrm -fr dill\n"
  },
  {
    "path": "backends/gaudi/server/pyproject.toml",
    "content": "[tool.poetry]\nname = \"text-generation-server\"\nversion = \"2.0.4\"\ndescription = \"Text Generation Inference Python gRPC Server\"\nauthors = [\"Olivier Dehaene <olivier@huggingface.co>\"]\n\n[tool.poetry.scripts]\ntext-generation-server = 'text_generation_server.cli:app'\n\n[tool.poetry.dependencies]\npython = \">=3.9,<3.13\"\nprotobuf = \"^5.0\"\ngrpcio = \"^1.71.1\"\ngrpcio-status = \"*\"\ngrpcio-reflection = \"*\"\ngrpc-interceptor = \"^0.15.0\"\ntyper = \"^0.15.0\"\nloguru = \"^0.7.3\"\nopentelemetry-api = \"^1.32.0\"\nopentelemetry-exporter-otlp = \"^1.32.0\"\nopentelemetry-instrumentation-grpc = \"^0.53b0\"\nhf-transfer = \"^0.1.9\"\nsentencepiece = \"^0.2.0\"\npeft = \"^0.15\"\ntransformers = \"^4.52.4\"\nnumpy = \"^1.26\"\naccelerate = \"^1.7.0\"\noutlines= { version = \"^0.0.36\", optional = true }\nprometheus-client = \"^0.21.1\"\npy-cpuinfo = \"^9.0.0\"\n\n[tool.poetry.group.dev.dependencies]\ngrpcio-tools = \"*\"\npytest = \"^8.3.5\"\n\n[tool.pytest.ini_options]\nmarkers = [\"private: marks tests as requiring an admin hf token (deselect with '-m \\\"not private\\\"')\"]\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.poetry.requires-plugins]\npoetry-plugin-export = \">=1.8\"\n"
  },
  {
    "path": "backends/gaudi/server/requirements.txt",
    "content": "accelerate==1.7.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nannotated-types==0.7.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nattrs==25.3.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ncertifi==2025.1.31 ; python_version >= \"3.9\" and python_version < \"3.13\"\ncharset-normalizer==3.4.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nclick==8.1.8 ; python_version >= \"3.9\" and python_version < \"3.13\"\ncloudpickle==3.1.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\ncolorama==0.4.6 ; python_version >= \"3.9\" and python_version < \"3.13\" and platform_system == \"Windows\" or python_version >= \"3.9\" and python_version < \"3.13\" and sys_platform == \"win32\"\ndeprecated==1.2.18 ; python_version >= \"3.9\" and python_version < \"3.13\"\ndiffusers==0.31.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ndiskcache==5.6.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\nfilelock==3.18.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nfsspec==2025.3.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\ngoogleapis-common-protos==1.70.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ngrpc-interceptor==0.15.4 ; python_version >= \"3.9\" and python_version < \"3.13\"\ngrpcio-reflection==1.71.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ngrpcio-status==1.71.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ngrpcio==1.72.0rc1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nhf-transfer==0.1.9 ; python_version >= \"3.9\" and python_version < \"3.13\"\nhuggingface-hub==0.30.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nidna==3.10 ; python_version >= \"3.9\" and python_version < \"3.13\"\nimportlib-metadata==8.6.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\ninteregular==0.3.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\njinja2==3.1.6 ; python_version >= \"3.9\" and python_version < \"3.13\"\njoblib==1.4.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\njsonschema-specifications==2024.10.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\njsonschema==4.23.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nlark==1.2.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nllvmlite==0.43.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nloguru==0.7.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\nmarkdown-it-py==3.0.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nmarkupsafe==3.0.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nmdurl==0.1.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nmpmath==1.3.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nnest-asyncio==1.6.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nnetworkx==3.2.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nnumba==0.60.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nnumpy==1.26.4 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-api==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-exporter-otlp==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-instrumentation-grpc==0.53b0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-instrumentation==0.53b0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-proto==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-sdk==1.32.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nopentelemetry-semantic-conventions==0.53b0 ; python_version >= \"3.9\" and python_version < \"3.13\"\noptimum==1.24.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\noutlines==0.0.36 ; python_version >= \"3.9\" and python_version < \"3.13\"\npackaging==24.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\npeft==0.15.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\npillow==11.2.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nprometheus-client==0.21.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nprotobuf==5.29.4 ; python_version >= \"3.9\" and python_version < \"3.13\"\npsutil==7.0.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\npy-cpuinfo==9.0.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\npydantic-core==2.33.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\npydantic==2.11.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\npygments==2.19.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\npyyaml==6.0.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nreferencing==0.36.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nregex==2024.11.6 ; python_version >= \"3.9\" and python_version < \"3.13\"\nrequests==2.32.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\nrich==14.0.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nrpds-py==0.24.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nsafetensors==0.5.3 ; python_version >= \"3.9\" and python_version < \"3.13\"\nscikit-learn==1.6.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nscipy==1.13.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nsentence-transformers==3.3.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nsentencepiece==0.2.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nsetuptools==78.1.0 ; python_version >= \"3.12\" and python_version < \"3.13\"\nshellingham==1.5.4 ; python_version >= \"3.9\" and python_version < \"3.13\"\nsympy==1.13.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\nthreadpoolctl==3.6.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntokenizers==0.21.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntqdm==4.67.1 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntransformers==4.52.4 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntriton==3.2.0 ; python_version >= \"3.9\" and python_version < \"3.13\" and platform_system == \"Linux\" and platform_machine == \"x86_64\"\ntyper==0.15.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntyping-extensions==4.13.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\ntyping-inspection==0.4.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nurllib3==2.4.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\nwin32-setctime==1.2.0 ; python_version >= \"3.9\" and python_version < \"3.13\" and sys_platform == \"win32\"\nwrapt==1.17.2 ; python_version >= \"3.9\" and python_version < \"3.13\"\nzipp==3.21.0 ; python_version >= \"3.9\" and python_version < \"3.13\"\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/__init__.py",
    "content": ""
  },
  {
    "path": "backends/gaudi/server/text_generation_server/adapters/__init__.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/__init__.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom text_generation_server.adapters.weights import (\n    AdapterBatchData,\n    AdapterBatchMetadata,\n)\n\n__all__ = [\n    \"AdapterBatchData\",\n    \"AdapterBatchMetadata\",\n]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/adapters/config.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/config.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, Set, Tuple\n\nimport torch\n\nfrom text_generation_server.adapters.weights import AdapterWeights\n\n\n@dataclass\nclass ModuleMap:\n    module_name: str\n    module_weights: Dict[str, Tuple[torch.Tensor, str]]\n\n\n@dataclass\nclass AdapterConfig(ABC):\n    base_model_name_or_path: str\n\n    @abstractmethod\n    def map_weights_for_model(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        weight_names: Tuple[str],\n    ) -> Tuple[ModuleMap, Set[str]]:\n        pass\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/adapters/lora.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/lora.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Set, Tuple, Type, Union\n\nimport torch\nfrom peft import LoraConfig as _LoraConfig\nfrom torch.distributed import ProcessGroup\n\nfrom text_generation_server.adapters.config import AdapterConfig, ModuleMap\n\nfrom text_generation_server.adapters.weights import (\n    AdapterBatchMetadata,\n    AdapterWeights,\n    BatchAdapterWeights,\n)\nfrom text_generation_server.utils.sgmv import (\n    BGMV_MAX_RANK,\n    MAX_RANK_CUSTOM,\n    get_tmp_tensors,\n    orient_for_rank,\n    pad_rank,\n    use_cutlass_shrink,\n)\n\n\ndef get_start_stop_idxs_for_rank(offset, size, rank, world_size):\n    block_size = size // world_size\n    start = offset + rank * block_size\n    stop = offset + (rank + 1) * block_size\n    return start, stop\n\n\ndef shard_on_dim(\n    t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup\n):\n    world_size = process_group.size()\n    rank = process_group.rank()\n\n    size = t.shape[dim]\n    start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)\n\n    if dim == 0:\n        tensor = t[start:stop]\n    elif dim == 1:\n        tensor = t[:, start:stop]\n    else:\n        raise NotImplementedError(\"Let's make that generic when needed\")\n\n    return tensor\n\n\ndef shard_lora_weights(\n    weights_a: List[torch.Tensor],\n    weights_b: List[torch.Tensor],\n    split_dim: int,\n    process_group: ProcessGroup,\n) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:\n    # [hidden_size, r]\n    weights_a = [\n        shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a\n    ]\n\n    # [r, hidden_size]\n    weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]\n\n    return weights_a, weights_b\n\n\n@dataclass\nclass LoraConfig(AdapterConfig):\n    r: int\n    target_modules: Optional[Union[List[str], str]]\n    fan_in_fan_out: bool\n    lora_alpha: int\n    use_rslora: bool\n\n    def map_weights_for_model(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        weight_names: Tuple[str],\n    ) -> Tuple[ModuleMap, Set[str]]:\n        adapter_weight_names = set()\n        module_map = {}\n        for weight_name in weight_names:\n            lora_a_name = f\"base_model.model.{weight_name}.lora_A.weight\"\n            lora_b_name = f\"base_model.model.{weight_name}.lora_B.weight\"\n            if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:\n                continue\n\n            module_map[weight_name] = {\n                \"lora_A\": (adapter_weights[lora_a_name], lora_a_name),\n                \"lora_B\": (adapter_weights[lora_b_name], lora_b_name),\n            }\n            adapter_weight_names.add(lora_a_name)\n            adapter_weight_names.add(lora_b_name)\n        return module_map, adapter_weight_names\n\n    @classmethod\n    def load(cls, adapter_id: str, api_token: str) -> \"LoraConfig\":\n        hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)\n        return cls(\n            base_model_name_or_path=hf_config.base_model_name_or_path,\n            r=hf_config.r,\n            target_modules=hf_config.target_modules,\n            fan_in_fan_out=hf_config.fan_in_fan_out,\n            lora_alpha=hf_config.lora_alpha,\n            use_rslora=(\n                hf_config.use_rslora if hasattr(hf_config, \"use_rslora\") else False\n            ),\n        )\n\n\nclass LoraWeights(AdapterWeights):\n    \"\"\"LoRA weights for a single adapter merged across all layers.\"\"\"\n\n    def __init__(\n        self,\n        weights_a: List[torch.Tensor],\n        weights_b: List[torch.Tensor],\n        adapter_config: LoraConfig,\n    ):\n        self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1\n        self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1\n\n        self._use_cutlass_shrink = use_cutlass_shrink(self.lora_a_r)\n        self._is_transposed = False\n\n        # [num_layers, hidden_size, r]\n        weights_a = [orient_for_rank(w, w.size(1)).contiguous() for w in weights_a]\n        self._weights_a = torch.stack(weights_a)\n\n        # [num_layers, r, hidden_size]\n        self._weights_b = torch.stack(weights_b)\n\n        self.adapter_config = adapter_config\n\n    @property\n    def weights_a(self) -> torch.Tensor:\n        if self._is_transposed:\n            self._transpose_weights()\n        return self._weights_a\n\n    @property\n    def weights_b(self) -> torch.Tensor:\n        if self._is_transposed:\n            self._transpose_weights()\n        return self._weights_b\n\n    @property\n    def weights_a_t(self) -> torch.Tensor:\n        if not self._is_transposed:\n            self._transpose_weights()\n        return self._weights_a\n\n    @property\n    def weights_b_t(self) -> torch.Tensor:\n        if not self._is_transposed:\n            self._transpose_weights()\n        return self._weights_b\n\n    def _transpose_weights(self):\n        if self._use_cutlass_shrink:\n            # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation\n            self._weights_a = self._weights_a.transpose(1, 2).contiguous()\n        self._weights_b = self._weights_b.transpose(1, 2).contiguous()\n        self._is_transposed = not self._is_transposed\n\n    @classmethod\n    def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:\n        return [BatchLoraWeights]\n\n    # prepare pre-loaded lora weights for use in the model.\n    #\n    # this method processes and organizes lora weights for a specific layer type across all layers:\n    # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.\n    # - retrieves weights from `module_map` based on the `layer_type`.\n    # - processes `nlayers` number of layers.\n    # - converts weights to the specified `dtype`.\n    # - shards weights across `world_size` number of processes using the `process_group`.\n    # - maps weights to specific layers using `target_to_layer`.\n    # - tracks `unused_weight_names` to identify any unused weights.\n    #\n    # the method handles weight transposition, scaling, and padding to ensure compatibility\n    # with SGMV or BGMV operations.\n    @classmethod\n    def prepare_weights(\n        cls,\n        config: LoraConfig,\n        module_map: Dict[str, Dict],\n        layer_type: str,\n        unused_weight_names: Set[str],\n        nlayers: int,\n        dtype: torch.dtype,\n        world_size: int,\n        process_group: ProcessGroup,\n        target_to_layer: Dict[str, Tuple[str, torch.Tensor]],\n    ) -> Optional[AdapterWeights]:\n        lora_a_list = [None] * nlayers\n        lora_b_list = [None] * nlayers\n\n        for layer_id in range(nlayers):\n            key = (layer_id, layer_type)\n            weight_name, layer = target_to_layer[key]\n            base_weight = layer.base_layer.linear.weight\n            base_device = base_weight.device\n\n            if weight_name not in module_map:\n                # There is no LoRA weight for this layer type in the adapter\n                return None\n\n            lora_a, lora_a_name = module_map[weight_name][\"lora_A\"]\n            lora_a = lora_a.to(base_device, dtype)\n\n            lora_b, lora_b_name = module_map[weight_name][\"lora_B\"]\n            lora_b = lora_b.to(base_device, dtype)\n\n            scale = get_scaling_factor(\n                config.lora_alpha,\n                config.r,\n                uses_rslora=config.use_rslora,\n            )\n\n            unused_weight_names.discard(lora_a_name)\n            unused_weight_names.discard(lora_b_name)\n\n            # Merge scaling factor into lora_b due to associativity of matrix multiplication:\n            # (A * B) * C = A * (B * C)\n            lora_a_list[layer_id] = lora_a.transpose(0, 1)\n            lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale\n\n        # pad lora ranks to be compatible with sgmv\n        lora_a_list = [pad_rank(w, dim=1, world_size=world_size) for w in lora_a_list]\n        lora_b_list = [pad_rank(w, dim=0, world_size=world_size) for w in lora_b_list]\n\n        if lora_a_list:\n            # update rank if it was padded\n            padded_rank = lora_a_list[0].size(1)\n            config.r = padded_rank\n\n        return LoraWeights(\n            *shard_lora_weights(\n                weights_a=lora_a_list,\n                weights_b=lora_b_list,\n                split_dim=0 if layer_type in {\"o_proj\", \"down_proj\", \"lm_head\"} else 1,\n                process_group=process_group,\n            ),\n            config,\n        )\n\n\n@dataclass\nclass RankSegments:\n    rank: int\n\n    lora_a_ptr: torch.Tensor\n    lora_b_ptr: torch.Tensor\n\n    # prefill (sgmv)\n    tmp_shrink: torch.Tensor\n    tmp_expand: torch.Tensor\n    segment_starts: torch.Tensor\n    segment_ends: torch.Tensor\n\n    # decode (bgmv)\n    indices: torch.Tensor\n\n\n@dataclass\nclass BatchLoraWeights(BatchAdapterWeights):\n    lora_a: Dict[int, torch.Tensor]\n    lora_b: Dict[int, torch.Tensor]\n    adapter_index_configs: Dict[int, LoraConfig]\n    rank_data: Dict[int, RankSegments]\n    use_sgmv: bool\n\n    def has_adapter(self, adapter_index: int) -> bool:\n        return adapter_index in self.adapter_index_configs\n\n    def can_vectorize(self, pg: ProcessGroup) -> bool:\n        return all(\n            rank_data.rank // pg.size() <= MAX_RANK_CUSTOM\n            for rank_data in self.rank_data.values()\n        )\n\n    @classmethod\n    def load(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        meta: AdapterBatchMetadata,\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> Optional[\"BatchLoraWeights\"]:\n        adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}\n        adapter_weights = {\n            k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)\n        }\n        if not adapter_weights:\n            return None\n\n        first_weights = next(iter(adapter_weights.values()))\n        device = first_weights.weights_a.device\n        segment_indices = meta.segment_indices\n\n        lora_a = {\n            idx: adapter_weights[idx].weights_a\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n        lora_b = {\n            idx: adapter_weights[idx].weights_b\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n\n        max_rank = max(\n            (\n                adapter_weights[idx].lora_a_r\n                for idx in segment_indices\n                if idx in adapter_weights\n            ),\n            default=0,\n        )\n\n        if prefill or max_rank > BGMV_MAX_RANK:\n            use_sgmv = True\n            lora_a_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_a.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n            lora_b_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_b.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n        else:\n            use_sgmv = False\n            lora_a_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_a_t.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n            lora_b_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_b_t.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n\n        adapter_index_configs = {\n            idx: adapter_weights[idx].adapter_config\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n\n        adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}\n\n        rank_indices = defaultdict(list)\n        for segment_idx, adapter_idx in enumerate(segment_indices):\n            if adapter_idx not in adapter_weights:\n                continue\n            rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)\n\n        if prefill_head_indices is not None:\n            j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]\n            for head_index in prefill_head_indices:\n                # j cannot go out of bounds as that would mean there are tokens without corresponding adapters\n                if head_index < meta.adapter_segments[j]:\n                    prefill_head_segment_ends[-1] += 1\n                else:\n                    prefill_head_segment_starts.append(prefill_head_segment_ends[-1])\n                    prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)\n                    j += 1\n\n        rank_data = {}\n        for rank, indices in rank_indices.items():\n            tmp_shrink = None\n            tmp_expand = None\n            segment_starts = None\n            segment_ends = None\n            batch_indices = None\n\n            if use_sgmv:\n                lora_a_ptr_indices = lora_a_ptr[indices]\n                tmp_shrink, tmp_expand = get_tmp_tensors(\n                    lora_a_ptr_indices.size(0), rank, device\n                )\n                segment_starts = meta.adapter_segments[indices]\n                segment_ends = meta.adapter_segments[[i + 1 for i in indices]]\n                if prefill_head_indices is not None:\n                    for i, segment_index in enumerate(indices):\n                        segment_starts[i] = prefill_head_segment_starts[segment_index]\n                        segment_ends[i] = prefill_head_segment_ends[segment_index]\n            else:\n                rank_indices = set(indices)\n                batch_indices = [\n                    adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()\n                ]\n                batch_indices = [\n                    idx if idx in rank_indices else -1 for idx in batch_indices\n                ]\n                batch_indices = torch.tensor(\n                    batch_indices, dtype=torch.int64, device=device\n                )\n\n            rank_data[rank] = RankSegments(\n                rank=rank,\n                tmp_shrink=tmp_shrink,\n                tmp_expand=tmp_expand,\n                lora_a_ptr=lora_a_ptr[indices],\n                lora_b_ptr=lora_b_ptr[indices],\n                segment_starts=segment_starts,\n                segment_ends=segment_ends,\n                indices=batch_indices,\n            )\n\n        return BatchLoraWeights(\n            lora_a=lora_a,\n            lora_b=lora_b,\n            adapter_index_configs=adapter_index_configs,\n            rank_data=rank_data,\n            use_sgmv=use_sgmv,\n        )\n\n\ndef get_scaling_factor(\n    lora_alpha: int,\n    r: int,\n    uses_rslora: bool = False,\n) -> float:\n    \"\"\"Computes the scaling factor for the lora weights.\"\"\"\n    if uses_rslora:\n        return lora_alpha / (r**0.5)\n    return lora_alpha / r\n\n\ndef _convert_lora(v: AdapterWeights) -> AdapterWeights:\n    if hasattr(v, \"lora_weights\"):\n        return v.lora_weights\n    return v\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/adapters/weights.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/weights.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom abc import ABC, abstractclassmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Set, Type\n\nimport torch\n\n\n@dataclass\nclass AdapterBatchMetadata:\n    # [batch_size]\n    adapter_indices: torch.Tensor\n\n    # [num_adapters]\n    adapter_set: Set[int]\n\n    # [num_segments + 1]\n    adapter_segments: torch.Tensor\n\n    # [num_segments]\n    # maps from segment index to adapter index, i.e.:\n    # segment_indices[s] == adapter_indices[i]\n    segment_indices: List[int]\n\n\nclass AdapterWeights(ABC):\n    @abstractclassmethod\n    def get_batch_types(cls) -> List[Type[\"BatchAdapterWeights\"]]:\n        pass\n\n    @property\n    def speculative_tokens(self) -> int:\n        return 0\n\n\nclass BatchAdapterWeights(ABC):\n    @abstractclassmethod\n    def has_adapter(self, adapter_index: int) -> bool:\n        pass\n\n    @abstractclassmethod\n    def load(\n        cls,\n        adapter_weights: Dict[int, AdapterWeights],\n        meta: \"AdapterBatchMetadata\",\n        prefill: bool,\n        prefill_head_indices: torch.Tensor,\n    ) -> Optional[\"BatchAdapterWeights\"]:\n        pass\n\n\nclass LayerAdapterWeights:\n    \"\"\"Adapter weights that apply to a particular layer.\"\"\"\n\n    def __init__(self):\n        self.adapter_weights: Dict[int, AdapterWeights] = {}\n\n    def add_adapter(self, adapter_idx: int, weights: AdapterWeights):\n        self.adapter_weights[adapter_idx] = weights\n\n    def remove_adapter(self, adapter_idx: int):\n        if adapter_idx not in self.adapter_weights:\n            return\n        del self.adapter_weights[adapter_idx]\n\n    def is_empty(self) -> bool:\n        return len(self.adapter_weights) == 0\n\n    def get_data(\n        self,\n        meta: AdapterBatchMetadata,\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> Dict[str, BatchAdapterWeights]:\n        # bucket adapters by batch class\n        adapter_batch_types: Dict[\n            Type[BatchAdapterWeights], Dict[int, AdapterWeights]\n        ] = defaultdict(dict)\n        for adapter_index, adapter_weights in self.adapter_weights.items():\n            for batch_type in adapter_weights.get_batch_types():\n                adapter_batch_types[batch_type][adapter_index] = adapter_weights\n\n        batch_data = {}\n        for batch_type, adapter_weights in adapter_batch_types.items():\n            batched_weights = batch_type.load(\n                adapter_weights, meta, prefill, prefill_head_indices\n            )\n            if batched_weights is not None:\n                batch_data = batched_weights\n        return batch_data\n\n\n@dataclass\nclass AdapterBatchData:\n    meta: AdapterBatchMetadata\n\n    # layer type -> adapter type -> batch weight data\n    data: Dict[str, Dict[str, BatchAdapterWeights]]\n\n    prefill: bool\n\n    @staticmethod\n    def from_meta(\n        meta: AdapterBatchMetadata,\n        weights: Dict[str, LayerAdapterWeights],\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> \"AdapterBatchData\":\n        data = {}\n        for k, v in weights.items():\n            if v.is_empty():\n                continue\n            data[k] = v.get_data(\n                meta, prefill, prefill_head_indices if k == \"lm_head\" else None\n            )\n        return AdapterBatchData(meta=meta, data=data, prefill=prefill)\n\n    def ranks(self) -> Set[int]:\n        # TODO(travis): refactor to be less coupled to lora implementation\n        ranks = set()\n        for lora_data in self.data.values():\n            if lora_data is None:\n                continue\n\n            for rank_data in lora_data.rank_data.values():\n                ranks.add(rank_data.rank)\n\n        return ranks\n\n    def layer_names(self) -> Set[str]:\n        return set(self.data.keys())\n\n    def adapter_keys(self) -> Set[str]:\n        adapter_keys = set()\n        for layer_data in self.data.values():\n            adapter_keys.update(layer_data.keys())\n        return adapter_keys\n\n    @property\n    def max_rank(self) -> int:\n        ranks = self.ranks()\n        return max(ranks) if len(ranks) > 0 else 0\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/cache.py",
    "content": "import torch\n\nfrom typing import Dict, Optional, TypeVar\n\nfrom text_generation_server.models.types import Batch\n\nB = TypeVar(\"B\", bound=Batch)\n\n\nclass Cache:\n    def __init__(self):\n        self.cache: Dict[int, B] = {}\n\n    def pop(self, batch_id: int) -> Optional[B]:\n        return self.cache.pop(batch_id, None)\n\n    def set(self, entry: B):\n        if entry is not None:\n            self.cache[entry.batch_id] = entry\n\n    def delete(self, batch_id: int):\n        batch = self.pop(batch_id)\n        if batch is not None:\n            del batch\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    def clear(self):\n        keys = list(self.cache.keys())\n        for k in keys:\n            self.delete(k)\n\n    def __len__(self):\n        return len(self.cache.keys())\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/cli.py",
    "content": "import os\nimport sys\nimport typer\n\nfrom pathlib import Path\nfrom loguru import logger\nfrom typing import Optional\nfrom enum import Enum\nfrom huggingface_hub import hf_hub_download\nfrom text_generation_server.utils.adapter import parse_lora_adapters\n\n\napp = typer.Typer()\n\n\nclass Quantization(str, Enum):\n    gptq = \"gptq\"\n    awq = \"awq\"\n    fp8 = \"fp8\"\n    compressed_tensors = \"compressed-tensors\"\n\n\nclass Dtype(str, Enum):\n    float16 = \"float16\"\n    bloat16 = \"bfloat16\"\n\n\nclass KVCacheDtype(str, Enum):\n    fp8_e4m3fn = \"fp8_e4m3fn\"\n    fp8_e5m2 = \"fp8_e5m2\"\n\n\n@app.command()\ndef serve(\n    model_id: str,\n    revision: Optional[str] = None,\n    sharded: bool = False,\n    quantize: Optional[Quantization] = None,\n    speculate: Optional[int] = None,\n    dtype: Optional[Dtype] = None,\n    kv_cache_dtype: Optional[KVCacheDtype] = None,\n    trust_remote_code: bool = False,\n    uds_path: Path = \"/tmp/text-generation-server\",\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    otlp_endpoint: Optional[str] = None,\n    otlp_service_name: str = \"text-generation-inference.server\",\n    max_input_tokens: Optional[int] = None,\n):\n    if sharded:\n        # assert (\n        #     os.getenv(\"RANK\", None) is not None\n        # ), \"RANK must be set when sharded is True\"\n        assert (\n            os.getenv(\"WORLD_SIZE\", None) is not None\n        ), \"WORLD_SIZE must be set when sharded is True\"\n        assert (\n            os.getenv(\"MASTER_ADDR\", None) is not None\n        ), \"MASTER_ADDR must be set when sharded is True\"\n        assert (\n            os.getenv(\"MASTER_PORT\", None) is not None\n        ), \"MASTER_PORT must be set when sharded is True\"\n\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    # Import here after the logger is added to log potential import exceptions\n    from text_generation_server import server\n    from text_generation_server.tracing import setup_tracing\n\n    # Setup OpenTelemetry distributed tracing\n    if otlp_endpoint is not None:\n        setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)\n\n    lora_adapters = parse_lora_adapters(os.getenv(\"LORA_ADAPTERS\"))\n\n    # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled\n    # and warn the user\n    if lora_adapters:\n        logger.warning(\"LoRA adapters enabled (experimental feature).\")\n\n        if \"CUDA_GRAPHS\" in os.environ:\n            logger.warning(\n                \"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs.\"\n            )\n            global CUDA_GRAPHS\n            CUDA_GRAPHS = None\n\n    # Downgrade enum into str for easier management later on\n    quantize = None if quantize is None else quantize.value\n    dtype = \"bfloat16\" if dtype is None else dtype.value\n    kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value\n    logger.info(f\"quantize={quantize} kv_cache_dtype={kv_cache_dtype}\")\n    if dtype is not None and quantize not in {\n        None,\n        \"bitsandbytes\",\n        \"bitsandbytes-nf4\",\n        \"bitsandbytes-fp4\",\n        \"gptq\",\n        \"awq\",\n        \"fp8\",\n        \"compressed-tensors\",\n    }:\n        raise RuntimeError(\n            \"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model.\"\n        )\n    server.serve(\n        model_id,\n        lora_adapters,\n        revision,\n        sharded,\n        quantize,\n        speculate,\n        dtype,\n        kv_cache_dtype,\n        trust_remote_code,\n        uds_path,\n        max_input_tokens,\n    )\n\n\n@app.command()\ndef download_weights(\n    model_id: str,\n    revision: Optional[str] = None,\n    extension: str = \".safetensors\",\n    auto_convert: bool = True,\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    trust_remote_code: bool = False,\n    merge_lora: bool = False,\n):\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    # Import here after the logger is added to log potential import exceptions\n    from text_generation_server import utils\n\n    # Test if files were already download\n    try:\n        utils.weight_files(model_id, revision, extension)\n        logger.info(\"Files are already present on the host. \" \"Skipping download.\")\n        return\n    # Local files not found\n    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):\n        pass\n\n    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(\n        \"WEIGHTS_CACHE_OVERRIDE\", None\n    ) is not None\n\n    if not is_local_model:\n        # TODO: maybe reverse the default value of merge_lora?\n        # currently by default we don't merge the weights with the base model\n        if merge_lora:\n            try:\n                hf_hub_download(\n                    model_id, revision=revision, filename=\"adapter_config.json\"\n                )\n                utils.download_and_unload_peft(\n                    model_id, revision, trust_remote_code=trust_remote_code\n                )\n                is_local_model = True\n                utils.weight_files(model_id, revision, extension)\n                return\n            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n                pass\n        else:\n            try:\n                utils.peft.download_peft(\n                    model_id, revision, trust_remote_code=trust_remote_code\n                )\n            except Exception:\n                pass\n\n        try:\n            import json\n\n            config = hf_hub_download(\n                model_id, revision=revision, filename=\"config.json\"\n            )\n            with open(config, \"r\") as f:\n                config = json.load(f)\n\n            base_model_id = config.get(\"base_model_name_or_path\", None)\n            if base_model_id and base_model_id != model_id:\n                try:\n                    logger.info(f\"Downloading parent model {base_model_id}\")\n                    download_weights(\n                        model_id=base_model_id,\n                        revision=\"main\",\n                        extension=extension,\n                        auto_convert=auto_convert,\n                        logger_level=logger_level,\n                        json_output=json_output,\n                        trust_remote_code=trust_remote_code,\n                    )\n                except Exception:\n                    pass\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n\n        # Try to download weights from the hub\n        try:\n            filenames = utils.weight_hub_files(model_id, revision, extension)\n            utils.download_weights(filenames, model_id, revision)\n            # Successfully downloaded weights\n            return\n\n        # No weights found on the hub with this extension\n        except utils.EntryNotFoundError as e:\n            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead\n            if not extension == \".safetensors\" or not auto_convert:\n                raise e\n\n    elif (Path(model_id) / \"adapter_config.json\").exists():\n        # Try to load as a local PEFT model\n        try:\n            utils.download_and_unload_peft(\n                model_id, revision, trust_remote_code=trust_remote_code\n            )\n            utils.weight_files(model_id, revision, extension)\n            return\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n    elif (Path(model_id) / \"config.json\").exists():\n        # Try to load as a local Medusa model\n        try:\n            import json\n\n            config = Path(model_id) / \"config.json\"\n            with open(config, \"r\") as f:\n                config = json.load(f)\n\n            base_model_id = config.get(\"base_model_name_or_path\", None)\n            if base_model_id:\n                try:\n                    logger.info(f\"Downloading parent model {base_model_id}\")\n                    download_weights(\n                        model_id=base_model_id,\n                        revision=\"main\",\n                        extension=extension,\n                        auto_convert=auto_convert,\n                        logger_level=logger_level,\n                        json_output=json_output,\n                        trust_remote_code=trust_remote_code,\n                    )\n                except Exception:\n                    pass\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n\n    # Try to see if there are local pytorch weights\n    try:\n        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE\n        try:\n            local_pt_files = utils.weight_files(model_id, revision, \".bin\")\n        except Exception:\n            local_pt_files = utils.weight_files(model_id, revision, \".pt\")\n\n    # No local pytorch weights\n    except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n        if extension == \".safetensors\":\n            logger.warning(\n                f\"No safetensors weights found for model {model_id} at revision {revision}. \"\n                f\"Downloading PyTorch weights.\"\n            )\n\n        # Try to see if there are pytorch weights on the hub\n        pt_filenames = utils.weight_hub_files(model_id, revision, \".bin\")\n        # Download pytorch weights\n        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)\n\n    if auto_convert:\n        if not trust_remote_code:\n            logger.warning(\n                \"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because \"\n                \"Pickle files are unsafe and can essentially contain remote code execution!\"\n                \"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety\",\n            )\n\n        logger.warning(\n            f\"No safetensors weights found for model {model_id} at revision {revision}. \"\n            f\"Converting PyTorch weights to safetensors.\"\n        )\n\n        # Safetensors final filenames\n        local_st_files = [\n            p.parent / f\"{p.stem.lstrip('pytorch_')}.safetensors\"\n            for p in local_pt_files\n        ]\n        try:\n            import transformers\n            import json\n\n            if is_local_model:\n                config_filename = os.path.join(model_id, \"config.json\")\n            else:\n                config_filename = hf_hub_download(\n                    model_id, revision=revision, filename=\"config.json\"\n                )\n            with open(config_filename, \"r\") as f:\n                config = json.load(f)\n            architecture = config[\"architectures\"][0]\n\n            class_ = getattr(transformers, architecture)\n\n            # Name for this varible depends on transformers version.\n            discard_names = getattr(class_, \"_tied_weights_keys\", [])\n\n        except Exception:\n            discard_names = []\n        # Convert pytorch weights to safetensors\n        utils.convert_files(local_pt_files, local_st_files, discard_names)\n\n\n@app.command()\ndef quantize(\n    model_id: str,\n    output_dir: str,\n    revision: Optional[str] = None,\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    trust_remote_code: bool = False,\n    upload_to_model_id: Optional[str] = None,\n    percdamp: float = 0.01,\n    act_order: bool = False,\n    groupsize: int = 128,\n):\n    if revision is None:\n        revision = \"main\"\n    download_weights(\n        model_id=model_id,\n        revision=revision,\n        logger_level=logger_level,\n        json_output=json_output,\n    )\n    from text_generation_server.layers.gptq.quantize import quantize\n\n    quantize(\n        model_id=model_id,\n        bits=4,\n        groupsize=groupsize,\n        output_dir=output_dir,\n        revision=revision,\n        trust_remote_code=trust_remote_code,\n        upload_to_model_id=upload_to_model_id,\n        percdamp=percdamp,\n        act_order=act_order,\n        sym=True,\n    )\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/interceptor.py",
    "content": "# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.\n\nimport torch\nimport grpc\n\nfrom google.rpc import status_pb2, code_pb2\nfrom grpc_status import rpc_status\nfrom grpc_interceptor.server import AsyncServerInterceptor\nfrom loguru import logger\nfrom typing import Callable, Any\nimport traceback\nimport os\n\n\nclass ExceptionInterceptor(AsyncServerInterceptor):\n    async def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        try:\n            response = method(request_or_iterator, context)\n            return await response\n        except Exception as err:\n            trace = \" \" + traceback.format_exc() if os.environ.get(\"DUMP_STACK\") else \"\"\n            method_name = method_name.split(\"/\")[-1]\n            logger.exception(f\"Method {method_name} encountered an error.\")\n\n            # Runtime Error cannot be recovered from\n            if isinstance(err, RuntimeError):\n                exit(1)\n\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n            from .utils.debug import dbg_trace\n\n            dbg_trace(\"EXCEPTION\", traceback.format_exc())\n            await context.abort_with_status(\n                rpc_status.to_status(\n                    status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)\n                )\n            )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/__init__.py",
    "content": "from text_generation_server.layers.tensor_parallel import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n)\nfrom text_generation_server.layers.linear import (\n    get_linear,\n    FastLinear,\n)\nfrom text_generation_server.layers.speculative import SpeculativeHead\n\n# Just to add the `load` methods.\nfrom text_generation_server.layers.layernorm import load_layer_norm\nfrom text_generation_server.layers.conv import load_conv2d\nfrom text_generation_server.layers.fp8 import Fp8Linear\n\nfrom text_generation_server.layers.lora import (\n    LoraLinear,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\n\n__all__ = [\n    \"get_linear\",\n    \"FastLinear\",\n    \"TensorParallelColumnLinear\",\n    \"TensorParallelRowLinear\",\n    \"TensorParallelEmbedding\",\n    \"SpeculativeHead\",\n    \"LoraLinear\",\n    \"Fp8Linear\",\n    \"TensorParallelMultiAdapterLinear\",\n    \"TensorParallelAdapterRowLinear\",\n    \"load_layer_norm\",\n    \"load_conv2d\",\n]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/attention/__init__.py",
    "content": "from .common import (\n    Seqlen,\n    HPUPagedAttentionMetadata,\n    trim_attn_metadata,\n    trim_seqlen_metadata,\n    _async_h2d_tensor_copy,\n)\n\nfrom .hpu import (\n    SUPPORTS_WINDOWING,\n    attention,\n    paged_attention,\n    paged_attention_mla,\n    set_block_mapping,\n)\n\n\n# KVCache needs `reshape_and_cache`, so ensure that it is defined already.\nfrom .kv_cache import KVCache, get_kv_scales, KVCompressCache\n\n__all__ = [\n    \"attention\",\n    \"get_kv_scales\",\n    \"paged_attention\",\n    \"paged_attention_mla\",\n    \"set_block_mapping\",\n    \"SUPPORTS_WINDOWING\",\n    \"KVCache\",\n    \"KVCompressCache\",\n    \"Seqlen\",\n    \"HPUPagedAttentionMetadata\",\n    \"trim_seqlen_metadata\",\n    \"trim_attn_metadata\",\n    \"_async_h2d_tensor_copy\",\n]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/attention/common.py",
    "content": "from dataclasses import dataclass\nimport torch\nfrom typing import Optional, List, Dict\nimport collections\nimport torch.nn.functional as F\n\n_TYPE_CACHE = {}\n\n\n@dataclass\nclass HPUPagedAttentionMetadata:\n    \"\"\"Metadata for PagedAttention.\"\"\"\n\n    block_list: Optional[torch.Tensor]\n    block_mapping: Optional[torch.Tensor]\n    block_usage: Optional[torch.Tensor]\n    block_groups: Optional[torch.Tensor]\n    attn_bias: Optional[torch.Tensor]\n    slots_in_window_mask: Optional[torch.Tensor] = None\n    block_list_in_window: Optional[torch.Tensor] = None\n    block_mapping_in_window: Optional[torch.Tensor] = None\n    block_usage_in_window: Optional[torch.Tensor] = None\n    block_groups_in_window: Optional[torch.Tensor] = None\n    attn_bias_in_window: Optional[torch.Tensor] = None\n\n\ndef subtuple(\n    obj: object,\n    typename: str,\n    to_copy: List[str],\n    to_override: Optional[Dict[str, object]] = None,\n):\n    if obj is None:\n        return None\n    if to_override is None:\n        to_override = {}\n    fields = set(to_copy) | set(to_override.keys())\n    if isinstance(obj, dict):\n        values = {key: obj[key] for key in fields if key in obj}\n    else:\n        values = {f: to_override.get(f, getattr(obj, f)) for f in fields}\n    if typename not in _TYPE_CACHE:\n        _TYPE_CACHE[typename] = collections.namedtuple(typename, \" \".join(fields))\n    return _TYPE_CACHE[typename](**values)\n\n\ndef trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:\n    # NOTE(kzawora): To anyone working on this in the future:\n    # Trimming metadata is required when using HPUGraphs.\n    # Attention metadata is going to be hashed by PT bridge, and\n    # appropriate HPUGraphs will be matched based on all inputs' hash.\n\n    # Before you put more keys in here, make sure you know their\n    # value type and make sure you know how it's going to be hashed.\n    # You can find that information in input_hash function\n    # in habana_frameworks/torch/hpu/graphs.py. You can also hash\n    # it manually with torch.hpu.graphs.input_hash(attention_metadata)\n\n    # If you use primitive types here - they will get hashed based\n    # on their value. You *will* get lots of excessive graph captures\n    # (and an OOM eventually) if you decide to put something like\n    # seq_len int here.\n    # If you absolutely need a scalar, put it in a tensor. Tensors\n    # get hashed using their metadata, not their values:\n    # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))\n    # input_hash(123) != input_hash(321)\n    # input_hash(\"abc\") != input_hash(\"cba\")\n    attention_metadata = subtuple(\n        metadata,\n        \"TrimmedAttentionMetadata\",\n        [\n            \"block_list\",\n            \"block_mapping\",\n            \"block_usage\",\n            \"block_groups\",\n            \"attn_bias\",\n            \"slots_in_window_mask\",\n            \"block_list_in_window\",\n            \"block_mapping_in_window\",\n            \"block_usage_in_window\",\n            \"block_groups_in_window\",\n            \"attn_bias_in_window\",\n        ],\n    )\n    return attention_metadata\n\n\n@dataclass\nclass Seqlen:\n    input_lengths: torch.Tensor\n    attn_mask: Optional[torch.Tensor] = None\n\n    def __init__(\n        self,\n        input_lengths,\n    ):\n        self.input_lengths = input_lengths\n\n    def clamp(self, max):\n        # Flash decoding doesn't need to clamp\n        return self\n\n    def make_sliding_window_bias(\n        self,\n        seq_lens: List[int],\n        window_size: Optional[int],\n        dtype: torch.dtype,\n        padded_input_len: Optional[int],\n        padded_bs: Optional[int],\n    ) -> List[torch.Tensor]:\n        attn_biases = []\n        for seq_len in seq_lens:\n            if seq_len != 0:\n                tensor = torch.full(\n                    (1, seq_len, seq_len),\n                    dtype=dtype,\n                    fill_value=1,\n                )\n                shift = 0\n                mask = torch.tril(tensor, diagonal=shift).to(dtype)  # type: ignore\n                if window_size is not None:\n                    mask = torch.triu(mask, diagonal=shift - window_size + 1)\n                mask = F.pad(\n                    mask,\n                    (\n                        padded_input_len - seq_len,\n                        0,\n                        padded_input_len - seq_len,\n                        0,\n                        0,\n                        0,\n                    ),\n                    value=0,\n                )\n            else:\n                mask = torch.full(\n                    (1, padded_input_len, padded_input_len),\n                    dtype=dtype,\n                    fill_value=0,\n                )\n            attn_biases.append(mask)\n        attn_biases = torch.stack(attn_biases, dim=0)\n        return attn_biases.to(torch.bool)\n\n\ndef _async_h2d_tensor_copy(source, device=\"hpu\"):\n    if source is None:\n        return None\n    if source.device.type == \"hpu\":\n        return source\n    assert source.device.type == \"cpu\", \"Source tensor is not present in host memory!\"\n    target = torch.empty(source.shape, dtype=source.dtype, device=device)\n    target.copy_(source, non_blocking=True)\n    return target\n\n\ndef trim_seqlen_metadata(metadata: Seqlen) -> object:\n    # NOTE(kzawora): To anyone working on this in the future:\n    # Trimming metadata is required when using HPUGraphs.\n    # Attention metadata is going to be hashed by PT bridge, and\n    # appropriate HPUGraphs will be matched based on all inputs' hash.\n\n    # Before you put more keys in here, make sure you know their\n    # value type and make sure you know how it's going to be hashed.\n    # You can find that information in input_hash function\n    # in habana_frameworks/torch/hpu/graphs.py. You can also hash\n    # it manually with torch.hpu.graphs.input_hash(attention_metadata)\n\n    # If you use primitive types here - they will get hashed based\n    # on their value. You *will* get lots of excessive graph captures\n    # (and an OOM eventually) if you decide to put something like\n    # seq_len int here.\n    # If you absolutely need a scalar, put it in a tensor. Tensors\n    # get hashed using their metadata, not their values:\n    # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))\n    # input_hash(123) != input_hash(321)\n    # input_hash(\"abc\") != input_hash(\"cba\")\n    attention_metadata = subtuple(\n        metadata,\n        \"TrimmedSeqlen\",\n        [\n            \"input_lengths\",\n            \"attn_mask\",\n        ],\n    )\n    return attention_metadata\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/attention/hpu.py",
    "content": "import torch\nfrom text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata\nfrom typing import Optional\nfrom text_generation_server.layers.attention.kv_cache import KVCache, KVScales\nfrom vllm_hpu_extension import ops\nfrom vllm_hpu_extension.utils import Matmul\nfrom habana_frameworks.torch.hpex.kernels import FusedSDPA\nfrom vllm_hpu_extension.utils import ModuleFusedSDPA\nimport os\nfrom text_generation_server.models.globals import BLOCK_SIZE\nimport math\n\nSUPPORTS_WINDOWING = False\n\n\nclass FP8Matmul(torch.nn.Module):\n\n    def __init__(self, scale_other):\n        super().__init__()\n        self.scale_input = torch.tensor(1.0, dtype=torch.bfloat16, device=\"hpu\")\n        self.scale_other = scale_other\n\n    def quant_input(self, x, scale):\n        return torch.ops.hpu.cast_to_fp8_v2(\n            x, scale, False, False, torch.float8_e4m3fn\n        )[0]\n\n    def matmul_fp8(\n        self, x, other, out_dtype, scale_input_inv=None, scale_other_inv=None\n    ):\n        return torch.ops.hpu.fp8_gemm_v2(\n            A=x,\n            trans_A=False,\n            B=other,\n            trans_B=False,\n            D=None,\n            out_dtype=out_dtype,\n            A_scale_inv=scale_input_inv,\n            B_scale_inv=scale_other_inv,\n            bias=None,\n            accumulate=False,\n        )\n\n    def forward(self, input, other):\n        qinput = self.quant_input(input, self.scale_input)\n        qother = self.quant_input(other, self.scale_other)\n        output = self.matmul_fp8(\n            qinput,\n            qother,\n            out_dtype=torch.bfloat16,\n            scale_input_inv=1.0 / self.scale_input,\n            scale_other_inv=1.0 / self.scale_other,\n        )\n        return output\n\n\nclass FetchFromCache(torch.nn.Module):\n\n    def __init__(self, scale_inv):\n        super().__init__()\n        self.scale_inv = scale_inv\n\n    def forward(self, cache, blocks):\n        if os.environ.get(\"VLLM_CONTIGUOUS_PA\", \"true\").lower() == \"true\":\n            out = cache[: blocks.size(0)]\n        else:\n            out = cache.index_select(0, blocks)\n        if out.dtype == torch.float8_e4m3fn:\n            out = torch.ops.hpu.cast_from_fp8(out, self.scale_inv, torch.bfloat16)\n        return out\n\n\ndef attention(\n    *,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    kv_cache: KVCache,\n    kv_scales: KVScales,\n    seqlen: Seqlen,\n    softmax_scale: float,\n    window_size_left: int = -1,\n    causal: bool = True,\n    softcap: Optional[float] = None,\n):\n    fsdpa_op = ModuleFusedSDPA(FusedSDPA)\n    bs = seqlen.input_lengths.shape[0]\n    _, head_num, head_size = query.shape\n    _, kv_head_num, head_size = key.shape\n    query = query.view(bs, -1, head_num, head_size).transpose(1, 2)\n    key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)\n    value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)\n    attn_output = fsdpa_op(\n        query,\n        key,\n        value,\n        attn_mask=seqlen.attn_mask if window_size_left != -1 else None,\n        dropout_p=0.0,\n        is_causal=causal if window_size_left == -1 else False,\n        scale=softmax_scale,\n        softmax_mode=\"None\",\n        recompute_mode=None,\n        valid_sequence_lengths=seqlen.input_lengths if window_size_left == -1 else None,\n        padding_side=\"left\",\n    )\n    attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()\n    return attn_output\n\n\ndef set_block_mapping(hpu_attention_meta: HPUPagedAttentionMetadata, batch_size):\n    block_mapping = torch.nn.functional.one_hot(\n        hpu_attention_meta.block_groups, num_classes=batch_size\n    )\n    dtype = hpu_attention_meta.block_usage.dtype\n    device = hpu_attention_meta.block_usage.device\n    mask = torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int32).unsqueeze(0)\n    mask = mask >= hpu_attention_meta.block_usage.unsqueeze(-1)\n    attn_bias = torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -math.inf)\n    hpu_attention_meta = hpu_attention_meta._replace(\n        attn_bias=attn_bias, block_mapping=block_mapping.to(dtype)\n    )\n    if hpu_attention_meta.block_groups_in_window is not None:\n        block_mapping = torch.nn.functional.one_hot(\n            hpu_attention_meta.block_groups_in_window, num_classes=batch_size\n        )\n        attn_bias = torch.log(hpu_attention_meta.slots_in_window_mask.float())\n        hpu_attention_meta = hpu_attention_meta._replace(\n            attn_bias_in_window=attn_bias,\n            block_mapping_in_window=block_mapping.to(dtype),\n        )\n    return hpu_attention_meta\n\n\ndef paged_attention(\n    query: torch.Tensor,\n    kv_cache: KVCache,\n    kv_head_mapping: torch.Tensor,\n    softmax_scale: float,\n    seqlen: Seqlen,\n    *,\n    kv_scales: KVScales,\n    softcap: Optional[float] = None,\n    hpu_attention_meta: HPUPagedAttentionMetadata,\n    window_size_left: int = -1,\n):\n    batch_size, head_num, head_size = query.shape\n    fp8_kv = kv_cache.dtype == torch.float8_e4m3fn\n    output = ops.flat_pa(\n        query=query.view(batch_size, 1, head_num * head_size),\n        key_cache=kv_cache.key,\n        value_cache=kv_cache.value,\n        block_list=(\n            hpu_attention_meta.block_list\n            if window_size_left == -1\n            else hpu_attention_meta.block_list_in_window\n        ),\n        block_mapping=(\n            hpu_attention_meta.block_mapping\n            if window_size_left == -1\n            else hpu_attention_meta.block_mapping_in_window\n        ),\n        block_bias=(\n            hpu_attention_meta.attn_bias\n            if window_size_left == -1\n            else hpu_attention_meta.attn_bias_in_window\n        ),\n        block_groups=(\n            hpu_attention_meta.block_groups\n            if window_size_left == -1\n            else hpu_attention_meta.block_groups_in_window\n        ),\n        block_size=BLOCK_SIZE,\n        scale=softmax_scale,\n        matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),\n        matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),\n        batch2block_matmul_op=Matmul(),\n        block2batch_matmul_op=Matmul(),\n        keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),\n        values_fetch_func=FetchFromCache(1.0 / kv_scales.value_scale_cpu),\n    )\n    # Reshape the output tensor.\n    return output.view(batch_size, head_num, head_size)\n\n\ndef paged_attention_mla(\n    query: torch.Tensor,\n    kv_cache: KVCache,\n    kv_head_mapping: torch.Tensor,\n    softmax_scale: float,\n    seqlen: Seqlen,\n    *,\n    kv_scales: KVScales,\n    softcap: Optional[float] = None,\n    hpu_attention_meta: HPUPagedAttentionMetadata,\n    kv_lora_rank: int = 0,\n):\n    batch_size, head_num, head_size = query.shape\n    fp8_kv = kv_cache.dtype == torch.float8_e4m3fn\n    output = ops.flat_pa_mla(\n        query=query,\n        key_cache=kv_cache.key,\n        value_cache=None,\n        block_list=hpu_attention_meta.block_list,\n        block_mapping=hpu_attention_meta.block_mapping,\n        block_bias=hpu_attention_meta.attn_bias,\n        block_groups=hpu_attention_meta.block_groups,\n        block_size=BLOCK_SIZE,\n        scale=softmax_scale,\n        matmul_qk_op=FP8Matmul(kv_scales.key_scale) if fp8_kv else Matmul(),\n        matmul_av_op=FP8Matmul(kv_scales.value_scale) if fp8_kv else Matmul(),\n        batch2block_matmul_op=Matmul(),\n        block2batch_matmul_op=Matmul(),\n        keys_fetch_func=FetchFromCache(1.0 / kv_scales.key_scale_cpu),\n        values_fetch_func=None,\n        kv_lora_rank=kv_lora_rank,\n    )\n    # Reshape the output tensor.\n    return output.view(batch_size, head_num, -1)\n\n\n__all__ = [\n    \"SUPPORTS_WINDOWING\",\n    \"attention\",\n    \"paged_attention\",\n    \"paged_attention_mla\",\n    \"set_block_mapping\",\n]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/attention/kv_cache.py",
    "content": "from typing import Tuple\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfrom text_generation_server.models.globals import BLOCK_SIZE\nfrom text_generation_server.utils.weights import Weights\n\n\n@dataclass\nclass KVScales:\n    \"\"\"\n    Key-value scales for FP8 KV cache.\n\n    This data class stores key and value scales both as a GPU tensor and\n    as a GPU float. This inconvenience is necessary because some functions\n    (e.g. scaling kernels) take scales as a GPU tensor, whereas others\n    (e.g. flashinfer) take scales as a CPU scalar.\n    \"\"\"\n\n    key_scale: torch.Tensor\n    value_scale: torch.Tensor\n    key_scale_cpu: float = field(init=False)\n    value_scale_cpu: float = field(init=False)\n\n    def __post_init__(self):\n        if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:\n            raise ValueError(\"Key and value scales must be scalar tensors.\")\n\n        self.key_scale_cpu = self.key_scale.item()\n        self.value_scale_cpu = self.value_scale.item()\n\n\nclass KVCache:\n    \"\"\"\n    Key-value cache for attention layers.\n    \"\"\"\n\n    kv_cache: Tuple[torch.Tensor, torch.Tensor]\n\n    def __init__(\n        self,\n        *,\n        num_blocks: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n    ):\n        \"\"\"Construct the key-value cache for a layer.\"\"\"\n        ## TODO FP8 kv cache support\n        if dtype is torch.float8_e5m2:\n            raise ValueError(\"torch.float8_e5m2 is not supported in hpu. \")\n\n        self.kv_cache = (\n            torch.zeros(\n                (num_blocks * BLOCK_SIZE, num_heads, head_size),\n                dtype=dtype,\n                device=device,\n            ),\n            torch.zeros(\n                (num_blocks * BLOCK_SIZE, num_heads, head_size),\n                dtype=dtype,\n                device=device,\n            ),\n        )\n\n    @property\n    def dtype(self):\n        \"\"\"Get the data type of the cache.\"\"\"\n        return self.kv_cache[0].dtype\n\n    @property\n    def key(self):\n        \"\"\"Get the key cache.\"\"\"\n\n        return self.kv_cache[0]\n\n    @property\n    def value(self):\n        \"\"\"Get the value cache.\"\"\"\n\n        return self.kv_cache[1]\n\n    def store(\n        self,\n        *,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        slots: torch.Tensor,\n        kv_scales: KVScales,\n    ):\n        \"\"\"Store the key and value at the given slots.\"\"\"\n        ## TODO FP8 kv cache support\n\n        key_cache = self.kv_cache[0]\n        value_cache = self.kv_cache[1]\n\n        paged_reshape_and_cache(\n            key,\n            value,\n            key_cache,\n            value_cache,\n            slots,\n            kv_scales.key_scale,\n            kv_scales.value_scale,\n        )\n\n\nclass KVCompressCache(KVCache):\n    \"\"\"\n    Key-value cache for attention layers.\n    \"\"\"\n\n    kv_cache: torch.Tensor\n\n    def __init__(\n        self,\n        *,\n        num_blocks: int,\n        head_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n    ):\n        \"\"\"Construct the key-value cache for a layer.\"\"\"\n        ## TODO FP8 kv cache support\n        if dtype is torch.float8_e5m2:\n            raise ValueError(\"torch.float8_e5m2 is not supported in hpu. \")\n\n        self.kv_cache = torch.zeros(\n            (num_blocks * BLOCK_SIZE, 1, head_size),\n            dtype=dtype,\n            device=device,\n        )\n\n    @property\n    def dtype(self):\n        \"\"\"Get the data type of the cache.\"\"\"\n        return self.kv_cache.dtype\n\n    @property\n    def key(self):\n        \"\"\"Get the key cache.\"\"\"\n\n        return self.kv_cache\n\n    @property\n    def value(self):\n        \"\"\"Get the value cache.\"\"\"\n\n        return self.kv_cache\n\n    def store(\n        self,\n        *,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        slots: torch.Tensor,\n        kv_scales: KVScales,\n    ):\n        \"\"\"Store the key and value at the given slots.\"\"\"\n        ## TODO FP8 kv cache support\n        if self.kv_cache.dtype == torch.float8_e4m3fn:\n            key = torch.ops.hpu.cast_to_fp8_v2(\n                key, kv_scales.key_scale, False, False, torch.float8_e4m3fn\n            )[0]\n        self.kv_cache.index_copy_(0, slots, key)\n\n\ndef paged_reshape_and_cache(\n    key: torch.Tensor,\n    value: torch.Tensor,\n    key_cache: torch.Tensor,\n    value_cache: torch.Tensor,\n    slots: torch.Tensor,\n    k_scale: torch.Tensor,\n    v_scale: torch.Tensor,\n):\n    if key_cache.dtype == torch.float8_e4m3fn:\n        key = torch.ops.hpu.cast_to_fp8_v2(\n            key, k_scale, False, False, torch.float8_e4m3fn\n        )[0]\n        value = torch.ops.hpu.cast_to_fp8_v2(\n            value, v_scale, False, False, torch.float8_e4m3fn\n        )[0]\n    key_cache.index_copy_(0, slots, key)\n    value_cache.index_copy_(0, slots, value)\n\n\ndef get_kv_scales(weights: Weights, prefix: str) -> KVScales:\n    \"\"\"Load KV cache scales.\"\"\"\n\n    key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)\n    value_scale = key_scale\n    if weights.has_tensor(f\"{prefix}.k_scale\") and weights.has_tensor(\n        f\"{prefix}.v_scale\"\n    ):\n        key_scale = weights.get_tensor(f\"{prefix}.k_scale\", to_dtype=False).float()\n        value_scale = weights.get_tensor(f\"{prefix}.v_scale\", to_dtype=False).float()\n    elif weights.has_tensor(f\"{prefix}.kv_scale\"):\n        # Fall back to older more coarse-grained scale when available.\n        key_scale = weights.get_tensor(f\"{prefix}.kv_scale\").float()\n        value_scale = key_scale\n\n    return KVScales(key_scale=key_scale, value_scale=value_scale)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/awq/conversion_utils.py",
    "content": "import torch\nfrom typing import List\n\n\nAWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]\nREVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n\n\ndef pack(imatrix: torch.Tensor, direction: str = \"column\"):\n    \"\"\"\n    Packs a 4-bit integer matrix into a packed 32-bit integer matrix.\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n        direction (str): direction of packing, either \"column\" or \"row\"\n    Returns:\n        qmatrix (torch.Tensor): packed matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)\n\n    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow\n\n    if direction == \"column\":\n        imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)\n\n    elif direction == \"row\":\n        imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)\n\n    qmatrix = qmatrix.to(torch.int32)\n\n    return qmatrix\n\n\ndef unpack(qmatrix: torch.Tensor, direction: str = \"column\"):\n    \"\"\"\n    Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.\n    Args:\n        qmatrix (torch.Tensor): matrix of packed integers\n        direction (str): direction of unpacking, either \"column\" or \"row\"\n    Returns:\n        imatrix (torch.Tensor): matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, 32, 4, device=qmatrix.device)\n\n    if direction == \"column\":\n        imatrix = torch.bitwise_right_shift(\n            qmatrix[:, :, None], shifts[None, None, :]\n        ).view(qmatrix.shape[0], -1)\n\n    elif direction == \"row\":\n        imatrix = torch.bitwise_right_shift(\n            qmatrix[:, None, :], shifts[None, :, None]\n        ).view(-1, qmatrix.shape[-1])\n\n    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow\n\n    return imatrix\n\n\ndef apply_order(\n    imatrix: torch.Tensor,\n    direction: str = \"column\",\n    order: List[int] = AWQ_PACK_ORDER,\n):\n    \"\"\"\n    Applies the order to a 4-bit integer matrix.\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n        direction (str): direction of applying order, either \"column\" or \"row\"\n        order (List[int]): order to apply, default is AWQ_PACK_ORDER\n    Returns:\n        imatrix (torch.Tensor): matrix of integers\n    \"\"\"\n    if direction == \"column\":\n        imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)\n    elif direction == \"row\":\n        imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)\n\n    return imatrix\n\n\ndef fast_awq_to_gptq(qweight, qzeros):\n    # awq uses column packing for both weights and zeros\n    izeros = unpack(qzeros, direction=\"column\")\n    iweights = unpack(qweight, direction=\"column\")\n\n    # Reverse the order of the iweight and izeros tensors\n    izeros = apply_order(izeros, direction=\"column\", order=REVERSE_AWQ_PACK_ORDER)\n    iweights = apply_order(iweights, direction=\"column\", order=REVERSE_AWQ_PACK_ORDER)\n    # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)\n    izeros = izeros - 1\n    # exllama uses row packing for weights and column packing for zeros\n    qzeros = pack(izeros, direction=\"column\")\n    qweight = pack(iweights, direction=\"row\")\n\n    return qweight, qzeros\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/awq/quantize/__init__.py",
    "content": "from .hpu import WQLinear\n\n__all__ = [\"WQLinear\"]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/awq/quantize/hpu.py",
    "content": "from typing import Optional\nimport torch\nimport torch.nn as nn\n\ntry:\n    import habana_frameworks.torch.hpu  # noqa: F401\n\n    convert_from_uint4 = torch.ops.hpu.convert_from_uint4\nexcept Exception as e:\n    hpu_import_exception = e\n\n    def error_raiser_hpu(*args, **kwargs):\n        raise ValueError(\n            f\"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}\"\n        )\n\n    convert_from_uint4 = error_raiser_hpu\n\nAWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n\n\ndef unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):\n    shifts = torch.arange(0, 32, bits, device=qzeros.device)\n\n    # unpacking columnwise\n    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(\n        torch.int8  # smallest dtype available\n    )\n    iweights = iweights.view(iweights.shape[0], -1)\n\n    # unpacking columnwise\n    if qzeros is not None:\n        izeros = torch.bitwise_right_shift(\n            qzeros[:, :, None], shifts[None, None, :]\n        ).to(\n            torch.int8  # smallest dtype available\n        )\n        izeros = izeros.view(izeros.shape[0], -1)\n    else:\n        izeros = qzeros\n\n    return iweights, izeros\n\n\ndef reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):\n    reverse_order_tensor = torch.arange(\n        iweights.shape[-1],\n        dtype=torch.int32,\n        device=izeros.device,\n    )\n    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)\n    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]\n    reverse_order_tensor = reverse_order_tensor.view(-1)\n\n    if izeros is not None:\n        izeros = izeros[:, reverse_order_tensor]\n    iweights = iweights[:, reverse_order_tensor]\n\n    return iweights, izeros\n\n\ndef unpack_weight_and_zeros(qweight, qzeros, bits):\n    # Unpack the qweight and qzeros tensors\n    iweight, izeros = unpack_awq(qweight, qzeros, bits)\n    # Reverse the order of the iweight and izeros tensors\n    iweight, izeros = reverse_awq_order(iweight, izeros, bits)\n\n    # overflow checks\n    iweight = torch.bitwise_and(iweight, (2**bits) - 1)\n    izeros = torch.bitwise_and(izeros, (2**bits) - 1)\n\n    return iweight, izeros\n\n\ndef pack_tensor(input, bits=4):\n    normal = input.to(torch.int32)\n    q = torch.zeros(\n        (normal.shape[0], normal.shape[1] // 32 * bits),\n        dtype=torch.int32,\n        device=input.device,\n    )\n    i = 0\n    col = 0\n    while col < q.shape[1]:\n        for j in range(i, i + (32 // bits)):\n            q[:, col] |= normal[:, j] << (bits * (j - i))\n        i += 32 // bits\n        col += 1\n    q = q.to(torch.int32)\n    return q\n\n\nclass WQLinear(nn.Module):\n    def __init__(\n        self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]\n    ):\n        super().__init__()\n\n        if w_bit not in [4]:\n            raise NotImplementedError(\"Only 4-bit are supported for now.\")\n\n        self.in_features = qweight.shape[0]\n        self.out_features = qweight.shape[1] * 32 // w_bit\n\n        self.w_bit = w_bit\n        self.group_size = group_size if group_size != -1 else self.in_features\n        # quick sanity check (make sure aligment)\n        assert self.in_features % self.group_size == 0\n        assert self.out_features % (32 // self.w_bit) == 0\n\n        self.qweight = qweight\n        self.qzeros = qzeros\n        self.scales = scales\n        self.bias = bias\n        self._preprocessing()\n\n    def _preprocessing(self):\n        device = self.qweight.device\n        weight, zeros = unpack_weight_and_zeros(\n            self.qweight.cpu(), self.qzeros.cpu(), self.w_bit\n        )\n        self.qweight = pack_tensor(weight).to(device)\n        self.qzeros = pack_tensor(zeros).to(device)\n\n    @torch.no_grad()\n    def forward(self, x):\n        out_shape = x.shape[:-1] + (self.out_features,)\n        x = x.reshape(-1, x.shape[-1])\n        weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)\n        outputs = torch.matmul(x, weights)\n\n        outputs = outputs + self.bias if self.bias is not None else outputs\n        outputs = outputs.reshape(out_shape)\n        return outputs\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/bnb.py",
    "content": "from dataclasses import dataclass\n\nimport bitsandbytes as bnb\nimport torch\nfrom bitsandbytes.nn import Int8Params, Params4bit\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\n@dataclass\nclass BNBWeight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)\n\n\nclass Linear8bitLt(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n        has_fp16_weights=True,\n        memory_efficient_backward=False,\n        threshold=0.0,\n        index=None,\n    ):\n        super().__init__()\n        assert (\n            not memory_efficient_backward\n        ), \"memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0\"\n        self.state = bnb.MatmulLtState()\n        self.index = index\n\n        # Necessary for stacked layers\n        self.state.threshold = threshold\n        self.state.has_fp16_weights = has_fp16_weights\n        self.state.memory_efficient_backward = memory_efficient_backward\n        if threshold > 0.0 and not has_fp16_weights:\n            self.state.use_pool = True\n\n        self.weight = Int8Params(\n            weight.data,\n            has_fp16_weights=has_fp16_weights,\n            requires_grad=has_fp16_weights,\n        )\n        self.weight.cuda(weight.device)\n        self.bias = bias\n\n    def init_8bit_state(self):\n        self.state.CB = self.weight.CB\n        self.state.SCB = self.weight.SCB\n        self.weight.CB = None\n        self.weight.SCB = None\n\n    def forward(self, x: torch.Tensor):\n        self.state.is_training = self.training\n        if self.weight.CB is not None:\n            self.init_8bit_state()\n\n        # weights are cast automatically as Int8Params, but the bias has to be cast manually\n        if self.bias is not None and self.bias.dtype != x.dtype:\n            self.bias.data = self.bias.data.to(x.dtype)\n\n        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)\n\n        if not self.state.has_fp16_weights:\n            if self.state.CB is not None and self.state.CxB is not None:\n                # we converted 8-bit row major to turing/ampere format in the first inference pass\n                # we no longer need the row-major weight\n                del self.state.CB\n                self.weight.data = self.state.CxB\n        return out\n\n\n@dataclass\nclass BNBFP4Weight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear4bit(self.weight, bias, quant_type=\"fp4\")\n\n\n@dataclass\nclass BNBNF4Weight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear4bit(self.weight, bias, quant_type=\"nf4\")\n\n\nclass Linear4bit(torch.nn.Module):\n    def __init__(self, weight, bias, quant_type):\n        super().__init__()\n        self.weight = Params4bit(\n            weight.data,\n            requires_grad=False,\n            compress_statistics=True,\n            quant_type=quant_type,\n        )\n        self.compute_dtype = None\n        self.weight.cuda(weight.device)\n        self.bias = bias\n\n    def forward(self, x: torch.Tensor):\n        # weights are cast automatically as Int8Params, but the bias has to be cast manually\n        if self.bias is not None and self.bias.dtype != x.dtype:\n            self.bias.data = self.bias.data.to(x.dtype)\n\n        if getattr(self.weight, \"quant_state\", None) is None:\n            print(\n                \"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.\"\n            )\n        inp_dtype = x.dtype\n        if self.compute_dtype is not None:\n            x = x.to(self.compute_dtype)\n\n        bias = None if self.bias is None else self.bias.to(self.compute_dtype)\n        out = bnb.matmul_4bit(\n            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state\n        )\n\n        out = out.to(inp_dtype)\n\n        return out\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/compressed_tensors/__init__.py",
    "content": "from .loader import CompressedTensorsLoader\n\n__all__ = [\"CompressedTensorsLoader\"]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/compressed_tensors/loader.py",
    "content": "from typing import Any, Dict, List, Union\n\nfrom compressed_tensors import QuantizationConfig, QuantizationStatus\nfrom compressed_tensors.config import CompressionFormat\nfrom compressed_tensors.quantization import (\n    QuantizationScheme,\n    QuantizationType,\n    find_name_or_class_matches,\n)\nfrom loguru import logger\nfrom pydantic import ValidationError\nfrom torch import nn\n\nfrom text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    UnquantizedWeight,\n    Weights,\n    WeightsLoader,\n)\n\n# compressed-tensors can match modules as quantization targets. However,\n# they need to be objects rather than classes or class names. Since we\n# need to match `Linear` targets, make an instance that can be re-used.\n_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)\n\n\nclass CompressedTensorsLoader(WeightsLoader):\n    \"\"\"Loader for checkpoints stored in the compressed-tensors format.\"\"\"\n\n    def __init__(self, config: Dict[str, Any]):\n        quantization_config_raw = config.get(\"quantization_config\")\n        if quantization_config_raw is None:\n            # `compression_config` was renamed to `quantization_config`; support\n            # retained for backward compatibility.\n            quantization_config_raw = config.get(\"compression_config\")\n        if quantization_config_raw is None:\n            raise ValueError(\n                \"Checkpoint does not have compressed-tensors configuration\"\n            )\n\n        try:\n            quantization_config = QuantizationConfig.model_validate(\n                quantization_config_raw\n            )\n        except ValidationError as e:\n            raise ValueError(\"Cannot parse compressed-tensors configuration\") from e\n\n        if quantization_config.quantization_status not in (\n            QuantizationStatus.COMPRESSED,\n            QuantizationStatus.FROZEN,\n        ):\n            raise ValueError(\n                f\"Model quantization was not finished, status was: {quantization_config.quantization_status}\"\n            )\n\n        self.ignore = (\n            quantization_config.ignore if quantization_config.ignore is not None else []\n        )\n        self.loaders = self._get_target_loaders(quantization_config)\n\n        for target, loader in self.loaders.items():\n            log_once(\n                logger.info,\n                f\"Using {loader} for compressed-tensors target '{target}'\",\n            )\n\n    def get_weights(self, weights: Weights, prefix: str):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights(weights, prefix)\n\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights_col_packed(weights, prefix, block_sizes)\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        loader = self._lookup_loader(prefixes[0])\n        return loader.get_multi_weights_col(weights, prefixes, dim)\n\n    def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):\n        loader = self._lookup_loader(prefixes[0])\n        return loader.get_multi_weights(weights, prefixes, dim)\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights_row(weights, prefix)\n\n    def _get_target_loaders(\n        self, quantization_config: QuantizationConfig\n    ) -> Dict[str, WeightsLoader]:\n        \"\"\"\n        A compressed-tensors checkpoint can use different quantizations\n        for different targets. This method returns a dictionary with a\n        loader per target.\n        \"\"\"\n\n        loaders: Dict[str, WeightsLoader] = {}\n\n        format = quantization_config.format\n\n        for group_name, group in quantization_config.config_groups.items():\n            # The group configuration can be a string, but does that ever\n            # happen in a serialized quantization config?\n            assert isinstance(group, QuantizationScheme)\n\n            loader = self._create_loader_for_group(format, group_name, group)\n\n            # A quantized parameter group can have multiple targets, add the\n            # loader for all the targets.\n            for target in group.targets:\n                if target in loaders:\n                    raise ValueError(\n                        f\"Target '{target} has multiple configured loaders'\"\n                    )\n                loaders[target] = loader\n\n        return loaders\n\n    def _create_loader_for_group(\n        self, format: str, group_name: str, group: QuantizationScheme\n    ) -> WeightsLoader:\n        \"\"\"\n        Find and create a loader for the group with the given quantization\n        scheme.\n        \"\"\"\n        # NOTE: we ignore group.output_activations because we don't support\n        #       output quantization yet.\n\n        input_activations = group.input_activations\n        weights = group.weights\n        if (\n            format\n            in {\n                CompressionFormat.float_quantized.value,\n                CompressionFormat.naive_quantized.value,\n            }\n            and weights is not None\n            and weights.type == QuantizationType.FLOAT\n            and weights.num_bits == 8\n        ):\n            # FP W8A8 or W8A16.\n            return W8ANFpLoader(input_activations=input_activations, weights=weights)\n        else:\n            raise ValueError(\n                f\"Group '{group_name}' has unsupported compressed-tensors configurtion\"\n            )\n\n    def _lookup_loader(self, prefix: str) -> WeightsLoader:\n        \"\"\"\n        Look up the loader to use for a given parameter name (prefix).\n        \"\"\"\n\n        if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:\n            return DefaultWeightsLoader(UnquantizedWeight)\n\n        # We currently only handle linear layers, so unconditionally pass\n        # a `Linear` instance.\n        targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())\n        if len(targets) == 0:\n            raise ValueError(\n                f\"Cannot find compressed-tensors target for prefix: {prefix}\"\n            )\n        return self.loaders[targets[0]]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/compressed_tensors/w8an_fp.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom compressed_tensors.quantization import QuantizationArgs, QuantizationType\n\nfrom text_generation_server.layers.fp8 import (\n    Fp8Weight,\n    _load_scalar_or_matrix_scale,\n    requantize_with_max_scale,\n)\nfrom text_generation_server.utils.weights import Weights, WeightsLoader\n\n\nclass W8ANFpLoader(WeightsLoader):\n    \"\"\"\n    Loader for W8A8/W8A16 FP compressed-tensors parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        input_activations: Optional[QuantizationArgs],\n        weights: QuantizationArgs,\n    ):\n        assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8\n\n        # We ignore the `strategy` option which sets the scales to be\n        # per-tensor, per-channel or per-token. What scales are supported\n        # is dependent on the kernels used (e.g. cutlass can do tokenwise,\n        # Torch cannot, and FP8-Marlin does not quantize inputs at all).\n        # So, instead we try to use the best-possible configuration.\n\n        self.load_weight_scale = not weights.dynamic\n        self.load_input_scale = (\n            input_activations is not None and not input_activations.dynamic\n        )\n        self.force_w8a16 = (\n            input_activations is not None and input_activations.num_bits == 16\n        )\n\n    def __str__(self) -> str:\n        def scale_to_str(scale):\n            return \"static\" if scale else \"dynamic\"\n\n        quantization_type = f\"W8A{16 if self.force_w8a16 else 8}\"\n\n        return f\"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        w = weights.get_tensor(f\"{prefix}.weight\")\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = (\n                weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n                .reshape(-1)\n                .expand(w.shape[0])\n            )\n            logical_widths = [w.shape[0]]\n            w, weight_scale = requantize_with_max_scale(\n                w,\n                weight_scale.unsqueeze(-1).to(weights.device),\n                logical_widths,\n                weights.dtype,\n            )\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(\n                f\"{prefix}.input_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        w = weights.get_packed_sharded(\n            f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n        )\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n            if weight_scale.numel() > 1:\n                weight_scale = weights.get_packed_sharded(\n                    f\"{prefix}.weight_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            weight_scale = weight_scale.reshape(-1).expand(w.shape[0])\n            logical_widths = [w.shape[0]]\n            w, weight_scale = requantize_with_max_scale(\n                w,\n                weight_scale.unsqueeze(-1).to(weights.device),\n                logical_widths,\n                weights.dtype,\n            )\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n            if input_scale.numel() > 1:\n                input_scale = weights.get_packed_sharded(\n                    f\"{prefix}.input_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            input_scale = input_scale.reshape(-1).max()\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [\n            weights.get_sharded(f\"{p}.weight\", dim=0, to_device=False) for p in prefixes\n        ]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.weight_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n            ]\n            weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)\n            logical_widths = [x[0] for x in shapes]\n            w, weight_scale = requantize_with_max_scale(\n                w,\n                weight_scale.unsqueeze(-1).to(weights.device),\n                logical_widths,\n                weights.dtype,\n            )\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.input_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_multi_weights(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [weights.get_tensor(f\"{p}.weight\", to_device=False) for p in prefixes]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        weight_scale = None\n\n        if self.load_weight_scale:\n            weight_scale = [\n                weights.get_tensor(f\"{p}.weight_scale\", to_dtype=False)\n                .reshape(-1)\n                .expand(shape[0])\n                for p, shape in zip(prefixes, shapes)\n            ]\n            weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)\n            logical_widths = [x[0] for x in shapes]\n            w, weight_scale = requantize_with_max_scale(\n                w,\n                weight_scale.unsqueeze(-1).to(weights.device),\n                logical_widths,\n                weights.dtype,\n            )\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = [\n                weights.get_tensor(f\"{p}.input_scale\", to_dtype=False)\n                .reshape(-1)\n                .expand(shape[0])\n                for p, shape in zip(prefixes, shapes)\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        w = weights.get_sharded(f\"{prefix}.weight\", dim=1)\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n            weight_scale = weight_scale.reshape(-1).expand(w.shape[0])\n            logical_widths = [w.shape[0]]\n            w, weight_scale = requantize_with_max_scale(\n                w,\n                weight_scale.unsqueeze(-1).to(weights.device),\n                logical_widths,\n                weights.dtype,\n            )\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(\n                f\"{prefix}.input_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/conv.py",
    "content": "from accelerate import init_empty_weights\nimport torch\n\n\n@classmethod\ndef load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    bias = weights.get_tensor(f\"{prefix}.bias\")\n    with init_empty_weights():\n        conv2d = cls(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n        )\n\n    conv2d.weight = torch.nn.Parameter(weight)\n    conv2d.bias = torch.nn.Parameter(bias)\n    return conv2d\n\n\n@classmethod\ndef load_conv2d_no_bias(\n    cls, prefix, weights, in_channels, out_channels, kernel_size, stride\n):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    with init_empty_weights():\n        conv2d = cls(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n        )\n\n    conv2d.weight = torch.nn.Parameter(weight)\n    conv2d.bias = None\n    return conv2d\n\n\ntorch.nn.Conv2d.load = load_conv2d\ntorch.nn.Conv2d.load_no_bias = load_conv2d_no_bias\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/exl2.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Union\n\nimport torch\nfrom text_generation_server.utils.weights import Weight, Weights, WeightsLoader\n\n\n@dataclass\nclass Exl2Weight(Weight):\n    \"\"\"\n    Exllama2 exl2 quantized weights.\n    \"\"\"\n\n    q_weight: torch.Tensor\n    q_scale: torch.Tensor\n    q_invperm: torch.Tensor\n    q_scale_max: torch.Tensor\n    q_groups: torch.Tensor\n\n    def __post_init__(self):\n        self.q_scale_max /= 256\n        self.q_invperm = self.q_invperm.short()\n\n    @property\n    def device(self) -> torch.device:\n        return self.q_weight.device\n\n    def get_linear(self, bias: torch.Tensor):\n        from text_generation_server.layers.gptq import ExllamaQuantLinear\n\n        return ExllamaQuantLinear(self, bias)\n\n\nclass Exl2WeightsLoader(WeightsLoader):\n    \"\"\"Loader for exl2-quantized weights.\"\"\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        try:\n            q_weight = weights.get_tensor(f\"{prefix}.q_weight\")\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `exl2`-quantized weight, make sure the model is already quantized.\"\n            )\n\n        q_scale = weights.get_tensor(f\"{prefix}.q_scale\")\n        q_invperm = weights.get_tensor(f\"{prefix}.q_invperm\")\n        q_scale_max = weights.get_tensor(f\"{prefix}.q_scale_max\")\n        q_groups = weights.get_tensor(f\"{prefix}.q_groups\")\n\n        return Exl2Weight(\n            q_weight=q_weight,\n            q_scale=q_scale,\n            q_invperm=q_invperm,\n            q_scale_max=q_scale_max,\n            q_groups=q_groups,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        raise RuntimeError(\"Column-packed weights are not supported for exl\")\n\n    def get_weights_col(self, weights: Weights, prefix: str):\n        # Sharding is not yet supported, so we return the weights as-is.\n        return self.get_weights(weights, prefix)\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        raise ValueError(\"get_multi_weights_col is not supported for exl2\")\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        # Sharding is not yet supported, so we return the weights as-is.\n        return self.get_weights(weights, prefix)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/fp8.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Type, Union, List\n\nimport torch\n\nfrom text_generation_server.utils.weights import (\n    Weight,\n    WeightsLoader,\n    UnquantizedWeight,\n    Weights,\n)\n\nfrom vllm_hpu_extension.ops import scaled_fp8_quant\nfrom vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2\n\nquant_dtype: torch.dtype = torch.float8_e4m3fn\nFP8_MAX = torch.finfo(torch.float8_e4m3fn).max\nif is_hpu_gaudi2():\n    FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max\n\n\ndef pad_weight(weight, block_size):\n    \"\"\"Pads a matrix to make its dimensions multiples of block_size.\"\"\"\n    M, N = weight.shape[-2:]\n    block_size_m, block_size_n = block_size\n    pad_M = (block_size_m - M % block_size_m) % block_size_m\n    pad_N = (block_size_n - N % block_size_n) % block_size_n\n\n    if pad_M == 0 and pad_N == 0:\n        return weight, M, N  # No padding needed\n    padded_weight = torch.nn.functional.pad(\n        weight, (0, pad_N, 0, pad_M), mode=\"constant\", value=0\n    )\n    return padded_weight, M, N  # Return original dimensions for unpadding\n\n\ndef unpad_weight(weight, original_M, original_N, keep_first_dim=False):\n    \"\"\"Removes padding from the matrix to restore its original shape.\"\"\"\n    if (weight.shape[-2] == original_M) and (weight.shape[-1] == original_N):\n        return weight\n    if keep_first_dim:\n        return weight[:, :original_M, :original_N]\n    else:\n        return weight[:original_M, :original_N]\n\n\ndef pad_block_fp8_weight_naive(weight, weight_scale, block_size):\n\n    assert len(block_size) == 2\n\n    block_size_m, block_size_n = block_size\n    weight_scale_m, weight_scale_n = weight_scale.shape[-2:]\n\n    weight, orig_M, orig_N = pad_weight(weight, block_size)\n    M, N = weight.shape[-2:]\n\n    assert weight_scale_m == M // block_size_m\n    assert weight_scale_n == N // block_size_n\n\n    return weight, orig_M, orig_N\n\n\ndef dynamic_quant(data, single_scale=False):\n    if single_scale:\n        scale = ((torch.abs(data)).max() + 1e-8) / FP8_MAX\n    else:\n        scale = ((torch.abs(data)).max(dim=-1).values + 1e-8) / FP8_MAX\n        scale = scale.unsqueeze(-1)\n    data_fp8 = torch.ops.hpu.cast_to_fp8_v2(\n        data, 1.0 / scale, False, False, torch.float8_e4m3fn\n    )[0]\n    return data_fp8, scale.float()\n\n\ndef dequant_block_fp8_weight_naive(\n    weight,\n    weight_scale,\n    block_size,\n    dtype=torch.bfloat16,\n    original_M=None,\n    original_N=None,\n    do_unpad=False,\n):\n    if weight_scale is None:\n        return weight\n    assert len(block_size) == 2\n\n    weight_shape_len = len(weight.shape)\n\n    block_size_m, block_size_n = block_size\n\n    # mul scale\n    if weight_shape_len == 2:\n        weight_scale_m, weight_scale_n = weight_scale.shape\n        weight_scale = weight_scale.view(weight_scale_m, 1, weight_scale_n, 1)\n        weight = weight.view(weight_scale_m, block_size_m, weight_scale_n, block_size_n)\n        if is_hpu_gaudi2():\n            fake_weight = weight.cpu().to(dtype).to(weight.device)\n            dequant_weight = fake_weight * weight_scale.to(dtype)\n        else:\n            dequant_weight = weight.to(dtype) * weight_scale.to(dtype)\n        dequant_weight = dequant_weight.view(\n            weight_scale_m * block_size_m, weight_scale_n * block_size_n\n        )\n        keep_first_dim = False\n    elif weight_shape_len == 3:\n        fd, weight_scale_m, weight_scale_n = weight_scale.shape\n        weight_scale = weight_scale.view(fd, weight_scale_m, 1, weight_scale_n, 1)\n        weight = weight.view(\n            fd, weight_scale_m, block_size_m, weight_scale_n, block_size_n\n        )\n        if is_hpu_gaudi2():\n            fake_weight = weight.cpu().to(dtype).to(weight.device)\n            dequant_weight = fake_weight * weight_scale.to(dtype)\n        else:\n            dequant_weight = weight.to(dtype) * weight_scale.to(dtype)\n        dequant_weight = dequant_weight.view(\n            fd, weight_scale_m * block_size_m, weight_scale_n * block_size_n\n        )\n        keep_first_dim = True\n    else:\n        raise ValueError(\"Only support original weight shape is either 2 or 3\")\n\n    if do_unpad:\n        dequant_weight = unpad_weight(\n            dequant_weight, original_M, original_N, keep_first_dim=keep_first_dim\n        )\n\n    return dequant_weight\n\n\ndef apply_block_fp8_linear_hpu_dynamic(\n    input: torch.Tensor,\n    weight: torch.Tensor,\n    weight_scale: torch.Tensor,\n    input_scale: Optional[torch.Tensor] = None,\n    bias: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n    # View input as 2D matrix for fp8 methods\n    input_2d = input.view(-1, input.shape[-1])\n    output_shape = [*input.shape[:-1], weight.shape[0]]\n\n    x_fp8, x_scale = dynamic_quant(input_2d)\n\n    output = torch.ops.hpu.fp8_gemm_v2(\n        x_fp8,\n        False,\n        weight,\n        True,\n        None,\n        torch.bfloat16,\n        x_scale,\n        weight_scale,\n        None,\n        False,\n    )\n    if bias is not None:\n        output = output + bias\n    return output.to(dtype=input.dtype).view(*output_shape)\n\n\ndef get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:\n    \"\"\"\n    Return an FP8 linear `Module` that is compatible with the current system.\n    \"\"\"\n    # On other systems let Torch decide if the hardware supports FP8.\n    return Fp8Linear\n\n\ndef normalize_e4m3fn_to_native_float8(\n    weight: torch.Tensor,\n    weight_scale: torch.Tensor,\n    input_scale: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:\n    return weight, weight_scale, input_scale\n\n\ndef per_tensor_dequantize(\n    tensor: torch.Tensor,\n    inv_scale: Union[float, torch.Tensor],\n    dtype: torch.dtype = torch.float16,\n) -> torch.Tensor:\n    device = tensor.device\n    dtype = torch.bfloat16\n    if is_hpu_gaudi2():\n        # dequant on cpu to avoid nan on gaudi2\n        tensor = tensor.to(\"cpu\")\n\n    fake_qweight = tensor.to(dtype).to(device)\n    dq_weight = fake_qweight * inv_scale\n    return dq_weight\n\n\ndef requantize_with_max_scale(\n    weight: torch.Tensor,\n    weight_scale: torch.Tensor,\n    logical_widths: int,\n    dtype: torch.dtype,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # Max scale to be used for requanitzation.\n    max_w_scale = weight_scale.max()\n\n    if is_hpu_gaudi2():\n        max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()\n\n    start = 0\n    for idx, logical_width in enumerate(logical_widths):\n        end = start + logical_width\n        weight_dq = per_tensor_dequantize(\n            weight[start:end, :], weight_scale[start:end, :], dtype\n        )\n        weight[start:end, :], max_w_scale_normalized = fp8_quantize(\n            weight_dq, max_w_scale\n        )\n        start = end\n\n    return weight, max_w_scale_normalized\n\n\ndef fp8_quantize(\n    weight: torch.Tensor,\n    scale: Optional[torch.Tensor] = None,\n    scale_upper_bound: Optional[torch.Tensor] = None,\n    qdtype: torch.dtype = torch.float8_e4m3fn,\n    scalar: bool = False,\n):\n    \"\"\"\n    This function returns a reciprocal of the scale, so that a tensor can be unscaled\n    by multiplying it with the returned scale. If a scale is given through the `scale`\n    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can\n    be used without modification).\n    \"\"\"\n    shape = weight.shape\n    qweight, scale = scaled_fp8_quant(\n        weight.reshape(-1, shape[-1]),\n        scale=scale,\n        scale_ub=scale_upper_bound,\n        # TODO: don't do this when we have to use the Torch kernel.\n        use_per_token_if_dynamic=not scalar,\n    )\n\n    return qweight.reshape(shape), scale\n\n\nclass HybridFP8UnquantLoader(WeightsLoader):\n    \"\"\"Weight loader that loads FP8 and unquantized Torch tensors.\"\"\"\n\n    def __init__(\n        self,\n        activation_scale_ub: Optional[float],\n        to_fp8: bool,\n        weight_block_size: Optional[List[int]] = None,\n    ):\n        self.activation_scale_ub = activation_scale_ub\n        self.to_fp8 = to_fp8\n        self.weight_block_size = weight_block_size\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        w = weights.get_tensor(f\"{prefix}.weight\")\n\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                scale = weights.get_tensor(f\"{prefix}.weight_scale_inv\")\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n            # FP8 branch\n            scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n            scale = scale.reshape(-1).expand(w.shape[0])\n            logical_widths = [w.shape[0]]\n            w, scale = requantize_with_max_scale(\n                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype\n            )\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = (\n                    weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n                    .reshape(-1)\n                    .max()\n                )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        w = weights.get_packed_sharded(\n            f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n        )\n\n        if w.dtype == torch.float8_e4m3fn:\n            # FP8 branch\n            scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if scale.numel() > 1:\n                scale = weights.get_packed_sharded(\n                    f\"{prefix}.weight_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            scale = scale.reshape(-1).expand(w.shape[0])\n            logical_widths = [w.shape[0]]\n            w, scale = requantize_with_max_scale(\n                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype\n            )\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = weights.get_tensor(\n                    f\"{prefix}.input_scale\", to_dtype=False\n                )\n                if input_scale.numel() > 1:\n                    input_scale = weights.get_packed_sharded(\n                        f\"{prefix}.input_scale\",\n                        dim=0,\n                        block_sizes=block_sizes,\n                        to_dtype=False,\n                    )\n                input_scale = input_scale.reshape(-1).max()\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [\n            weights.get_sharded(f\"{p}.weight\", dim=0, to_device=False) for p in prefixes\n        ]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        # FP8 branch\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                scale = [\n                    weights.get_sharded(f\"{p}.weight_scale_inv\", dim=0, to_device=False)\n                    for p in prefixes\n                ]\n                scale = torch.cat(scale, dim=dim)\n                scale = scale.to(weights.device)\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n\n            scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.weight_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n            ]\n            scale = torch.cat(scale, dim=0).reshape(-1)\n\n            logical_widths = [x[0] for x in shapes]\n            w, scale = requantize_with_max_scale(\n                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype\n            )\n\n            input_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.input_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_multi_weights(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [weights.get_tensor(f\"{p}.weight\", to_device=False) for p in prefixes]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        # FP8 branch\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                scale = [\n                    weights.get_tensor(f\"{p}.weight_scale_inv\", to_device=False)\n                    for p in prefixes\n                ]\n                scale = torch.cat(scale, dim=dim)\n                scale = scale.to(weights.device)\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n\n            scale = [\n                weights.get_tensor(f\"{p}.weight_scale\", to_dtype=False)\n                .reshape(-1)\n                .expand(shape[0])\n                for p, shape in zip(prefixes, shapes)\n            ]\n            scale = torch.cat(scale, dim=0).reshape(-1)\n\n            logical_widths = [x[0] for x in shapes]\n            w, scale = requantize_with_max_scale(\n                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype\n            )\n\n            input_scale = [\n                weights.get_tensor(f\"{p}.input_scale\", to_dtype=False).reshape(-1)\n                for p in prefixes\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        w = weights.get_sharded(f\"{prefix}.weight\", dim=1)\n        # FP8 branch\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.\n                scale = weights.get_sharded(f\"{prefix}.weight_scale_inv\", dim=1)\n\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n\n            scale = (\n                weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n                .reshape(-1)\n                .expand(w.shape[0])\n            )\n            logical_widths = [w.shape[0]]\n            w, scale = requantize_with_max_scale(\n                w, scale.unsqueeze(-1).to(weights.device), logical_widths, weights.dtype\n            )\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = (\n                    weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n                    .reshape(-1)\n                    .max()\n                )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n\n@dataclass\nclass Fp8Weight(Weight):\n    weight: torch.Tensor\n    dtype: torch.dtype\n    weight_scale: Optional[torch.Tensor] = None\n    input_scale: Optional[torch.Tensor] = None\n    activation_scale_ub: Optional[float] = None\n    force_w8a16: bool = False\n    weight_block_size: Optional[List[int]] = None\n\n    def get_linear(self, bias: torch.Tensor):\n        if self.weight_scale is None:\n            return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(\n                self.weight, bias, self.dtype\n            )\n        # This is not checked by the fbgemm kernels, but they require contiguous\n        # memory. Can be non-contiguous when we e.g. expand from scalars.\n        self.weight_scale = self.weight_scale.contiguous()\n        return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(\n            weight=self.weight,\n            scale=self.weight_scale,\n            dtype=self.dtype,\n            bias=bias,\n            input_scale=self.input_scale,\n            scale_upper_bound=self.activation_scale_ub,\n            weight_block_size=self.weight_block_size,\n        )\n\n\nclass Fp8Linear(torch.nn.Module):\n    _device_identity_cache = {}\n\n    def __init__(\n        self,\n        qweight: torch.Tensor,\n        scale: torch.Tensor,\n        dtype: torch.dtype,\n        bias: Optional[torch.Tensor] = None,\n        input_scale: Optional[torch.Tensor] = None,\n        scale_upper_bound: Optional[float] = None,\n        weight_block_size: Optional[List[int]] = None,\n    ) -> None:\n        super().__init__()\n\n        self.dtype = dtype\n        self.qweight = qweight\n        self.scale = scale.float()\n        self.input_scale = input_scale.float() if input_scale is not None else None\n        self.weight_block_size = weight_block_size\n        self.scale_upper_bound = scale_upper_bound\n\n        self.bias = bias if bias is not None else None\n\n    @classmethod\n    def from_unquant(cls, weight, bias, dtype):\n        qweight, scale = fp8_quantize(weight, scalar=True)\n        return cls(\n            qweight=qweight,\n            scale=scale,\n            dtype=dtype,\n            bias=bias,\n            input_scale=None,\n            scale_upper_bound=None,\n        )\n\n    @classmethod\n    def from_fp8(\n        cls,\n        weight: torch.Tensor,\n        scale: torch.Tensor,\n        dtype: torch.dtype,\n        bias: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> \"Fp8Linear\":\n        input_scale = kwargs.get(\"input_scale\", None)\n        scale_upper_bound = kwargs.get(\"scale_upper_bound\", None)\n        weight_block_size = kwargs.get(\"weight_block_size\", None)\n\n        if weight_block_size is not None:\n            weight, orig_M, orig_N = pad_block_fp8_weight_naive(\n                weight, scale, weight_block_size\n            )\n            weight, scale = dynamic_quant(\n                dequant_block_fp8_weight_naive(\n                    weight,\n                    scale,\n                    weight_block_size,\n                    original_M=orig_M,\n                    original_N=orig_N,\n                    do_unpad=True,\n                )\n            )\n            scale = scale.squeeze(-1)\n\n        return cls(\n            qweight=weight,\n            scale=scale,\n            input_scale=input_scale,\n            scale_upper_bound=scale_upper_bound,\n            bias=bias,\n            dtype=dtype,\n            weight_block_size=weight_block_size,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if self.weight_block_size is not None or self.input_scale is None:\n            return apply_block_fp8_linear_hpu_dynamic(\n                input, self.qweight, self.scale, self.input_scale, self.bias\n            )\n\n        x_fp8 = torch.ops.hpu.cast_to_fp8_v2(\n            input, 1.0 / self.input_scale, False, False, torch.float8_e4m3fn\n        )[0]\n        return torch.ops.hpu.fp8_gemm_v2(\n            A=x_fp8,\n            trans_A=False,\n            B=self.qweight,\n            trans_B=True,\n            D=None,\n            out_dtype=input.dtype,\n            A_scale_inv=self.input_scale,\n            B_scale_inv=self.scale,\n            bias=self.bias,\n            accumulate=False,\n        )\n\n\ndef _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):\n    scale = weights.get_tensor(prefix, to_dtype=False)\n\n    if scale.numel() > 1:\n        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)\n    return scale.reshape(-1).expand(shape[0])\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/gptq/__init__.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport torch\nfrom loguru import logger\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    Weight,\n    Weights,\n    WeightsLoader,\n    DefaultWeightsLoader,\n)\n\n\nfrom .hpu import QuantLinear\n\n\n@dataclass\nclass GPTQWeight(Weight):\n    qweight: torch.Tensor\n    qzeros: torch.Tensor\n    scales: torch.Tensor\n    g_idx: Optional[torch.Tensor]\n    bits: int\n    groupsize: int\n    use_awq_kernel: bool\n    use_exllama: bool\n\n    def __post_init__(self):\n        if self.scales.dtype == torch.float:\n            self.scales = self.scales.half()\n\n    @property\n    def device(self) -> torch.device:\n        return self.qweight.device\n\n    def get_linear(self, bias: torch.Tensor):\n        if self.use_awq_kernel:\n            try:\n                from text_generation_server.layers.awq.quantize import WQLinear\n\n                return WQLinear(\n                    w_bit=self.bits,\n                    group_size=self.groupsize,\n                    qweight=self.qweight,\n                    qzeros=self.qzeros,\n                    scales=self.scales,\n                    bias=bias,\n                )\n            except ImportError:\n                raise NotImplementedError(\n                    \"You do not seem to have awq installed, either install it (cd server &&  make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly\"\n                )\n        else:\n            return QuantLinear(\n                self.qweight,\n                self.qzeros,\n                self.scales,\n                self.g_idx,\n                bias,\n                self.bits,\n                self.groupsize,\n            )\n\n\nclass GPTQWeightsLoader(WeightsLoader):\n    \"\"\"\n    Loader for GPTQ- and AWQ-quantized weights.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        bits: int,\n        desc_act: bool,\n        groupsize: int,\n        quant_method: str,\n        quantize: str,\n        sym: bool,\n        modules_to_not_convert: List[str],\n    ):\n        self.bits = bits\n        self.desc_act = desc_act\n        self.groupsize = groupsize\n        self.quant_method = quant_method\n        self.quantize = quantize\n        self.sym = sym\n        self.modules_to_not_convert = modules_to_not_convert\n\n    def is_layer_skipped_quantization(\n        self, prefix: str, modules_to_not_convert: List[str]\n    ):\n        return any(module_name in prefix for module_name in modules_to_not_convert)\n\n    def get_weights(self, weights: Weights, prefix: str):\n        self._get_gptq_params(weights)\n\n        use_exllama = True\n        if self.bits != 4:\n            use_exllama = False\n\n        if self.desc_act:\n            log_once(logger.warning, \"Disabling exllama because desc_act=True\")\n            use_exllama = False\n\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights(weights, prefix)\n\n        try:\n            qweight = weights.get_tensor(f\"{prefix}.qweight\")\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`\"\n            )\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        else:\n            g_idx = None\n\n        qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n        scales = weights.get_tensor(f\"{prefix}.scales\")\n\n        if use_exllama and g_idx is not None:\n            g_idx = g_idx - g_idx[0]\n\n        if self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_exllama=use_exllama,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights_col_packed(\n                weights, prefix, block_sizes\n            )\n        try:\n            qweight = weights.get_packed_sharded(\n                f\"{prefix}.qweight\", dim=1, block_sizes=block_sizes\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized.\"\n            )\n        scales = weights.get_packed_sharded(\n            f\"{prefix}.scales\", dim=1, block_sizes=block_sizes\n        )\n        scales = scales.to(dtype=weights.dtype)\n\n        self._get_gptq_params(weights)\n\n        qzeros = weights.get_packed_sharded(\n            f\"{prefix}.qzeros\", dim=1, block_sizes=block_sizes\n        )\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        elif self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            g_idx = (\n                torch.arange(\n                    qweight.shape[0] * (32 // self.bits),\n                    device=qweight.device,\n                )\n                // self.groupsize\n            ).to(dtype=torch.int32)\n        else:\n            g_idx = None\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=False,\n        )\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)\n        try:\n            qweight = torch.cat(\n                [weights.get_sharded(f\"{p}.qweight\", dim=1) for p in prefixes], dim=1\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized\"\n            )\n\n        scales = torch.cat(\n            [weights.get_sharded(f\"{p}.scales\", dim=1) for p in prefixes], dim=1\n        )\n\n        self._get_gptq_params(weights)\n\n        qzeros = torch.cat(\n            [weights.get_sharded(f\"{p}.qzeros\", dim=1) for p in prefixes], dim=1\n        )\n\n        use_exllama = self.bits == 4 and self.quantize == \"gptq\" and not self.desc_act\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            w = [weights.get_tensor(f\"{p}.g_idx\") for p in prefixes]\n            for w2 in w[1:]:\n                torch.testing.assert_close(w2, w[0])\n            g_idx = w[0]\n        elif self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n        else:\n            g_idx = None\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def get_multi_weights(self, weights: Weights, prefixes: List[str], dim: int):\n        if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_multi_weights(weights, prefixes, dim)\n        try:\n            qweight = torch.cat(\n                [weights.get_tensor(f\"{p}.qweight\") for p in prefixes], dim=1\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized\"\n            )\n\n        scales = torch.cat([weights.get_tensor(f\"{p}.scales\") for p in prefixes], dim=1)\n\n        self._get_gptq_params(weights)\n\n        qzeros = torch.cat([weights.get_tensor(f\"{p}.qzeros\") for p in prefixes], dim=1)\n\n        use_exllama = self.bits == 4 and self.quantize == \"gptq\" and not self.desc_act\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            w = [weights.get_tensor(f\"{p}.g_idx\") for p in prefixes]\n            for w2 in w[1:]:\n                torch.testing.assert_close(w2, w[0])\n            g_idx = w[0]\n        elif self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                ).to(dtype=torch.int32)\n        else:\n            g_idx = None\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        self._get_gptq_params(weights)\n\n        use_exllama = True\n        desc_act = self.desc_act\n        if self.bits != 4:\n            use_exllama = False\n\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights_row(weights, prefix)\n\n        if self.desc_act:\n            log_once(logger.warning, \"Disabling exllama because desc_act=True\")\n            use_exllama = False\n\n        try:\n            qweight = weights.get_sharded(f\"{prefix}.qweight\", dim=0)\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`\"\n            )\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_sharded(f\"{prefix}.g_idx\", dim=0)\n        else:\n            g_idx = None\n\n        if weights.process_group.size() > 1:\n            if g_idx is not None:\n                if (\n                    not torch.equal(\n                        # Remove g_idx[0] to adapt the check with TP>1.\n                        (g_idx - g_idx[0]).cpu(),\n                        torch.tensor(\n                            [i // self.groupsize for i in range(g_idx.shape[0])],\n                            dtype=torch.int32,\n                        ),\n                    )\n                    and not (g_idx == 0).all()\n                ):\n                    # Exllama implementation does not support row tensor parallelism with act-order, as\n                    # it would require to reorder input activations that are split unto several GPUs\n                    use_exllama = False\n                    desc_act = True\n\n        from text_generation_server.layers.gptq import (\n            GPTQWeight,\n        )\n\n        if not desc_act and self.groupsize != -1:\n            qzeros = weights.get_sharded(f\"{prefix}.qzeros\", dim=0)\n            scales = weights.get_sharded(f\"{prefix}.scales\", dim=0)\n            if g_idx is not None:\n                # qzeros, scales sharded, and g_idx must be adjusted accordingly\n                g_idx = g_idx - g_idx[0]\n        else:\n            qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n            scales = weights.get_tensor(f\"{prefix}.scales\")\n\n        if self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def _get_gptq_params(self, weights: Weights):\n        if weights.has_tensor(\"gptq_bits\") and weights.has_tensor(\"gptq_groupsize\"):\n            self.bits = weights.get_tensor(\"gptq_bits\").item()\n            self.groupsize = weights.get_tensor(\"gptq_groupsize\").item()\n            self.desc_act = False\n            # `server quantize` used asymmetric quantization unconditionally\n            # before the `gptq_sym` setting tensor was added.\n            self.sym = (\n                weights.get_tensor(\"gptq_sym\").item()\n                if weights.has_tensor(\"gptq_sym\")\n                else False\n            )\n            self.quant_method = \"gptq\"\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/gptq/hpu.py",
    "content": "import math\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\ntry:\n\n    convert_from_uint4 = torch.ops.hpu.convert_from_uint4\nexcept Exception as e:\n    hpu_import_exception = e\n\n    def error_raiser_hpu(*args, **kwargs):\n        raise ValueError(\n            f\"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}\"\n        )\n\n    convert_from_uint4 = error_raiser_hpu\n\n\ndef pack_tensor(input, bits=4):\n    normal = input.to(torch.int32)\n    q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)\n    i = 0\n    col = 0\n    while col < q.shape[1]:\n        for j in range(i, i + (32 // bits)):\n            q[:, col] |= normal[:, j] << (bits * (j - i))\n        i += 32 // bits\n        col += 1\n    q = q.to(torch.int32)\n    return q\n\n\nclass QuantLinear(nn.Module):\n    def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):\n        super().__init__()\n        self.register_buffer(\"qweight\", qweight)\n        self.register_buffer(\"qzeros\", qzeros)\n        self.register_buffer(\"scales\", scales)\n        self.register_buffer(\"g_idx\", g_idx)\n        if bias is not None:\n            self.register_buffer(\"bias\", bias)\n        else:\n            self.bias = None\n        if bits not in [4]:\n            raise NotImplementedError(\"Only 4 bits are supported.\")\n        self.bits = bits\n        self.maxq = 2**self.bits - 1\n        self.groupsize = groupsize\n\n        self.outfeatures = qweight.shape[1]\n        self.infeatures = qweight.shape[0] * 32 // bits\n        self.wf = torch.tensor(\n            list(range(0, 32, self.bits)), dtype=torch.int32\n        ).unsqueeze(0)\n        self._preprocessing()\n\n    def unpack_zeros_from_cuda_old_format(self):\n        zeros = torch.bitwise_right_shift(\n            torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),\n            self.wf.unsqueeze(0),\n        ).to(torch.int16 if self.bits == 8 else torch.int8)\n\n        zeros = zeros + 1\n        zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(\n            self.scales.dtype\n        )  # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.\n        zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])\n        return zeros\n\n    def unpack_weight_from_cuda_old_format(self):\n        weight = torch.bitwise_right_shift(\n            torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),\n            self.wf.unsqueeze(-1),\n        ).to(torch.int16 if self.bits == 8 else torch.int8)\n        weight = torch.bitwise_and(weight, (2**self.bits) - 1)\n        weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))\n        return weight\n\n    def _preprocessing(self):\n        orig_device = self.qweight.device\n        self.qweight = self.qweight.cpu()\n        weight = self.unpack_weight_from_cuda_old_format()\n        new_qweight = pack_tensor(weight)\n        self.qweight = new_qweight.to(orig_device)\n        # TODO: Support group indexing and remove the check\n        columns = self.qweight.shape[0]\n        g_idx_trivial = [i // self.groupsize for i in range(columns)]\n        g_idx_trivial = torch.tensor(\n            g_idx_trivial, dtype=torch.int32, device=self.g_idx.device\n        )\n        sort_zeros = not (torch.equal(self.g_idx, g_idx_trivial))\n        self.qzeros = self.qzeros.cpu()\n        zeros = self.unpack_zeros_from_cuda_old_format()\n        if sort_zeros:\n            zeros_group_1 = torch.zeros(\n                (self.infeatures, self.outfeatures),\n                dtype=zeros.dtype,\n                device=zeros.device,\n            )\n            scales = self.scales.cpu()\n            scale_group_1 = torch.zeros(\n                (self.infeatures, self.outfeatures),\n                dtype=scales.dtype,\n                device=scales.device,\n            )\n            for i in range(self.infeatures):\n                zeros_group_1[i] = zeros[self.g_idx[i]]\n                scale_group_1[i] = self.scales[self.g_idx[i]]\n            self.qzeros = pack_tensor(zeros_group_1).to(orig_device)\n            self.scales = scale_group_1.to(orig_device)\n            self.groupsize = 1\n            self.g_idx = None\n        else:\n            new_qzeros = pack_tensor(zeros)\n            self.qzeros = new_qzeros.to(orig_device)\n\n    @classmethod\n    def new(cls, bits, groupsize, infeatures, outfeatures, bias):\n        if bits not in [4]:\n            raise NotImplementedError(\"Only 4 bits are supported.\")\n\n        qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)\n        qzeros = torch.zeros(\n            (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),\n            dtype=torch.int32,\n        )\n        scales = torch.zeros(\n            (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16\n        )\n        g_idx = torch.tensor(\n            [i // groupsize for i in range(infeatures)], dtype=torch.int32\n        )\n        if bias:\n            bias = torch.zeros((outfeatures), dtype=torch.float16)\n        else:\n            bias = None\n        return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)\n\n    def pack(self, linear, scales, zeros, g_idx=None):\n        self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx\n\n        scales = scales.t().contiguous()\n        zeros = zeros.t().contiguous()\n        scale_zeros = zeros * scales\n        self.scales = scales.clone().half()\n        if linear.bias is not None:\n            self.bias = linear.bias.clone().half()\n\n        intweight = []\n        for idx in range(self.infeatures):\n            intweight.append(\n                torch.round(\n                    (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])\n                    / self.scales[self.g_idx[idx]]\n                ).to(torch.int)[:, None]\n            )\n        intweight = torch.cat(intweight, dim=1)\n        intweight = intweight.t().contiguous()\n        intweight = intweight.numpy().astype(np.uint32)\n        qweight = np.zeros(\n            (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32\n        )\n        i = 0\n        row = 0\n        while row < qweight.shape[0]:\n            if self.bits in [4]:\n                for j in range(i, i + (32 // self.bits)):\n                    qweight[row] |= intweight[j] << (self.bits * (j - i))\n                i += 32 // self.bits\n                row += 1\n            else:\n                raise NotImplementedError(\"Only 4 bits are supported.\")\n\n        qweight = qweight.astype(np.int32)\n        self.qweight = torch.from_numpy(qweight)\n\n        zeros -= 1\n        zeros = zeros.numpy().astype(np.uint32)\n        qzeros = np.zeros(\n            (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32\n        )\n        i = 0\n        col = 0\n        while col < qzeros.shape[1]:\n            if self.bits in [4]:\n                for j in range(i, i + (32 // self.bits)):\n                    qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))\n                i += 32 // self.bits\n                col += 1\n            else:\n                raise NotImplementedError(\"Only 4 bits are supported.\")\n\n        qzeros = qzeros.astype(np.int32)\n        self.qzeros = torch.from_numpy(qzeros)\n\n    def forward(self, x):\n        out_shape = x.shape[:-1] + (self.outfeatures,)\n        x = x.reshape(-1, x.shape[-1])\n        weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)\n        out = torch.matmul(x, weight)\n        out = out.reshape(out_shape)\n        out = out + self.bias if self.bias is not None else out\n        return out\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/gptq/quantize.py",
    "content": "import time\nimport torch.nn as nn\nimport math\nimport json\nimport os\nimport torch\nimport transformers\n\nfrom texttable import Texttable\nfrom transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\nfrom huggingface_hub import HfApi\nfrom accelerate import init_empty_weights\nfrom text_generation_server.utils import initialize_torch_distributed, Weights\nfrom text_generation_server.utils.hub import weight_files\nfrom text_generation_server.layers.gptq import QuantLinear\nfrom loguru import logger\nfrom typing import Optional\nfrom text_generation_server.layers.gptq.utils import torch_snr_error\n\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\nDEV = torch.device(\"cuda:0\")\n\n\nclass Quantizer(nn.Module):\n    def __init__(self, shape=1):\n        super(Quantizer, self).__init__()\n        self.register_buffer(\"maxq\", torch.tensor(0))\n        self.register_buffer(\"scale\", torch.zeros(shape))\n        self.register_buffer(\"zero\", torch.zeros(shape))\n\n    def configure(\n        self,\n        bits,\n        perchannel=False,\n        sym=True,\n        mse=False,\n        norm=2.4,\n        grid=100,\n        maxshrink=0.8,\n        trits=False,\n    ):\n        self.maxq = torch.tensor(2**bits - 1)\n        self.perchannel = perchannel\n        self.sym = sym\n        self.mse = mse\n        self.norm = norm\n        self.grid = grid\n        self.maxshrink = maxshrink\n        if trits:\n            self.maxq = torch.tensor(-1)\n        self.scale = torch.zeros_like(self.scale)\n\n    def _quantize(self, x, scale, zero, maxq):\n        if maxq < 0:\n            return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero\n        q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)\n        return scale * (q - zero)\n\n    def find_params(self, x, weight=False):\n        dev = x.device\n        self.maxq = self.maxq.to(dev)\n\n        shape = x.shape\n        if self.perchannel:\n            if weight:\n                x = x.flatten(1)\n            else:\n                if len(shape) == 4:\n                    x = x.permute([1, 0, 2, 3])\n                    x = x.flatten(1)\n                if len(shape) == 3:\n                    x = x.reshape((-1, shape[-1])).t()\n                if len(shape) == 2:\n                    x = x.t()\n        else:\n            x = x.flatten().unsqueeze(0)\n\n        tmp = torch.zeros(x.shape[0], device=dev)\n        xmin = torch.minimum(x.min(1)[0], tmp)\n        xmax = torch.maximum(x.max(1)[0], tmp)\n\n        if self.sym:\n            xmax = torch.maximum(torch.abs(xmin), xmax)\n            tmp = xmin < 0\n            if torch.any(tmp):\n                xmin[tmp] = -xmax[tmp]\n        tmp = (xmin == 0) & (xmax == 0)\n        xmin[tmp] = -1\n        xmax[tmp] = +1\n\n        if self.maxq < 0:\n            self.scale = xmax\n            self.zero = xmin\n        else:\n            self.scale = (xmax - xmin) / self.maxq\n            if self.sym:\n                self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)\n            else:\n                self.zero = torch.round(-xmin / self.scale)\n\n        if self.mse:\n            best = torch.full([x.shape[0]], float(\"inf\"), device=dev)\n            for i in range(int(self.maxshrink * self.grid)):\n                p = 1 - i / self.grid\n                xmin1 = p * xmin\n                xmax1 = p * xmax\n                scale1 = (xmax1 - xmin1) / self.maxq\n                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero\n                q = self._quantize(\n                    x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq\n                )\n                q -= x\n                q.abs_()\n                q.pow_(self.norm)\n                err = torch.sum(q, 1)\n                tmp = err < best\n                if torch.any(tmp):\n                    best[tmp] = err[tmp]\n                    self.scale[tmp] = scale1[tmp]\n                    self.zero[tmp] = zero1[tmp]\n        if not self.perchannel:\n            if weight:\n                tmp = shape[0]\n            else:\n                tmp = shape[1] if len(shape) != 3 else shape[2]\n            self.scale = self.scale.repeat(tmp)\n            self.zero = self.zero.repeat(tmp)\n\n        if weight:\n            shape = [-1] + [1] * (len(shape) - 1)\n            self.scale = self.scale.reshape(shape)\n            self.zero = self.zero.reshape(shape)\n            return\n        if len(shape) == 4:\n            self.scale = self.scale.reshape((1, -1, 1, 1))\n            self.zero = self.zero.reshape((1, -1, 1, 1))\n        if len(shape) == 3:\n            self.scale = self.scale.reshape((1, 1, -1))\n            self.zero = self.zero.reshape((1, 1, -1))\n        if len(shape) == 2:\n            self.scale = self.scale.unsqueeze(0)\n            self.zero = self.zero.unsqueeze(0)\n\n    def quantize(self, x):\n        if self.ready():\n            return self._quantize(x, self.scale, self.zero, self.maxq)\n\n        return x\n\n    def enabled(self):\n        return self.maxq > 0\n\n    def ready(self):\n        return torch.all(self.scale != 0)\n\n\nclass GPTQ:\n    def __init__(self, layer, observe=False):\n        self.layer = layer\n        self.dev = self.layer.weight.device\n        W = layer.weight.data.clone()\n        if isinstance(self.layer, nn.Conv2d):\n            W = W.flatten(1)\n        if isinstance(self.layer, transformers.Conv1D):\n            W = W.t()\n        self.rows = W.shape[0]\n        self.columns = W.shape[1]\n        self.H = torch.zeros((self.columns, self.columns), device=self.dev)\n        self.nsamples = 0\n        self.quantizer = Quantizer()\n        self.observe = observe\n\n    def add_batch(self, inp, out):\n        # Hessian H = 2 X XT + λ I\n        if self.observe:\n            self.inp1 = inp\n            self.out1 = out\n        else:\n            self.inp1 = None\n            self.out1 = None\n\n        if len(inp.shape) == 2:\n            inp = inp.unsqueeze(0)\n        tmp = inp.shape[0]\n        if isinstance(self.layer, nn.Linear) or isinstance(\n            self.layer, transformers.Conv1D\n        ):\n            if len(inp.shape) == 3:\n                inp = inp.reshape((-1, inp.shape[-1]))\n            inp = inp.t()\n        if isinstance(self.layer, nn.Conv2d):\n            unfold = nn.Unfold(\n                self.layer.kernel_size,\n                dilation=self.layer.dilation,\n                padding=self.layer.padding,\n                stride=self.layer.stride,\n            )\n            inp = unfold(inp)\n            inp = inp.permute([1, 0, 2])\n            inp = inp.flatten(1)\n        self.H *= self.nsamples / (self.nsamples + tmp)\n        self.nsamples += tmp\n        # inp = inp.float()\n        inp = math.sqrt(2 / self.nsamples) * inp.float()\n        # self.H += 2 / self.nsamples * inp.matmul(inp.t())\n        self.H += inp.matmul(inp.t())\n\n    def print_loss(self, name, q_weight, weight_error, timecost):\n        table = Texttable()\n        length = 28\n        name = (\n            (name + \" \" * (length - len(name)))\n            if len(name) <= length\n            else name[:length]\n        )\n\n        table.header([\"name\", \"weight_error\", \"fp_inp_SNR\", \"q_inp_SNR\", \"time\"])\n\n        # assign weight\n        self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(\n            self.layer.weight.data.dtype\n        )\n\n        if self.inp1 is not None:\n            # quantize input to int8\n            quantizer = Quantizer()\n            quantizer.configure(8, perchannel=False, sym=True, mse=False)\n            quantizer.find_params(self.inp1)\n            q_in = quantizer.quantize(self.inp1).type(torch.float16)\n            q_out = self.layer(q_in)\n\n            # get kinds of SNR\n            q_SNR = torch_snr_error(q_out, self.out1).item()\n            fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()\n        else:\n            q_SNR = \"-\"\n            fp_SNR = \"-\"\n\n        table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])\n        print(table.draw().split(\"\\n\")[-2])\n\n    def fasterquant(\n        self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=\"\"\n    ):\n        self.layer.to(self.dev)\n\n        W = self.layer.weight.data.clone()\n        if isinstance(self.layer, nn.Conv2d):\n            W = W.flatten(1)\n        if isinstance(self.layer, transformers.Conv1D):\n            W = W.t()\n        W = W.float()\n\n        tick = time.time()\n\n        if not self.quantizer.ready():\n            self.quantizer.find_params(W, weight=True)\n\n        H = self.H\n        if not self.observe:\n            del self.H\n        dead = torch.diag(H) == 0\n        H[dead, dead] = 1\n        W[:, dead] = 0\n\n        if act_order:\n            perm = torch.argsort(torch.diag(H), descending=True)\n            W = W[:, perm]\n            H = H[perm][:, perm]\n\n        Losses = torch.zeros_like(W)\n        Q = torch.zeros_like(W)\n\n        damp = percdamp * torch.mean(torch.diag(H))\n        diag = torch.arange(self.columns, device=self.dev)\n        H[diag, diag] += damp\n        H = torch.linalg.cholesky(H)\n        H = torch.cholesky_inverse(H)\n        try:\n            H = torch.linalg.cholesky(H, upper=True)\n        except Exception:\n            # Addition because Falcon fails on h_to_4h\n            H = torch.linalg.cholesky(\n                H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True\n            )\n        Hinv = H\n\n        g_idx = []\n        scale = []\n        zero = []\n        now_idx = 1\n\n        for i1 in range(0, self.columns, blocksize):\n            i2 = min(i1 + blocksize, self.columns)\n            count = i2 - i1\n\n            W1 = W[:, i1:i2].clone()\n            Q1 = torch.zeros_like(W1)\n            Err1 = torch.zeros_like(W1)\n            Losses1 = torch.zeros_like(W1)\n            Hinv1 = Hinv[i1:i2, i1:i2]\n\n            for i in range(count):\n                w = W1[:, i]\n                d = Hinv1[i, i]\n\n                if groupsize != -1:\n                    if (i1 + i) % groupsize == 0:\n                        self.quantizer.find_params(\n                            W[:, (i1 + i) : (i1 + i + groupsize)], weight=True\n                        )\n\n                    if ((i1 + i) // groupsize) - now_idx == -1:\n                        scale.append(self.quantizer.scale)\n                        zero.append(self.quantizer.zero)\n                        now_idx += 1\n\n                q = self.quantizer.quantize(w.unsqueeze(1)).flatten()\n                Q1[:, i] = q\n                Losses1[:, i] = (w - q) ** 2 / d**2\n\n                err1 = (w - q) / d\n                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))\n                Err1[:, i] = err1\n\n            Q[:, i1:i2] = Q1\n            Losses[:, i1:i2] = Losses1 / 2\n\n            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])\n\n        torch.cuda.synchronize()\n        error = torch.sum(Losses).item()\n\n        groupsize = groupsize if groupsize != -1 else self.columns\n        g_idx = [i // groupsize for i in range(self.columns)]\n        g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)\n        if act_order:\n            invperm = torch.argsort(perm)\n            Q = Q[:, invperm]\n            g_idx = g_idx[invperm]\n\n        if isinstance(self.layer, transformers.Conv1D):\n            Q = Q.t()\n\n        self.print_loss(\n            name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)\n        )\n\n        if scale == []:\n            scale.append(self.quantizer.scale)\n            zero.append(self.quantizer.zero)\n        scale = torch.cat(scale, dim=1)\n        zero = torch.cat(zero, dim=1)\n        return scale, zero, g_idx, error\n\n    def free(self):\n        self.inp1 = None\n        self.out1 = None\n        self.H = None\n        self.Losses = None\n        self.Trace = None\n        torch.cuda.empty_cache()\n\n\ndef get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train\")\n    testdata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\"\\n\\n\".join(traindata[\"text\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\"\\n\\n\".join(testdata[\"text\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"train\")\n    valdata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"validation\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\"\\n\\n\".join(traindata[\"sentence\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\"\\n\\n\".join(valdata[\"sentence\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"train\": \"en/c4-train.00000-of-01024.json.gz\"},\n        split=\"train\",\n        use_auth_token=False,\n    )\n    valdata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"validation\": \"en/c4-validation.00000-of-00008.json.gz\"},\n        split=\"validation\",\n        use_auth_token=False,\n    )\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(traindata) - 1)\n            trainenc = tokenizer(traindata[i][\"text\"], return_tensors=\"pt\")\n            if trainenc.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n\n    import random\n\n    random.seed(0)\n    valenc = []\n    for _ in range(256):\n        while True:\n            i = random.randint(0, len(valdata) - 1)\n            tmp = tokenizer(valdata[i][\"text\"], return_tensors=\"pt\")\n            if tmp.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        valenc.append(tmp.input_ids[:, i:j])\n    valenc = torch.hstack(valenc)\n\n    class TokenizerWrapper:\n        def __init__(self, input_ids):\n            self.input_ids = input_ids\n\n    valenc = TokenizerWrapper(valenc)\n\n    return trainloader, valenc\n\n\ndef get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"train\")\n    testdata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"test\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\" \".join(traindata[\"sentence\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\" \".join(testdata[\"sentence\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"train\": \"en/c4-train.00000-of-01024.json.gz\"},\n        split=\"train\",\n    )\n    valdata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"validation\": \"en/c4-validation.00000-of-00008.json.gz\"},\n        split=\"validation\",\n    )\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(traindata) - 1)\n            trainenc = tokenizer(traindata[i][\"text\"], return_tensors=\"pt\")\n            if trainenc.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n\n    valenc = tokenizer(\" \".join(valdata[:1100][\"text\"]), return_tensors=\"pt\")\n    valenc = valenc.input_ids[:, : (256 * seqlen)]\n\n    class TokenizerWrapper:\n        def __init__(self, input_ids):\n            self.input_ids = input_ids\n\n    valenc = TokenizerWrapper(valenc)\n\n    return trainloader, valenc\n\n\ndef get_loaders(\n    name, nsamples=128, seed=0, seqlen=2048, model_id=\"\", trust_remote_code=False\n):\n    if \"wikitext2\" in name:\n        return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)\n    if \"ptb\" in name:\n        if \"new\" in name:\n            return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)\n        return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)\n    if \"c4\" in name:\n        if \"new\" in name:\n            return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)\n        return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)\n\n\ndef find_layers(module, layers=(nn.Conv2d, nn.Linear), name=\"\"):\n    # Skip last lm_head linear\n    # Need isintance Falcon is inheriting Linear.\n    if isinstance(module, layers) and \"lm_head\" not in name:\n        return {name: module}\n    res = {}\n    for name1, child in module.named_children():\n        res.update(\n            find_layers(\n                child, layers=layers, name=name + \".\" + name1 if name != \"\" else name1\n            )\n        )\n    return res\n\n\n@torch.no_grad()\ndef sequential(\n    model,\n    dataloader,\n    dev,\n    nsamples,\n    bits,\n    groupsize,\n    *,\n    hooks,\n    percdamp=0.01,\n    sym: bool = False,\n    act_order: bool = False,\n):\n    print(\"Starting ...\")\n\n    use_cache = model.config.use_cache\n    model.config.use_cache = False\n    try:\n        layers = model.model.layers\n        prefix = \"model.layers\"\n    except Exception:\n        layers = model.transformer.h\n        prefix = \"transformer.h\"\n\n    dtype = next(iter(model.parameters())).dtype\n    inps = torch.zeros(\n        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev\n    )\n\n    cache = {\"i\": 0}\n    extra = {}\n\n    class Catcher(nn.Module):\n        def __init__(self, module):\n            super().__init__()\n            self.module = module\n\n        def forward(self, inp, **kwargs):\n            inps[cache[\"i\"]] = inp\n            cache[\"i\"] += 1\n            extra.update(kwargs.copy())\n            raise ValueError\n\n    layers[0] = Catcher(layers[0])\n    for batch in dataloader:\n        try:\n            model(batch[0].cuda())\n        except ValueError:\n            pass\n    layers[0] = layers[0].module\n\n    # layers[0] = layers[0].cpu()\n    # model.model.embed_tokens = model.model.embed_tokens.cpu()\n    # model.model.norm = model.model.norm.cpu()\n    torch.cuda.empty_cache()\n    for hook in hooks:\n        hook.remove()\n\n    outs = torch.zeros_like(inps)\n\n    extra = {\n        k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()\n    }\n\n    print(\"Ready.\")\n\n    quantizers = {}\n    for i in range(len(layers)):\n        print(f\"Quantizing layer {i+1}/{len(layers)}..\")\n        print(\"+------------------+--------------+------------+-----------+-------+\")\n        print(\"|       name       | weight_error | fp_inp_SNR | q_inp_SNR | time  |\")\n        print(\"+==================+==============+============+===========+=======+\")\n\n        layer = layers[i]\n        layer.load()\n        full = find_layers(layer)\n        sequential = [list(full.keys())]\n\n        for names in sequential:\n            subset = {n: full[n] for n in names}\n            gptq = {}\n            for name in subset:\n                gptq[name] = GPTQ(subset[name])\n                gptq[name].quantizer.configure(\n                    bits, perchannel=True, sym=sym, mse=False\n                )\n                pass\n\n            def add_batch(name):\n                nonlocal gptq\n\n                def tmp(_, inp, out):\n                    gptq[name].add_batch(inp[0].data, out.data)\n\n                return tmp\n\n            handles = []\n            for name in subset:\n                handles.append(subset[name].register_forward_hook(add_batch(name)))\n            for j in range(nsamples):\n                outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]\n            for h in handles:\n                h.remove()\n\n            for name in subset:\n                scale, zero, g_idx, error = gptq[name].fasterquant(\n                    percdamp=percdamp,\n                    groupsize=groupsize,\n                    act_order=act_order,\n                    name=name,\n                )\n                quantizers[f\"{prefix}.{i}.{name}\"] = (\n                    gptq[name].quantizer.cpu(),\n                    scale.cpu(),\n                    zero.cpu(),\n                    g_idx.cpu(),\n                    bits,\n                    groupsize,\n                )\n\n                gptq[name].free()\n\n        for j in range(nsamples):\n            outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]\n\n        layer.unload()\n        del layer\n        del gptq\n        torch.cuda.empty_cache()\n\n        inps, outs = outs, inps\n        print(\"+------------------+--------------+------------+-----------+-------+\")\n        print(\"\\n\")\n\n    model.config.use_cache = use_cache\n\n    return quantizers\n\n\ndef make_quant_linear(module, names, bits, groupsize, name=\"\"):\n    if isinstance(module, QuantLinear):\n        return\n    for attr in dir(module):\n        tmp = getattr(module, attr)\n        name1 = name + \".\" + attr if name != \"\" else attr\n        if name1 in names:\n            delattr(module, attr)\n            setattr(\n                module,\n                attr,\n                QuantLinear.new(\n                    bits,\n                    groupsize,\n                    tmp.in_features,\n                    tmp.out_features,\n                    tmp.bias is not None,\n                ),\n            )\n    for name1, child in module.named_children():\n        make_quant_linear(\n            child, names, bits, groupsize, name + \".\" + name1 if name != \"\" else name1\n        )\n\n\n# TODO: perform packing on GPU\ndef pack(model, quantizers, bits, groupsize):\n    layers = find_layers(model)\n    layers = {n: layers[n] for n in quantizers}\n    make_quant_linear(model, quantizers, bits, groupsize)\n    qlayers = find_layers(model, (QuantLinear,))\n    print(\"Packing ...\")\n    for name in qlayers:\n        print(name)\n        quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]\n        qlayers[name].pack(layers[name], scale, zero, g_idx)\n    print(\"Done.\")\n    return model\n\n\ndef setdeepattr(module, full_name, tensor):\n    current = module\n    tokens = full_name.split(\".\")\n    for token in tokens[:-1]:\n        current = getattr(current, token)\n    setattr(current, tokens[-1], tensor)\n\n\ndef getdeepattr(module, full_name):\n    current = module\n    tokens = full_name.split(\".\")\n    for token in tokens:\n        current = getattr(current, token)\n    return current\n\n\ndef load_weights_pre_hook(module_name, weights, recursive=False):\n    def inner(module, args):\n        print(f\"Pre hook {module_name}\")\n        local_params = {}\n        for k, v in module.named_parameters():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for k, v in module.named_buffers():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n\n        for local_param in local_params:\n            current_tensor = getdeepattr(module, local_param)\n            if current_tensor.device == torch.device(\"meta\"):\n                # print(f\"Loading {local_param}\")\n                if module_name:\n                    tensor_name = f\"{module_name}.{local_param}\"\n                else:\n                    tensor_name = local_param\n                tensor = weights.get_tensor(tensor_name)\n                setdeepattr(module, local_param, nn.Parameter(tensor))\n            else:\n                tensor = current_tensor.to(device=torch.device(\"cuda:0\"))\n                if current_tensor.requires_grad:\n                    tensor = nn.Parameter(tensor)\n                setdeepattr(module, local_param, tensor)\n\n    return inner\n\n\ndef load_weights_post_hook(module_name, weights, recursive=False):\n    def inner(module, args, output):\n        print(f\"Post hook {module_name}\")\n        local_params = {}\n        for k, v in module.named_parameters():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for k, v in module.named_buffers():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for local_param in local_params:\n            # print(f\"Unloading {local_param}\")\n            current_tensor = getdeepattr(module, local_param)\n            setdeepattr(\n                module,\n                local_param,\n                nn.Parameter(current_tensor.to(device=torch.device(\"cpu\"))),\n            )\n        return output\n\n    return inner\n\n\ndef quantize(\n    model_id: str,\n    bits: int,\n    groupsize: int,\n    output_dir: str,\n    revision: str,\n    trust_remote_code: bool,\n    upload_to_model_id: Optional[str],\n    percdamp: float,\n    act_order: bool,\n    sym: bool,\n):\n    print(\"loading model\")\n    config = AutoConfig.from_pretrained(\n        model_id,\n        trust_remote_code=trust_remote_code,\n    )\n\n    with init_empty_weights():\n        model = AutoModelForCausalLM.from_config(\n            config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code\n        )\n    model = model.eval()\n\n    print(\"LOADED model\")\n    files = weight_files(model_id, revision, extension=\".safetensors\")\n    process_group, _, _ = initialize_torch_distributed()\n    weights = Weights(\n        files,\n        device=torch.device(\"cuda:0\"),\n        dtype=torch.float16,\n        process_group=process_group,\n        aliases={\"embed_tokens.weight\": [\"lm_head.weight\"]},\n        weights_loader=DefaultWeightsLoader(UnquantizedWeight),\n    )\n    hooks = []\n    for name, module in model.named_modules():\n\n        def load(module, name):\n            def _load():\n                load_weights_pre_hook(name, weights, recursive=True)(module, None)\n\n            return _load\n\n        def unload(module, name):\n            def _unload():\n                load_weights_post_hook(name, weights, recursive=True)(\n                    module, None, None\n                )\n\n            return _unload\n\n        module.load = load(module, name)\n        module.unload = unload(module, name)\n        hooks.append(\n            module.register_forward_pre_hook(load_weights_pre_hook(name, weights))\n        )\n        hooks.append(\n            module.register_forward_hook(load_weights_post_hook(name, weights))\n        )\n    model.seqlen = 2048\n\n    dataset = \"wikitext2\"\n    nsamples = 128\n    seed = None\n\n    dataloader, testloader = get_loaders(\n        dataset,\n        nsamples=nsamples,\n        seed=seed,\n        model_id=model_id,\n        seqlen=model.seqlen,\n        trust_remote_code=trust_remote_code,\n    )\n\n    tick = time.time()\n    quantizers = sequential(\n        model,\n        dataloader,\n        DEV,\n        nsamples,\n        bits,\n        groupsize,\n        percdamp=percdamp,\n        act_order=act_order,\n        hooks=hooks,\n        sym=sym,\n    )\n    print(time.time() - tick)\n\n    pack(model, quantizers, bits, groupsize)\n    from safetensors.torch import save_file\n    from huggingface_hub import split_torch_state_dict_into_shards\n\n    state_dict = model.state_dict()\n    state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}\n\n    max_shard_size = \"10GB\"\n    state_dict_split = split_torch_state_dict_into_shards(\n        state_dict,\n        filename_pattern=\"model.safetensors\",\n        max_shard_size=max_shard_size,\n    )\n    index = None\n    if state_dict_split.is_sharded:\n        index = {\n            \"metadata\": state_dict_split.metadata,\n            \"weight_map\": state_dict_split.tensor_to_filename,\n        }\n    shards = state_dict_split.filename_to_tensors\n    os.makedirs(output_dir, exist_ok=True)\n    for shard_file, shard in shards.items():\n        save_file(\n            shard,\n            os.path.join(output_dir, shard_file),\n            metadata={\n                \"format\": \"pt\",\n                \"quantized\": \"gptq\",\n                \"origin\": \"text-generation-inference\",\n            },\n        )\n    if index is None:\n        path_to_weights = os.path.join(output_dir, \"model.safetensors\")\n        logger.info(f\"Model weights saved in {path_to_weights}\")\n    else:\n        save_index_file = \"model.safetensors.index.json\"\n        save_index_file = os.path.join(output_dir, save_index_file)\n        with open(save_index_file, \"w\", encoding=\"utf-8\") as f:\n            content = json.dumps(index, indent=2, sort_keys=True) + \"\\n\"\n            f.write(content)\n        logger.info(\n            f\"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be \"\n            f\"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the \"\n            f\"index located at {save_index_file}.\"\n        )\n    config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)\n    config.quantization_config = {\n        \"bits\": bits,\n        \"group_size\": groupsize,\n        \"damp_percent\": percdamp,\n        \"desc_act\": act_order,\n        \"static_groups\": False,\n        \"sym\": sym,\n        \"quant_method\": \"gptq\",\n    }\n    config.save_pretrained(output_dir)\n    logger.info(\"Saved config\")\n    logger.info(\"Saving tokenizer\")\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_id, trust_remote_code=trust_remote_code\n    )\n    tokenizer.save_pretrained(output_dir)\n    logger.info(\"Saved tokenizer\")\n\n    if upload_to_model_id:\n        api = HfApi()\n\n        api.upload_folder(\n            folder_path=output_dir, repo_id=upload_to_model_id, repo_type=\"model\"\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/gptq/utils.py",
    "content": "import torch\n\n\n# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py\ndef torch_snr_error(\n    y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = \"mean\"\n) -> torch.Tensor:\n    \"\"\"\n    Compute SNR between y_pred(tensor) and y_real(tensor)\n\n    SNR can be calcualted as following equation:\n\n        SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2\n\n    if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.\n\n        SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)\n\n    Args:\n        y_pred (torch.Tensor): _description_\n        y_real (torch.Tensor): _description_\n        reduction (str, optional): _description_. Defaults to 'mean'.\n\n    Raises:\n        ValueError: _description_\n        ValueError: _description_\n\n    Returns:\n        torch.Tensor: _description_\n    \"\"\"\n    if y_pred.shape != y_real.shape:\n        raise ValueError(\n            f\"Can not compute snr loss for tensors with different shape. \"\n            f\"({y_pred.shape} and {y_real.shape})\"\n        )\n    reduction = str(reduction).lower()\n\n    if y_pred.ndim == 1:\n        y_pred = y_pred.unsqueeze(0)\n        y_real = y_real.unsqueeze(0)\n\n    y_pred = y_pred.flatten(start_dim=1)\n    y_real = y_real.flatten(start_dim=1)\n\n    noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)\n    signal_power = torch.pow(y_real, 2).sum(dim=-1)\n    snr = (noise_power) / (signal_power + 1e-7)\n\n    if reduction == \"mean\":\n        return torch.mean(snr)\n    elif reduction == \"sum\":\n        return torch.sum(snr)\n    elif reduction == \"none\":\n        return snr\n    else:\n        raise ValueError(\"Unsupported reduction method.\")\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/layernorm.py",
    "content": "import torch\nfrom torch import nn\nfrom accelerate import init_empty_weights\n\n\n# Monkey patching\n@classmethod\ndef load_layer_norm(cls, prefix, weights, eps):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    bias = weights.get_tensor(f\"{prefix}.bias\")\n    with init_empty_weights():\n        ln = cls(weight.shape, eps=eps)\n\n    ln.weight = torch.nn.Parameter(weight)\n    ln.bias = torch.nn.Parameter(bias)\n    return ln\n\n\n@classmethod\ndef load_layer_norm_no_bias(cls, prefix, weights, eps):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    with init_empty_weights():\n        ln = cls(weight.shape, eps=eps)\n\n    ln.weight = torch.nn.Parameter(weight)\n    ln.bias = None\n    return ln\n\n\ntorch.nn.LayerNorm.load = load_layer_norm\ntorch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias\n\n\nclass FastLayerNorm(nn.LayerNorm):\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n\n        return super().forward(hidden_states), residual\n\n\nclass FastRMSNorm(nn.Module):\n    def __init__(self, weight: torch.Tensor, eps: float):\n        super().__init__()\n\n        self.weight = nn.Parameter(weight)\n        self.variance_epsilon = eps\n\n    @classmethod\n    def load(cls, prefix, weights, eps=1e-6):\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        return cls(weight, eps)\n\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(self.weight.dtype), residual\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/linear.py",
    "content": "import torch\nfrom torch.nn import functional as F\n\n\nclass FastLinear(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n    ) -> None:\n        super().__init__()\n        self.weight = torch.nn.Parameter(weight, requires_grad=False)\n        if bias is not None:\n            self.bias = torch.nn.Parameter(bias, requires_grad=False)\n        else:\n            self.bias = None\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        if bias:\n            bias = weights.get_tensor(f\"{prefix}.bias\")\n        else:\n            bias = None\n        return cls(weight, bias)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.linear(input, self.weight, self.bias)\n\n\ndef get_linear(weight, bias):\n    # Weights that are loaded through methods that are not\n    # quantization-aware are still bare tensors. We may want\n    # to change this in the future.\n    if isinstance(weight, torch.Tensor):\n        return FastLinear(weight, bias)\n\n    return weight.get_linear(bias)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/lora.py",
    "content": "from typing import TYPE_CHECKING, Optional, List\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom torch.distributed import ProcessGroup\n\nfrom text_generation_server.utils.sgmv import (\n    add_lora_a_bgmv,\n    add_lora_b_bgmv,\n    has_sgmv,\n    lora_a_sgmv_cutlass,\n    lora_b_sgmv_cutlass,\n    orient_for_rank,\n)\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters import AdapterBatchData\n    from text_generation_server.adapters.lora import BatchLoraWeights\n\n\nclass LoraLinear(nn.Module):\n    def __init__(\n        self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup\n    ):\n        super().__init__()\n        self.base_layer = base_layer\n        self.layer_id = layer_id\n        self.process_group = process_group\n\n    def forward_layer_type(\n        self,\n        result: torch.Tensor,\n        input: torch.Tensor,\n        adapter_data: \"AdapterBatchData\",\n        layer_type: str,\n        start_idx: int,\n        end_idx: int,\n    ) -> torch.Tensor:\n        if adapter_data is None:\n            return result\n        data: Optional[\"BatchLoraWeights\"] = adapter_data.data.get(layer_type)\n\n        if has_sgmv() and data is not None and data.can_vectorize(self.process_group):\n            # In tensor-parallel configurations, each GPU processes a specific segment of the output.\n            # The 'result' tensor represents the full output, which can vary in size based on\n            # the layer type (e.g., attention vs. feed-forward layers). We define the current\n            # segment using start_idx and end_idx. If the segment size doesn't match this GPU's\n            # slice of 'result', we create a zero tensor of the correct size for LoRA computation.\n            # This approach ensures accurate LoRA application across various layer sizes and\n            # configurations, adapting to different model architectures and parallelization strategies.\n            #\n            # Example scenarios where this is necessary:\n            # 1. The adapter's size doesn't evenly divide across GPUs.\n            # 2. We're processing the last segment which might be smaller.\n            # 3. Different projection layers (q, k, v) have different sizes.\n            if end_idx - start_idx != result.shape[1]:\n                proj = torch.zeros_like(result[:, start_idx:end_idx])\n            else:\n                proj = result\n\n            for r, rank_segments in data.rank_data.items():\n                lora_a_ptr = rank_segments.lora_a_ptr\n                lora_b_ptr = rank_segments.lora_b_ptr\n\n                if lora_a_ptr is None or lora_b_ptr is None:\n                    raise ValueError(\"LoRA data is missing\")\n\n                if data.use_sgmv:\n                    # Use SGMV for prefill\n                    v = lora_a_sgmv_cutlass(\n                        input,\n                        rank_segments.tmp_shrink,\n                        lora_a_ptr,\n                        rank_segments.segment_starts,\n                        rank_segments.segment_ends,\n                        self.layer_id,\n                        r,\n                    )\n\n                    if self.process_group.size() > 1:\n                        v = self.collect_lora_a(v)\n\n                    lora_b_sgmv_cutlass(\n                        proj,\n                        v,\n                        rank_segments.tmp_expand,\n                        lora_b_ptr,\n                        rank_segments.segment_starts,\n                        rank_segments.segment_ends,\n                        self.layer_id,\n                    )\n                else:\n                    # Use BGMV for decode\n                    v = torch.zeros(\n                        (input.size(0), r), dtype=input.dtype, device=input.device\n                    )\n                    # TODO: error with [-1, 0], but not [0, -1]\n                    add_lora_a_bgmv(\n                        v,\n                        input,\n                        lora_a_ptr,\n                        rank_segments.indices,\n                        self.layer_id,\n                    )\n\n                    if self.process_group.size() > 1:\n                        v = self.collect_lora_a(v)\n\n                    add_lora_b_bgmv(\n                        proj,\n                        v,\n                        lora_b_ptr,\n                        rank_segments.indices,\n                        self.layer_id,\n                    )\n\n            if end_idx - start_idx != result.shape[1]:\n                result[:, start_idx:end_idx] += proj\n        else:\n            for adapter_index in adapter_data.meta.adapter_set:\n                if data is not None and data.has_adapter(adapter_index):\n                    adapter_mask = (\n                        (adapter_data.meta.adapter_indices == adapter_index)\n                        .to(input.dtype)\n                        .view(-1, 1)\n                    )\n                    layer_result = self.forward_lora(\n                        input, data, adapter_index, adapter_mask\n                    )\n                    result[:, start_idx:end_idx] += layer_result\n\n        return result\n\n    def forward_lora(\n        self,\n        input: torch.Tensor,\n        data: \"BatchLoraWeights\",\n        adapter_index: int,\n        adapter_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        lora_a = data.lora_a[adapter_index][self.layer_id, :, :]\n        lora_b = data.lora_b[adapter_index][self.layer_id, :, :]\n\n        lora_a = orient_for_rank(lora_a, lora_b.size(0))\n\n        a_out = input @ lora_a\n        if self.process_group.size() > 1:\n            a_out = self.collect_lora_a(a_out)\n\n        result = (a_out @ lora_b) * adapter_mask\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError(\"Implemented in subclasses\")\n\n\nclass TensorParallelMultiAdapterLinear(LoraLinear):\n    def __init__(\n        self,\n        base_layer: nn.Module,\n        layer_id: int,\n        layer_names: List[str],\n        sizes: List[int],\n        process_group: ProcessGroup,\n    ):\n        super().__init__(base_layer, layer_id, process_group)\n        self.layer_names = layer_names\n        self.sizes = sizes\n\n    @classmethod\n    def load(\n        cls,\n        base_layer: nn.Module,\n        layer_id: int,\n        layer_names: List[str],\n        sizes: List[int],\n        process_group: ProcessGroup,\n    ):\n        return TensorParallelMultiAdapterLinear(\n            base_layer, layer_id, layer_names, sizes, process_group\n        )\n\n    def forward(\n        self, input: torch.Tensor, adapter_data: \"AdapterBatchData\"\n    ) -> torch.Tensor:\n        result = self.base_layer(input)\n\n        # noop if no layer names are provided (e.g. for models without adapters)\n        if self.layer_names is None:\n            return result\n\n        # handle models like Bloom that have inputs of shape\n        # (batch_size, sequence_length, hidden_size)\n        # we need to reshape them to (batch_size * sequence_length, hidden_size)\n        # for the LoRA computation, then reshape back\n        prev_shape = result.shape\n        is_3d = len(input.shape) >= 3\n        if is_3d:\n            input = input.reshape(-1, input.shape[-1])\n            result = result.reshape(-1, result.shape[-1])\n\n        offset = 0\n        for i, layer_name in enumerate(self.layer_names):\n            start_idx = offset // self.process_group.size()\n            # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple\n            # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It\n            # ensures correct slicing of the result tensor, accommodating variations like grouped-query\n            # attention where k_proj and v_proj differ from q_proj. This allows precise application of\n            # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the\n            # different projection sizes across layers and model architectures.\n            if self.sizes is not None:\n                offset += self.sizes[i]\n                end_idx = offset // self.process_group.size()\n            else:\n                end_idx = result.shape[1]\n\n            result = self.forward_layer_type(\n                result, input, adapter_data, layer_name, start_idx, end_idx\n            )\n\n        if is_3d:\n            result = result.reshape(prev_shape)\n\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.\n        # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.\n        #\n        # TODO(travis): this is not very efficient as we do an all-gather for every adapter,\n        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same\n        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:\n        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609\n        gathered_tensors = [\n            torch.empty_like(a_out) for _ in range(self.process_group.size())\n        ]\n        torch.distributed.all_gather(gathered_tensors, a_out)\n        return torch.cat(gathered_tensors, dim=1)\n\n\nclass TensorParallelAdapterRowLinear(LoraLinear):\n    def __init__(self, base_layer, layer_id, layer_name, process_group):\n        super().__init__(base_layer, layer_id, process_group)\n        self.layer_name = layer_name\n\n    @classmethod\n    def load(cls, base_layer, layer_id, layer_name, process_group):\n        return cls(base_layer, layer_id, layer_name, process_group)\n\n    def forward(\n        self, input: torch.Tensor, adapter_data: \"AdapterBatchData\"\n    ) -> torch.Tensor:\n        result = self.base_layer(input)\n\n        if self.layer_name is None:\n            return result\n\n        # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285\n        stride = result.shape[-1] // self.process_group.size()\n        start_idx = self.process_group.rank() * stride\n        end_idx = (self.process_group.rank() + 1) * stride\n\n        self.forward_layer_type(\n            result, input, adapter_data, self.layer_name, start_idx, end_idx\n        )\n\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.\n        # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.\n        #\n        # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,\n        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same\n        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:\n        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609\n        torch.distributed.all_reduce(a_out, group=self.process_group)\n        return a_out\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/medusa.py",
    "content": "import torch\nfrom torch import nn\nfrom typing import Tuple, Optional\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.layers.linear import FastLinear\nfrom text_generation_server.layers.tensor_parallel import (\n    TensorParallelHead,\n    TensorParallelColumnLinear,\n)\n\n\nclass ResBlock(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.linear = FastLinear.load(\n            config, prefix=f\"{prefix}.linear\", weights=weights, bias=True\n        )\n        self.act = torch.nn.SiLU()\n\n    def forward(self, x):\n        return x + self.act(self.linear(x))\n\n\nclass MedusaModel(torch.nn.Module):\n    def __init__(self, config, medusa_config, weights):\n        super().__init__()\n        self.heads = torch.nn.ModuleList(\n            [\n                MedusaHead(config, medusa_config, prefix=f\"{i}\", weights=weights)\n                for i in range(get_speculate())\n            ]\n        )\n\n    def forward(self, x):\n        if not self.heads:\n            return None\n        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)\n        return speculative_logits\n\n\nclass MedusaHead(torch.nn.Module):\n    def __init__(self, config, medusa_config, prefix, weights):\n        super().__init__()\n        self.blocks = torch.nn.ModuleList(\n            [\n                ResBlock(config, prefix=f\"{prefix}.{i}\", weights=weights)\n                for i in range(medusa_config[\"medusa_num_layers\"])\n            ]\n        )\n        n = len(self.blocks)\n        self.out = FastLinear.load(\n            config, prefix=f\"{prefix}.{n}\", weights=weights, bias=False\n        )\n\n    def forward(self, x):\n        for block in self.blocks:\n            x = block(x)\n        x = self.out(x)\n        return x\n\n\nclass MedusaHeadV1(nn.Module):\n    def __init__(self, lm_head, medusa):\n        super().__init__()\n        self.lm_head = lm_head\n        self.medusa = medusa\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        from pathlib import Path\n        from safetensors import safe_open\n        import json\n\n        speculator = config.speculator\n\n        path = speculator[\"path\"]\n        medusa_config = str(Path(path) / \"config.json\")\n\n        for fname in speculator[\"model_paths\"]:\n            filename = str(Path(path) / fname)\n\n            with open(medusa_config, \"r\") as f:\n                medusa_config = json.load(f)\n            routing = weights.routing\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing and routing[k] != filename:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n\n        medusa = MedusaModel(config, medusa_config, weights)\n        lm_head = TensorParallelHead.load(config, prefix, weights)\n        return MedusaHeadV1(lm_head, medusa)\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        logits = self.lm_head(input)\n        # If we have too many tokens, we skip speculative logits\n        if input.shape[0] > 128:\n            return logits, None\n\n        speculative_logits = self.medusa(input)\n        return logits, speculative_logits\n\n\nclass MedusaHeadV2(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        from pathlib import Path\n        from safetensors import safe_open\n        import json\n\n        speculator_path = config.speculator[\"path\"]\n\n        medusa_config = str(Path(speculator_path) / \"config.json\")\n        filename = str(Path(speculator_path) / \"medusa_lm_head.safetensors\")\n\n        with open(medusa_config, \"r\") as f:\n            medusa_config = json.load(f)\n        routing = weights.routing\n        with safe_open(filename, framework=\"pytorch\") as f:\n            for k in f.keys():\n                if k in routing and routing[k] != filename:\n                    raise RuntimeError(\n                        f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                    )\n                routing[k] = filename\n\n        self.n_medusa_heads = get_speculate()\n\n        assert medusa_config[\"medusa_num_layers\"] == 1\n        self.linear = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{i}.0.linear\" for i in range(self.n_medusa_heads)],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.process_group = weights.process_group\n        self.world_size = self.process_group.size()\n        self.rank = self.process_group.rank()\n\n        self.act = torch.nn.SiLU()\n\n        self.lm_head = TensorParallelHead.load(config, prefix, weights)\n\n    def forward(self, x):\n        # If we have too many tokens, we skip speculative logits\n        if x.shape[0] > 128:\n            logits = self.lm_head(x)\n            return logits, None\n\n        size = x.shape[-1]\n        block_size = (size + self.world_size - 1) // self.world_size\n        start = self.rank * block_size\n        stop = (self.rank + 1) * block_size\n\n        x_block = x[:, start:stop]\n\n        # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1\n        medusa_res = self.act(self.linear(x)).reshape(\n            *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]\n        )\n\n        # Apply all residual medusa heads\n        output = x[:, start:stop].unsqueeze(-2) + medusa_res\n\n        # Gather medusa heads\n        world_output = [\n            torch.empty_like(output) for _ in range(self.process_group.size())\n        ]\n        torch.distributed.all_gather(world_output, output, group=self.process_group)\n        world_output = torch.cat(world_output, dim=-1)\n\n        # Stack x and medusa residual x\n        stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)\n\n        # Compute lm head on x + medusa residual x\n        logits = self.lm_head(stacked_x)\n\n        # Finally, split logits from speculative logits\n        logits, speculative_logits = torch.split(\n            logits, [1, self.n_medusa_heads], dim=-2\n        )\n        # Squeeze added dimension\n        logits = logits.squeeze(-2)\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/mlp.py",
    "content": "import torch\nimport math\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Optional, Tuple\nfrom text_generation_server.layers import TensorParallelEmbedding, FastLinear\nfrom text_generation_server.layers.tensor_parallel import TensorParallelHead\nfrom text_generation_server.utils.speculate import get_speculate\n\n\nclass MLPSpeculatorLayerNorm(nn.Module):\n    \"\"\"\n    A L2 normalization implementation\n    ...\n    Args\n    ----\n    normalized_shape : int\n        Dimensionality of input data (size of final tensor axis)\n    elementwise_scale_weight : torch.Tensor\n        learned scaling term after normalization?\n    elementwise_shift_bias : torch.Tensor\n        learned bias term after normalization?\n    eps : float\n        Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).\n    \"\"\"\n\n    def __init__(\n        self,\n        prefix,\n        config,\n        weights,\n        eps=1e-06,\n    ):\n        super(MLPSpeculatorLayerNorm, self).__init__()\n        self.weight = weights.get_tensor(f\"{prefix}.weight\")\n        self.bias = weights.get_tensor(f\"{prefix}.bias\")\n        self.eps = eps\n\n    def forward(self, x):\n        xf = x\n        xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)\n        x = xf.type_as(x)\n        x = self.weight * x\n        x = x + self.bias\n        return x\n\n\nINV_SQRT2 = 2**-0.5\n\n\ndef simple_norm(x: torch.Tensor, eps=1e-06):\n    xf = x\n    xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)\n    x = xf.type_as(x)\n    return x * INV_SQRT2\n\n\nclass MLPSpeculatorModelTied(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.config = config\n        self.n_predict = get_speculate()\n        self.hidden_size = config.hidden_size\n\n        self.emb = TensorParallelEmbedding(f\"{prefix}.emb.0\", weights)\n        self.proj0 = FastLinear.load(\n            config,\n            prefix=f\"{prefix}.proj.0\",\n            weights=weights,\n            bias=False,\n        )\n        self.proj1 = FastLinear.load(\n            config,\n            prefix=f\"{prefix}.proj.1\",\n            weights=weights,\n            bias=False,\n        )\n        self.head = FastLinear.load(config, f\"{prefix}.head.0\", weights, bias=False)\n        self.ln = MLPSpeculatorLayerNorm(\n            prefix=f\"{prefix}.ln.0\",\n            config=config,\n            weights=weights,\n        )\n\n        # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation\n        self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1\n        self.activation = nn.GELU()\n        self.vsize = config.vocab_size\n        self.inner_dim = config.speculator_config[\"inner_dim\"]\n        self.top_k_tokens_per_head = [1] * self.n_predict\n        self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(\n            self.inner_dim / 2\n        )\n        self.emb.weight *= self.emb_weight\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_ids: torch.Tensor,\n    ):\n        top_k_tokens_per_head = self.top_k_tokens_per_head\n\n        # k indicates # of candidates\n        # h indicates # of generated tokens\n        state = hidden_states\n        b = state.size(0)\n        ind = input_ids.unsqueeze(0)\n        all_probs = torch.empty(\n            b, self.n_predict, self.vsize, device=state.device\n        )  # b k h v\n        assert (\n            len(top_k_tokens_per_head) == self.n_predict\n        ), f\"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)\"\n        for i in range(self.n_predict):\n            # Project and predict\n            z = self.emb(ind)\n            # z = z.mul(self.emb_weight)  # b k d\n            if i == 0:\n                state = self.proj0(state) * self.state_weight + z\n            else:\n                state = self.proj1(state) * self.state_weight + z\n            state = self.activation(self.ln(state))  # b k d\n            probs = F.log_softmax(self.head(state), dim=-1)  # b k v\n            _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1)  # b k k'\n\n            # Update candidate set with new predictions\n\n            # Update distribution set with new logits\n            all_probs[:, i] = probs.exp()\n\n            # Update state, log_probs and ind for new predictions\n            state = state.unsqueeze(2).expand(\n                -1, -1, top_k_tokens_per_head[i], -1\n            )  # b k k' d\n            state = state.reshape(-1, b, state.size(3))  # b kk' d\n            ind = preds.view(-1, b)  # b kk'\n\n        speculative_logits = all_probs\n        return speculative_logits\n\n\nclass MLPSpeculatorModel(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.config = config\n        self.n_predict = get_speculate()\n        self.hidden_size = config.hidden_size\n\n        self.emb = nn.ModuleList(\n            [\n                TensorParallelEmbedding(f\"{prefix}.emb.{i}\", weights)\n                for i in range(self.n_predict)\n            ]\n        )\n        self.proj = [\n            FastLinear.load(\n                config,\n                prefix=f\"{prefix}.proj.{i}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_predict)\n        ]\n        self.head = nn.ModuleList(\n            [\n                FastLinear.load(config, f\"{prefix}.head.{i}\", weights, bias=False)\n                for i in range(self.n_predict)\n            ]\n        )\n        self.ln = nn.ModuleList(\n            [\n                MLPSpeculatorLayerNorm(\n                    prefix=f\"{prefix}.ln.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(self.n_predict)\n            ]\n        )\n\n        # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation\n        self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1\n        self.activation = nn.GELU()\n        self.vsize = config.vocab_size\n        self.inner_dim = config.speculator_config[\"inner_dim\"]\n        self.top_k_tokens_per_head = [1] * self.n_predict\n        self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(\n            self.inner_dim / 2\n        )\n        self.emb.weight *= self.emb_weight\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_ids: torch.Tensor,\n    ):\n        top_k_tokens_per_head = self.top_k_tokens_per_head\n\n        # k indicates # of candidates\n        # h indicates # of generated tokens\n        state = hidden_states\n        b = state.size(0)\n        ind = input_ids.unsqueeze(0)\n        all_probs = torch.empty(\n            b, self.n_predict, self.vsize, device=state.device\n        )  # b k h v\n        assert (\n            len(top_k_tokens_per_head) == self.n_predict\n        ), f\"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)\"\n        for i in range(self.n_predict):\n            # Project and predict\n            z = self.emb[i](ind)\n            # z = z.mul(self.emb_weight)  # b k d\n            state = self.proj[i](state) * self.state_weight + z\n            state = self.activation(self.ln[i](state))  # b k d\n            probs = F.log_softmax(self.head[i](state), dim=-1)  # b k v\n            _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1)  # b k k'\n\n            # Update candidate set with new predictions\n\n            # Update distribution set with new logits\n            all_probs[:, i] = probs.exp()\n\n            # Update state, log_probs and ind for new predictions\n            state = state.unsqueeze(2).expand(\n                -1, -1, top_k_tokens_per_head[i], -1\n            )  # b k k' d\n            state = state.reshape(-1, b, state.size(3))  # b kk' d\n            ind = preds.view(-1, b)  # b kk'\n\n        speculative_logits = all_probs\n        return speculative_logits\n\n\nclass MLPSpeculatorHead(nn.Module):\n    def __init__(self, lm_head, mlp_speculator, scale_input: bool):\n        super().__init__()\n        self.lm_head = lm_head\n        self.mlp_speculator = mlp_speculator\n        self.scale_input = scale_input\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        logits = self.lm_head(input)\n        # If we have too many tokens, we skip speculative logits\n        if input.shape[0] > 128:\n            return logits, None\n\n        input_ids = logits.argmax(dim=-1)\n        if self.scale_input:\n            input = simple_norm(input)\n        speculative_logits = self.mlp_speculator(input, input_ids)\n        return logits, speculative_logits\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        from pathlib import Path\n        from safetensors import safe_open\n\n        speculator_path = config.speculator[\"path\"]\n\n        for fname in config.speculator[\"model_paths\"]:\n            filename = str(Path(speculator_path) / fname)\n            routing = weights.routing\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing and routing[k] != filename:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n\n        tie_weights = config.speculator_config.get(\"tie_weights\", False)\n        if tie_weights:\n            mlp_speculator = MLPSpeculatorModelTied(config, \"speculator\", weights)\n        else:\n            mlp_speculator = MLPSpeculatorModel(config, \"speculator\", weights)\n        # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator\n        scale_input = config.speculator_config.get(\"scale_input\", False)\n        lm_head = TensorParallelHead.load(config, prefix, weights)\n        return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/moe/__init__.py",
    "content": "from typing import Optional, Protocol, runtime_checkable\n\nimport torch\nimport torch.nn as nn\nfrom loguru import logger\nfrom transformers.activations import ACT2FN\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.layers.fp8 import HybridFP8UnquantLoader\nfrom text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer\nfrom text_generation_server.layers.moe.fp8 import FP8SparseMoELayer\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    Weights,\n    UnquantizedWeight,\n)\n\nfrom .fused_moe import fused_topk, grouped_topk\n\n# NOTE: we are using a protocol here, because multiple inherance is not nice.\n#       We need `Module`, and `Module` -> some abstract class -> some concrete\n#       class inheritance is whacky.\n\n\n@runtime_checkable\nclass MoELayer(Protocol):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n        hidden_act: str = \"silu\",\n        scoring_func: Optional[str] = None,\n        e_score_correction_bias: Optional[float] = None,\n    ): ...\n\n    def forward(\n        self, x: torch.Tensor, *, gating_output: torch.Tensor\n    ) -> torch.Tensor: ...\n\n\nclass DenseMoELayer(nn.Module):\n    \"\"\"\n    Layer for MoE that applies *all* experts to each tokens and then weights\n    their outputs based on the calculated routing. This layer is much slower\n    than `SparseMoELayer` and should only be used when no fused kernels are\n    available (e.g. for unsupported quantizers).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n        hidden_act: str = \"silu\",\n        scoring_func: Optional[str] = None,\n        e_score_correction_bias: Optional[float] = None,\n    ):\n        super().__init__()\n\n        assert scoring_func is None, \"scoring func is not handled\"\n        assert e_score_correction_bias is None, \"scoring correction bias is not handled\"\n\n        log_once(\n            logger.info,\n            \"No fused layers are available for this model type, using (slower) dense MoE layer\",\n        )\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.n_experts = n_experts\n        self.renormalize = renormalize\n        self.topk = topk\n        self.topk_group = topk_group\n\n        if \"gelu\" in hidden_act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        elif \"silu\" in hidden_act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[hidden_act]\n\n        self.gate_proj = [\n            TensorParallelColumnLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{gate_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n        self.up_proj = [\n            TensorParallelColumnLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{up_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n        self.down_proj = [\n            TensorParallelRowLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{down_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        x: (sequence_length, model_dim)\n        gating_output: (sequence_length, n_experts)\n        \"\"\"\n        # optional reshape\n        input_shape = x.shape\n        x = x.view(-1, input_shape[-1])\n\n        if self.n_expert_group is not None and self.topk_group is not None:\n            topk_weights, topk_ids = grouped_topk(\n                x,\n                gating_output,\n                self.topk,\n                renormalize=self.renormalize,\n                num_expert_group=self.n_expert_group,\n                topk_group=self.topk_group,\n            )\n        else:\n            topk_weights, topk_ids = fused_topk(\n                x, gating_output, self.topk, self.renormalize\n            )\n            topk_weights = topk_weights.to(x.dtype)\n\n        weights = torch.zeros(\n            topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device\n        )\n\n        weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))\n\n        out = torch.zeros_like(x)\n        for i in range(self.n_experts):\n            h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)\n            h = self.down_proj[i](h, reduce=False)\n            out += h * weights[:, i].view(-1, 1)\n\n        return out\n\n\nclass SparseMoELayer(nn.Module):\n    \"\"\"\n    Layer for MoE that uses fused kernels to only apply the active experts\n    for each token (rather than applying all experts and selecting the\n    outputs of active experts).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n        if (\n            isinstance(weights.loader, DefaultWeightsLoader)\n            and isinstance(weights.loader.weight_class, UnquantizedWeight)\n        ) or isinstance(weights.loader, HybridFP8UnquantLoader):\n            if (\n                isinstance(weights.loader, HybridFP8UnquantLoader)\n                and weights.loader.to_fp8\n            ):\n                cls = FP8SparseMoELayer\n            else:\n                cls = UnquantizedSparseMoELayer\n        else:\n            raise ValueError(\n                f\"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights\"\n            )\n\n        log_once(\n            logger.info,\n            \"Using MoE layer wih fused gemm\",\n        )\n\n        self.moe = cls(\n            n_expert_group=n_expert_group,\n            n_experts=n_experts,\n            prefix=prefix,\n            renormalize=renormalize,\n            topk=topk,\n            topk_group=topk_group,\n            weights=weights,\n            scoring_func=scoring_func,\n            e_score_correction_bias=e_score_correction_bias,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            down_proj_name=down_proj_name,\n        )\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        return self.moe(x, gating_output=gating_output)\n\n    @staticmethod\n    def is_supported(weights: Weights) -> bool:\n        return (\n            isinstance(weights.loader, DefaultWeightsLoader)\n            and isinstance(weights.loader.weight_class, UnquantizedWeight)\n        ) or isinstance(weights.loader, HybridFP8UnquantLoader)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/moe/fp8.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport os\n\nfrom text_generation_server.utils.weights import Weights\nfrom text_generation_server.layers.fp8 import (\n    Fp8Weight,\n    fp8_quantize,\n    quant_dtype,\n    normalize_e4m3fn_to_native_float8,\n    dynamic_quant,\n    dequant_block_fp8_weight_naive,\n)\nfrom text_generation_server.layers.moe.fused_moe import select_experts\nimport habana_frameworks.torch as htorch\n\n\nclass FP8SparseMoELayer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.topk = topk\n        self.topk_group = topk_group\n        self.renormalize = renormalize\n        self.weight_block_size = weights.weights_loader.weight_block_size\n        self.scoring_func = scoring_func\n        self.e_score_correction_bias = e_score_correction_bias\n        self.world_size = weights.process_group.size()\n        self.rank = weights.process_group.rank()\n        self.ep_rank = self.rank\n        self.use_ep = os.getenv(\"USE_EXPERT_PARALLEL\", \"true\").lower() == \"true\"\n        if (n_experts + self.world_size - 1) // self.world_size < 4:\n            self.use_ep = False\n        if self.use_ep:\n            n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size\n            self.ep_offset = self.ep_rank * n_experts_per_rank\n            n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)\n        else:\n            self.ep_offset = 0\n\n        (\n            self.gate_up_proj,\n            self.gate_up_proj_weight_scale,\n            self.gate_up_proj_input_scale,\n        ) = _load_expert_multi_weights_col(\n            prefix=prefix,\n            n_experts=n_experts,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            weights=weights,\n            use_ep=self.use_ep,\n            ep_offset=self.ep_offset,\n        )\n\n        self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (\n            _load_expert_weights_row(\n                prefix=prefix,\n                n_experts=n_experts,\n                name=down_proj_name,\n                weights=weights,\n                use_ep=self.use_ep,\n                ep_offset=self.ep_offset,\n            )\n        )\n        if self.weight_block_size is not None:\n            self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(\n                dequant_block_fp8_weight_naive(\n                    self.gate_up_proj,\n                    self.gate_up_proj_weight_scale,\n                    self.weight_block_size,\n                )\n            )\n            self.down_proj, self.down_proj_weight_scale = dynamic_quant(\n                dequant_block_fp8_weight_naive(\n                    self.down_proj, self.down_proj_weight_scale, self.weight_block_size\n                )\n            )\n            self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (\n                self.gate_up_proj_weight_scale.squeeze(-1),\n                self.down_proj_weight_scale.squeeze(-1),\n            )\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        topk_weights, topk_ids = select_experts(\n            hidden_states=x,\n            router_logits=gating_output,\n            use_grouped_topk=self.n_expert_group is not None,\n            top_k=self.topk,\n            renormalize=self.renormalize,\n            topk_group=self.topk_group,\n            num_expert_group=self.n_expert_group,\n            scoring_func=self.scoring_func,\n            e_score_correction_bias=self.e_score_correction_bias,\n        )\n        total_num_experts = gating_output.size(-1)\n        x_fp8, x_scale = dynamic_quant(x, single_scale=True)\n\n        if self.use_ep:\n            moe_n_slice = 1\n            n_expert_slice = (\n                total_num_experts + self.world_size - 1\n            ) // self.world_size\n        else:\n            moe_n_slice = 1\n            n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice\n        for i in range(moe_n_slice):\n            min_expert = i * n_expert_slice\n            max_expert = min((i + 1) * n_expert_slice, total_num_experts)\n            w13_list_slice = [\n                self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)\n            ]\n            w2_list_slice = [\n                self.down_proj[j, ...] for j in range(min_expert, max_expert)\n            ]\n            w13_weight_scale = [\n                self.gate_up_proj_weight_scale[j, ...]\n                for j in range(min_expert, max_expert)\n            ]\n            w2_weight_scale = [\n                self.down_proj_weight_scale[j, ...]\n                for j in range(min_expert, max_expert)\n            ]\n\n            current_hidden_states = torch.ops.hpu.mixture_of_experts(\n                hidden_states=x_fp8,\n                expert_routing_table=topk_ids.to(torch.int64),\n                router_weights=topk_weights.to(x.dtype),\n                w12=w13_list_slice,\n                w3=w2_list_slice,\n                d_scale_hidden_states=x_scale,\n                d_scale_w12=w13_weight_scale,\n                d_scale_w3=w2_weight_scale,\n                permuted_weights=True,\n                activation=\"silu\",\n                experts_min=min_expert + self.ep_offset,\n                experts_max=max_expert + self.ep_offset - 1,\n            )\n            htorch.core.mark_step()\n            if i == 0:\n                final_hidden_states = current_hidden_states\n            else:\n                final_hidden_states.add_(current_hidden_states)\n        return final_hidden_states\n\n\ndef _load_expert_weights(\n    get_weight_fn,\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n    ep_offset: int = 0,\n) -> torch.Tensor:\n    all_weight = None\n    all_weight_scales = None\n    max_input_scale = None\n\n    for i in range(n_experts):\n        weight = get_weight_fn(prefix, i + ep_offset, name, weights)\n\n        assert isinstance(weight, Fp8Weight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=quant_dtype,\n                device=weight.weight.device,\n            )\n        if all_weight_scales is None:\n            all_weight_scales = torch.empty(\n                (n_experts,) + weight.weight_scale.shape,\n                dtype=torch.float32,\n                device=weight.weight.device,\n            )\n\n        if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:\n            all_weight[i], all_weight_scales[i], current_input_scale = (\n                normalize_e4m3fn_to_native_float8(\n                    weight.weight, weight.weight_scale, weight.input_scale\n                )\n            )\n            if current_input_scale is not None:\n                if max_input_scale is None or current_input_scale > max_input_scale:\n                    max_input_scale = current_input_scale\n        else:\n            all_weight[i], all_weight_scales[i] = fp8_quantize(\n                weight.weight, scalar=True\n            )\n\n    assert all_weight is not None\n\n    return all_weight, all_weight_scales, max_input_scale\n\n\ndef _load_expert_multi_weights_col(\n    *,\n    prefix: str,\n    n_experts: int,\n    gate_proj_name: str,\n    up_proj_name: str,\n    weights: Weights,\n    use_ep: bool = False,\n    ep_offset: int = 0,\n) -> torch.Tensor:\n    def get_weight_fn_sharded(prefix, i, name, weights):\n        return weights.get_multi_weights_col(\n            [f\"{prefix}.{i}.{gate_proj_name}\", f\"{prefix}.{i}.{up_proj_name}\"], 0\n        )\n\n    def get_weight_fn(prefix, i, name, weights):\n        return weights.get_multi_weights(\n            [f\"{prefix}.{i}.{gate_proj_name}\", f\"{prefix}.{i}.{up_proj_name}\"], 0\n        )\n\n    return _load_expert_weights(\n        get_weight_fn if use_ep else get_weight_fn_sharded,\n        prefix=prefix,\n        n_experts=n_experts,\n        name=None,\n        weights=weights,\n        ep_offset=ep_offset if use_ep else 0,\n    )\n\n\ndef _load_expert_weights_row(\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n    use_ep: bool = False,\n    ep_offset: int = 0,\n) -> torch.Tensor:\n    def get_weight_fn_sharded(prefix, i, name, weights):\n        return weights.get_weights_row(f\"{prefix}.{i}.{name}\")\n\n    def get_weight_fn(prefix, i, name, weights):\n        return weights.get_weights(f\"{prefix}.{i}.{name}\")\n\n    return _load_expert_weights(\n        get_weight_fn if use_ep else get_weight_fn_sharded,\n        prefix=prefix,\n        n_experts=n_experts,\n        name=name,\n        weights=weights,\n        ep_offset=ep_offset if use_ep else 0,\n    )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Tuple, Optional\n\nimport torch\n\n\ndef grouped_topk(\n    hidden_states: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n    num_expert_group: int = 0,\n    topk_group: int = 0,\n    scoring_func: str = \"softmax\",\n    e_score_correction_bias: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    assert hidden_states.shape[0] == gating_output.shape[0], \"Number of tokens mismatch\"\n\n    gating_output = gating_output.float()\n    if e_score_correction_bias is not None:\n        e_score_correction_bias = e_score_correction_bias.float()\n\n    if scoring_func == \"softmax\":\n        scores = torch.softmax(gating_output, dim=-1)\n    elif scoring_func == \"sigmoid\":\n        scores = gating_output.sigmoid()\n    else:\n        raise ValueError(f\"Unsupported scoring function: {scoring_func}\")\n\n    num_token = scores.shape[0]\n    if e_score_correction_bias is not None:\n        # Store original scores before applying correction bias. We use biased\n        # scores for expert selection but original scores for routing weights\n        original_scores = scores\n        scores = scores + e_score_correction_bias.unsqueeze(0)\n        group_scores = (\n            scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)\n        )\n    else:\n        group_scores = (\n            scores.view(num_token, num_expert_group, -1).max(dim=-1).values\n        )  # [n, n_group]\n\n    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[\n        1\n    ]  # [n, top_k_group]\n    group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n    group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n    score_mask = (\n        group_mask.unsqueeze(-1)\n        .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)\n        .reshape(num_token, -1)\n    )  # [n, e]\n    tmp_scores = scores.masked_fill(~score_mask.bool(), float(\"-inf\"))  # [n, e]\n\n    if e_score_correction_bias is not None:\n        topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]\n        # Use original unbiased scores for the routing weights\n        topk_weights = original_scores.gather(1, topk_ids)\n    else:\n        topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)\n\n    if renormalize:\n        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n\n    return topk_weights.to(torch.float32), topk_ids.to(torch.int32)\n\n\ndef fused_topk(\n    hidden_states: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    topk_weights = torch.nn.functional.softmax(\n        gating_output, dim=1, dtype=torch.float32\n    )\n    topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)\n    if renormalize:\n        topk_weights /= topk_weights.sum(dim=-1, keepdim=True)\n    return topk_weights, topk_ids\n\n\ndef select_experts(\n    hidden_states: torch.Tensor,\n    router_logits: torch.Tensor,\n    top_k: int,\n    use_grouped_topk: bool,\n    renormalize: bool,\n    topk_group: Optional[int] = None,\n    num_expert_group: Optional[int] = None,\n    scoring_func: str = \"softmax\",\n    e_score_correction_bias: Optional[torch.Tensor] = None,\n):\n\n    # DeekSeekv2 uses grouped_top_k\n    if use_grouped_topk:\n        assert topk_group is not None\n        assert num_expert_group is not None\n        topk_weights, topk_ids = grouped_topk(\n            hidden_states=hidden_states,\n            gating_output=router_logits,\n            topk=top_k,\n            renormalize=renormalize,\n            num_expert_group=num_expert_group,\n            topk_group=topk_group,\n            scoring_func=scoring_func,\n            e_score_correction_bias=e_score_correction_bias,\n        )\n    else:\n        topk_weights, topk_ids = fused_topk(\n            hidden_states=hidden_states,\n            gating_output=router_logits,\n            topk=top_k,\n            renormalize=renormalize,\n        )\n    return topk_weights, topk_ids\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/moe/unquantized.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.utils.weights import UnquantizedWeight, Weights\nfrom vllm_hpu_extension.ops import VllmMixtureOfExpertsOp\nimport habana_frameworks.torch as htorch\nimport torch.nn.functional as F\nimport os\n\n\nclass UnquantizedSparseMoELayer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.topk = topk\n        self.topk_group = topk_group\n        self.renormalize = renormalize\n        self.weight_block_size = weights.weights_loader.weight_block_size\n        self.scoring_func = scoring_func\n        self.e_score_correction_bias = e_score_correction_bias\n        self.rank = weights.process_group.rank()\n        self.world_size = weights.process_group.size()\n        self.use_ep = os.getenv(\"USE_EXPERT_PARALLEL\", \"true\").lower() == \"true\"\n        if (n_experts + self.world_size - 1) // self.world_size < 4:\n            self.use_ep = False\n        if self.use_ep:\n            n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size\n            self.ep_offset = self.rank * n_experts_per_rank\n            n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)\n            experts_min = self.ep_offset\n            experts_max = self.ep_offset + n_experts - 1\n        else:\n            self.ep_offset = 0\n            experts_min = 0\n            experts_max = n_experts - 1\n\n        self.gate_up_proj = _load_expert_multi_weights_col(\n            prefix=prefix,\n            n_experts=n_experts,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            weights=weights,\n            use_ep=self.use_ep,\n            ep_offset=self.ep_offset,\n        )\n\n        self.down_proj = _load_expert_weights_row(\n            prefix=prefix,\n            n_experts=n_experts,\n            name=down_proj_name,\n            weights=weights,\n            use_ep=self.use_ep,\n            ep_offset=self.ep_offset,\n        )\n\n        self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)\n        for i in range(n_experts):\n            self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])\n            self.MoeOp.w2_list[i].set_weight(self.down_proj[i])\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        htorch.core.mark_step()\n        routing_weights = F.softmax(gating_output, dim=1, dtype=torch.float32)\n        routing_weights, selected_experts = torch.topk(\n            routing_weights, self.topk, dim=-1\n        )\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        routing_weights = routing_weights.to(x.dtype)\n\n        final_hidden_states = self.MoeOp(\n            hidden_states=x,\n            expert_routing_table=selected_experts,\n            router_weights=routing_weights,\n            permuted_weights=True,\n            activation=\"silu\",\n        )\n\n        return final_hidden_states.view(-1, x.shape[1])\n\n\ndef _load_expert_multi_weights_col(\n    *,\n    prefix: str,\n    n_experts: int,\n    gate_proj_name: str,\n    up_proj_name: str,\n    weights: Weights,\n    use_ep: bool = False,\n    ep_offset: int = 0,\n) -> torch.Tensor:\n    all_weight = None\n    for i in range(n_experts):\n        if not use_ep:\n            weight = weights.get_multi_weights_col(\n                [f\"{prefix}.{i}.{gate_proj_name}\", f\"{prefix}.{i}.{up_proj_name}\"], 0\n            )\n        else:\n            weight = weights.get_multi_weights(\n                [\n                    f\"{prefix}.{i+ep_offset}.{gate_proj_name}\",\n                    f\"{prefix}.{i+ep_offset}.{up_proj_name}\",\n                ],\n                0,\n            )\n\n        assert isinstance(weight, UnquantizedWeight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=weight.weight.dtype,\n                device=weight.weight.device,\n            )\n\n        all_weight[i] = weight.weight\n\n    assert all_weight is not None\n\n    return all_weight\n\n\ndef _load_expert_weights_row(\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n    use_ep: bool = False,\n    ep_offset: int = 0,\n) -> torch.Tensor:\n    all_weight = None\n    for i in range(n_experts):\n        if not use_ep:\n            weight = weights.get_weights_row(\n                f\"{prefix}.{i}.{name}\",\n            )\n        else:\n            weight = weights.get_weights(\n                f\"{prefix}.{i+ep_offset}.{name}\",\n            )\n\n        assert isinstance(weight, UnquantizedWeight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=weight.weight.dtype,\n                device=weight.weight.device,\n            )\n\n        all_weight[i] = weight.weight\n\n    assert all_weight is not None\n\n    return all_weight\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/rotary.py",
    "content": "import os\nimport math\nimport torch\nfrom torch import nn\nfrom habana_frameworks.torch.hpex.kernels import (\n    RotaryPosEmbeddingMode,\n    apply_rotary_pos_emb,\n)\n\n\ndef _create_inv_freq(dim, base, device):\n    inv_freq = 1.0 / (\n        base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)\n    )\n    return inv_freq\n\n\ndef _get_rope_config(config):\n    if os.getenv(\"ROPE_SCALING\", None) is not None:\n        rope_scaling = {\n            \"type\": os.environ[\"ROPE_SCALING\"],\n            \"factor\": float(os.environ[\"ROPE_FACTOR\"]),\n        }\n        return rope_scaling\n    return getattr(config, \"rope_scaling\", None)\n\n\nclass PositionRotaryEmbedding(nn.Module):\n    def __init__(self, inv_freq, scaling_factor, max_position_embeddings):\n        super().__init__()\n        self.inv_freq = inv_freq\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.scaling_factor = scaling_factor\n        self.dynamic_args = None\n        self._update_cos_sin_cache(\n            torch.float32, inv_freq.device, max_position_embeddings\n        )\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        num_tokens = query.shape[0]\n        head_size = query.shape[-1]\n        # HPU RoPE kernel requires hidden dimension for cos and sin to be equal\n        # to query hidden dimension, so the original tensors need to be\n        # expanded\n        # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE\n        # and expansion of cos/sin tensors via concatenation\n        rope_mode = RotaryPosEmbeddingMode.BLOCKWISE\n        cos = torch.cat((cos, cos), dim=-1)\n        sin = torch.cat((sin, sin), dim=-1)\n        rotary_dim = cos.shape[-1]\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, head_size)\n        query_rot = query[..., :rotary_dim]\n        query_pass = query[..., rotary_dim:]\n        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)\n        query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))\n\n        key_shape = key.shape\n        key = key.view(num_tokens, -1, head_size)\n        key_rot = key[..., :rotary_dim]\n        key_pass = key[..., rotary_dim:]\n        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)\n        key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))\n\n    @classmethod\n    def static(cls, config, dim, base, device):\n        inv_freq = _create_inv_freq(dim, base, device)\n        scaling_factor = None\n        rope_scaling = _get_rope_config(config)\n        if not hasattr(config, \"max_position_embeddings\") and hasattr(\n            config, \"max_seq_len\"\n        ):\n            # handling for dbrx\n            config.max_position_embeddings = config.max_seq_len\n        if rope_scaling is not None:\n            # `rope_type` is now standard in transformers, but some existing models\n            # have `type` instead.\n            rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))\n\n            if rope_type == \"linear\":\n                pass\n            elif rope_type == \"default\":\n                pass\n            elif rope_type == \"mrope\":\n                mrope_section = rope_scaling[\"mrope_section\"]\n                if mrope_section is not None:\n                    return RotaryPositionEmbeddingMultimodalSections(\n                        inv_freq,\n                        scaling_factor,\n                        mrope_section,\n                        config.max_position_embeddings,\n                    )\n            elif rope_type == \"dynamic\":\n                scaling_factor = rope_scaling[\"factor\"]\n                return DynamicPositionRotaryEmbedding(\n                    dim=dim,\n                    max_position_embeddings=config.max_position_embeddings,\n                    base=base,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                )\n            elif rope_type == \"llama3\":\n                inv_freq = apply_llama3_scaling(\n                    inv_freq,\n                    scaling_factor=rope_scaling[\"factor\"],\n                    low_freq_factor=rope_scaling[\"low_freq_factor\"],\n                    high_freq_factor=rope_scaling[\"high_freq_factor\"],\n                    original_max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                )\n\n                return cls(inv_freq, scaling_factor, config.max_position_embeddings)\n\n            elif rope_type == \"yarn\":\n                scaling_factor = rope_scaling[\"factor\"]\n                mscale = rope_scaling.get(\"mscale\", 1.0)\n                mscale_all_dim = rope_scaling.get(\"mscale_all_dim\", 0.0)\n                return YarnPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                    base=base,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                    extrapolation_factor=1,\n                    attn_factor=1,\n                    beta_fast=32,\n                    beta_slow=1,\n                    mscale=mscale,\n                    mscale_all_dim=mscale_all_dim,\n                )\n            elif rope_type in [\"su\", \"longrope\"]:\n                short_factor = torch.tensor(\n                    rope_scaling[\"short_factor\"], dtype=torch.float32, device=device\n                )\n                short_inv_freq = 1.0 / (\n                    short_factor\n                    * base\n                    ** (\n                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)\n                        / dim\n                    )\n                )\n                long_factor = torch.tensor(\n                    rope_scaling[\"long_factor\"], dtype=torch.float32, device=device\n                )\n                long_inv_freq = 1.0 / (\n                    long_factor\n                    * base\n                    ** (\n                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)\n                        / dim\n                    )\n                )\n\n                original_max_position_embeddings = (\n                    config.original_max_position_embeddings\n                )\n                max_position_embeddings = config.max_position_embeddings\n                if max_position_embeddings <= original_max_position_embeddings:\n                    scaling_factor = 1.0\n                else:\n                    scale = max_position_embeddings / original_max_position_embeddings\n                    scaling_factor = math.sqrt(\n                        1 + math.log(scale) / math.log(original_max_position_embeddings)\n                    )\n\n                # if short_mscale and long_mscale are provided we need to scale the freqs\n                # using the Phi3LongRoPEScaledRotaryEmbedding\n                if (\"short_mscale\" in rope_scaling) and (\"long_mscale\" in rope_scaling):\n                    short_mscale = rope_scaling[\"short_mscale\"]\n                    long_mscale = rope_scaling[\"long_mscale\"]\n                    return Phi3LongRoPEScaledRotaryEmbedding(\n                        short_inv_freq=short_inv_freq,\n                        long_inv_freq=long_inv_freq,\n                        max_position_embeddings=config.max_position_embeddings,\n                        short_mscale=short_mscale,\n                        long_mscale=long_mscale,\n                        original_max_position_embeddings=original_max_position_embeddings,\n                    )\n\n                return SuRotaryEmbedding(\n                    short_inv_freq=short_inv_freq,\n                    long_inv_freq=long_inv_freq,\n                    scaling_factor=scaling_factor,\n                    original_max_position_embeddings=original_max_position_embeddings,\n                    max_position_embeddings=config.max_position_embeddings,\n                )\n            else:\n                raise NotImplementedError(\n                    f\"rope scaling type {rope_scaling['type']} is not implemented or invalid\"\n                )\n        return cls(inv_freq, scaling_factor, config.max_position_embeddings)\n\n    @classmethod\n    def load(cls, config, prefix, weights):\n        # XXX: Always load this in float32 !\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        inv_freq = weights.get_tensor(f\"{prefix}.inv_freq\")\n        weights.dtype = dtype\n\n        scaling_factor = None\n        rope_scaling = _get_rope_config(config)\n        if rope_scaling is not None:\n            scaling_factor = rope_scaling[\"factor\"]\n            if rope_scaling[\"type\"] == \"linear\":\n                pass\n            elif rope_scaling[\"type\"] == \"dynamic\":\n                return DynamicPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=config.max_position_embeddings,\n                    base=10000.0,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                )\n            elif rope_scaling[\"type\"] == \"yarn\":\n                mscale = rope_scaling.get(\"mscale\", 1.0)\n                mscale_all_dim = rope_scaling.get(\"mscale_all_dim\", 0.0)\n                return YarnPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                    base=10000.0,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                    extrapolation_factor=1,\n                    attn_factor=1,\n                    beta_fast=32,\n                    beta_slow=1,\n                    mscale=mscale,\n                    mscale_all_dim=mscale_all_dim,\n                )\n            else:\n                raise NotImplementedError(\n                    f\"rope scaling type {rope_scaling['type']} is not implemented or invalid\"\n                )\n        return cls(inv_freq, scaling_factor, config.max_position_embeddings)\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            if self.scaling_factor is not None:\n                t /= self.scaling_factor\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n    def get_cos_sin(self, position_ids: torch.Tensor):\n\n        cos = torch.index_select(self._cos_cached, 0, position_ids)\n        sin = torch.index_select(self._sin_cached, 0, position_ids)\n\n        # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.\n        return cos.unsqueeze(1), sin.unsqueeze(1)\n\n\nclass SuRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        short_inv_freq,\n        long_inv_freq,\n        scaling_factor,\n        original_max_position_embeddings,\n        max_position_embeddings,\n    ):\n        super(PositionRotaryEmbedding, self).__init__()\n        self.short_inv_freq = short_inv_freq\n        self.long_inv_freq = long_inv_freq\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.dynamic_args = None\n        self._update_cos_sin_cache(\n            torch.float32, short_inv_freq.device, max_position_embeddings\n        )\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n\n            t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)\n            short_freqs = torch.outer(\n                t[: self.original_max_position_embeddings],\n                self.short_inv_freq.to(device=t.device),\n            )\n            long_freqs = torch.outer(\n                t[self.original_max_position_embeddings :],\n                self.long_inv_freq.to(device=t.device),\n            )\n\n            freqs = torch.cat([short_freqs, long_freqs])\n\n            self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)\n            self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)\n\n\nclass Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        short_inv_freq: torch.Tensor,\n        long_inv_freq: torch.Tensor,\n        max_position_embeddings: int,\n        short_mscale: float,\n        long_mscale: float,\n        original_max_position_embeddings: int,\n    ):\n        super(PositionRotaryEmbedding, self).__init__()\n        self.short_inv_freq = short_inv_freq\n        self.long_inv_freq = long_inv_freq\n        self.max_position_embeddings = max_position_embeddings\n        self.short_mscale = short_mscale\n        self.long_mscale = long_mscale\n        self.original_max_position_embeddings = original_max_position_embeddings\n\n        # cache\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.dynamic_args = None\n        self._update_cos_sin_cache(\n            torch.float32, short_inv_freq.device, max_position_embeddings\n        )\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)\n\n            short_freqs = torch.outer(\n                t[: self.original_max_position_embeddings],\n                self.short_inv_freq.to(device=t.device),\n            )\n\n            long_freqs = torch.outer(\n                t[self.original_max_position_embeddings :],\n                self.long_inv_freq.to(device=t.device),\n            )\n\n            short_freqs = short_freqs * self.short_mscale\n            long_freqs = long_freqs * self.long_mscale\n\n            freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)\n            freqs[: self.original_max_position_embeddings] = short_freqs\n            freqs[self.original_max_position_embeddings :] = long_freqs\n\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n\nclass DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):\n        inv_freq = _create_inv_freq(dim, base, device)\n        super().__init__(inv_freq, scaling_factor, max_position_embeddings)\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            if seqlen > self.max_position_embeddings:\n                newbase = self.base * (\n                    (self.scaling_factor * seqlen / self.max_position_embeddings)\n                    - (self.scaling_factor - 1)\n                ) ** (self.dim / (self.dim - 2))\n                self.inv_freq = _create_inv_freq(\n                    self.dim, newbase, self.inv_freq.device\n                )\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n\ndef find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))\n    high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\ndef get_mscale(scale: float = 1.0, mscale: float = 1.0):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\nclass YarnPositionRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings,\n        base,\n        device,\n        scaling_factor,\n        *,\n        extrapolation_factor,\n        attn_factor,\n        beta_fast,\n        beta_slow,\n        mscale: float,\n        mscale_all_dim: float,\n    ):\n        inv_freq = _create_inv_freq(dim, base, device)\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.extrapolation_factor = extrapolation_factor\n        self.attn_factor = attn_factor\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.mscale = float(\n            get_mscale(self.scaling_factor, mscale)\n            / get_mscale(self.scaling_factor, mscale_all_dim)\n            * self.attn_factor\n        )  # Get n-d magnitude scaling corrected for interpolation\n        super().__init__(inv_freq, scaling_factor, max_position_embeddings)\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            if seqlen > self.max_position_embeddings or True:\n                inv_freq_extrapolation = _create_inv_freq(\n                    self.dim, self.base, self.inv_freq.device\n                )\n                freqs = 1.0 / inv_freq_extrapolation\n                inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)\n                low, high = find_correction_range(\n                    self.beta_fast,\n                    self.beta_slow,\n                    self.dim,\n                    self.base,\n                    self.max_position_embeddings,\n                )\n\n                inv_freq_mask = (\n                    1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)\n                ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation\n                inv_freq = (\n                    inv_freq_interpolation * (1 - inv_freq_mask)\n                    + inv_freq_extrapolation * inv_freq_mask\n                )\n\n                self.inv_freq = inv_freq\n\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)\n            self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)\n\n\ndef apply_llama3_scaling(\n    freqs: torch.Tensor,\n    *,\n    scaling_factor: int,\n    low_freq_factor: int,\n    high_freq_factor: int,\n    original_max_position_embeddings: int,\n):\n    low_freq_wavelen = original_max_position_embeddings / low_freq_factor\n    high_freq_wavelen = original_max_position_embeddings / high_freq_factor\n    new_freqs = []\n\n    for freq in freqs:\n        wavelen = 2 * math.pi / freq\n\n        if wavelen < high_freq_wavelen:\n            new_freqs.append(freq)\n        elif wavelen > low_freq_wavelen:\n            new_freqs.append(freq / scaling_factor)\n        else:\n            assert low_freq_wavelen != high_freq_wavelen\n            smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (\n                high_freq_factor - low_freq_factor\n            )\n            new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)\n\n    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)\n\n\nclass RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        inv_freq: torch.Tensor,\n        scaling_factor: float,\n        sections: list,\n        max_position_embeddings,\n    ):\n        self.sections = sections\n        self._cos_cached = None\n        self._sin_cached = None\n        self.section_indices = (\n            torch.arange(len(self.sections))\n            .repeat_interleave(torch.tensor(self.sections))\n            .view(1, 1, -1)\n            .to(inv_freq.device)\n        )\n        super().__init__(inv_freq, scaling_factor, max_position_embeddings)\n\n    def _update_cos_sin_cache(\n        self, dtype: torch.dtype, device: torch.device, seqlen: int\n    ):\n        # always cache the cos/sin for the full sequence length to avoid\n        # recomputing if the sequence length is smaller than the cached one\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n            self._sections = self.section_indices.expand(seqlen, -1, -1)\n\n    def get_cos_sin(\n        self,\n        position_ids: torch.Tensor,\n    ):\n        slen = position_ids.shape[0]\n\n        cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])\n        sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])\n        return cos, sin\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/speculative.py",
    "content": "import torch\nimport json\nfrom typing import Tuple, Optional\nfrom text_generation_server.layers.tensor_parallel import TensorParallelHead\nfrom text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2\nfrom text_generation_server.layers.mlp import MLPSpeculatorHead\n\n\nclass SpeculativeHead(torch.nn.Module):\n    def __init__(self, lm_head, speculator):\n        super().__init__()\n        self.head = lm_head\n        self.speculator = speculator\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        speculator = config.speculator\n        if speculator:\n            speculator_path = config.speculator[\"path\"]\n            speculator_config = str(speculator_path / \"config.json\")\n\n            with open(speculator_config, \"r\") as f:\n                speculator_config = json.load(f)\n\n            config.speculator_config = speculator_config\n            try:\n                architecture = speculator_config[\"architectures\"][0]\n\n                if architecture == \"MLPSpeculatorPreTrainedModel\":\n                    speculator = MLPSpeculatorHead.load(config, prefix, weights)\n                else:\n                    speculator = None\n            except KeyError:\n                try:\n                    speculator = MedusaHeadV1.load(config, prefix, weights)\n                except Exception:\n                    speculator = MedusaHeadV2(config, prefix, weights)\n            lm_head = None\n        else:\n            lm_head = TensorParallelHead.load(config, prefix, weights)\n            speculator = None\n        return SpeculativeHead(lm_head, speculator)\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        if self.speculator is not None:\n            return self.speculator(input)\n\n        assert self.head is not None\n        logits = self.head(input)\n        return logits, None\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/layers/tensor_parallel.py",
    "content": "import torch\nfrom torch.nn import functional as F\nfrom typing import Iterable, List\nfrom text_generation_server.layers.linear import get_linear, FastLinear\n\nimport habana_frameworks.torch as htorch\n\n\nclass LayerConcat(torch.nn.Module):\n    \"\"\"\n    Apply multiple layers to the input and concatenate their\n    outputs.\n    \"\"\"\n\n    def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):\n        \"\"\"\n        `dim` is the dimension along which layer outputs are concatenated.\n        \"\"\"\n        super().__init__()\n        self.layers = layers\n        self.dim = dim\n\n    def forward(self, x: torch.Tensor):\n        outputs = [layer(x) for layer in self.layers]\n        return torch.cat(outputs, self.dim)\n\n\nclass SuperLayer(torch.nn.Module):\n    def __init__(self, linear):\n        super().__init__()\n        self.linear = linear\n\n    def forward(self, x):\n        return self.linear.forward(x)\n\n\nclass TensorParallelHead(SuperLayer):\n    def __init__(self, linear, process_group, should_gather: bool):\n        super().__init__(linear)\n        self.process_group = process_group\n        self.should_gather = should_gather\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        if config.quantize == \"exl2\":\n            try:\n                # If the piece and LM head embeddings are shared, we have\n                # non-quantized weights...\n                weight = weights.get_tensor(f\"{prefix}.weight\")\n            except Exception:\n                # ...otherwise they are quantized.\n                weight = weights.get_weights_col(prefix)\n            should_gather = weights.process_group.size() > 1\n        elif weights.process_group.size() > 1:\n            try:\n                weight = weights.get_sharded(f\"{prefix}.weight\", dim=0)\n                should_gather = True\n            except AssertionError:\n                # If the vocab size is not divisible by number of shards\n                # just load the entire thing.\n                weight = weights.get_tensor(f\"{prefix}.weight\")\n                should_gather = False\n        else:\n            weight = weights.get_tensor(f\"{prefix}.weight\")\n            should_gather = False\n\n        return TensorParallelHead(\n            get_linear(weight, bias=None),\n            process_group=weights.process_group,\n            should_gather=should_gather,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if not self.should_gather:\n            return super().forward(input)\n\n        world_size = self.process_group.size()\n        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):\n            out_dim = self.linear.weight.shape[0]\n\n            if input.shape[0] == 1:\n                world_out = input.new_empty(1, out_dim * world_size)\n                local_out = input.new_empty(1, out_dim)\n                gather_input = local_out\n            else:\n                world_out = input.new_empty(out_dim * world_size, input.shape[0])\n                gather_input = input.new_empty(out_dim, input.shape[0])\n                local_out = gather_input.T\n\n            torch.mm(input, self.linear.weight.T, out=local_out)\n            htorch.core.mark_step()\n            torch.distributed.all_gather_into_tensor(\n                world_out, gather_input, group=self.process_group\n            )\n\n            if input.shape[0] == 1:\n                return world_out\n            return world_out.T\n\n        output = super().forward(input)\n        world_output = [\n            torch.empty_like(output) for _ in range(self.process_group.size())\n        ]\n\n        htorch.core.mark_step()\n        torch.distributed.all_gather(world_output, output, group=self.process_group)\n        world_output = torch.cat(world_output, dim=-1)\n        return world_output\n\n\nclass TensorParallelColumnLinear(SuperLayer):\n    @classmethod\n    def load_gate_up(cls, config, prefix: str, weights, bias: bool):\n        \"\"\"Specific method when the QKV was joined after the fact\"\"\"\n        weight = weights.get_weights_col_packed_gate_up(prefix)\n        if bias:\n            raise NotImplementedError(\"packed_gate_up only implemented without bias\")\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load_qkv(\n        cls,\n        config,\n        prefix: str,\n        weights,\n        bias: bool,\n        num_heads: int,\n        num_key_value_heads: int,\n    ):\n        \"\"\"Specific method when the QKV was joined after the fact\"\"\"\n        weight = weights.get_weights_col_packed_qkv(\n            prefix,\n            num_heads=num_heads,\n            num_key_value_heads=num_key_value_heads,\n        )\n        if bias:\n            raise NotImplementedError(\"packed_qkv only implemented for baichuan\")\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_weights_col(prefix)\n        if bias:\n            bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):\n        if config.quantize == \"exl2\":\n            linears = []\n            for prefix in prefixes:\n                weight = weights.get_weights_col(prefix)\n                b = weights.get_tensor(f\"{prefix}.bias\") if bias else None\n                linears.append(get_linear(weight, b))\n            linear = LayerConcat(linears)\n        else:\n            weight = weights.get_multi_weights_col(prefixes, dim=dim)\n            if bias:\n                b = [weights.get_sharded(f\"{p}.bias\", dim=0) for p in prefixes]\n                bias = torch.cat(b, dim=dim)\n            else:\n                bias = None\n            linear = get_linear(weight, bias)\n        return cls(linear)\n\n\nclass TensorParallelRowLinear(SuperLayer):\n    def __init__(self, linear, process_group):\n        super().__init__(linear)\n        self.process_group = process_group\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_weights_row(prefix)\n\n        if bias and weights.process_group.rank() == 0:\n            # Rank is only on the first rank process\n            bias = weights.get_tensor(f\"{prefix}.bias\")\n        else:\n            bias = None\n        return cls(\n            get_linear(weight, bias),\n            process_group=weights.process_group,\n        )\n\n    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:\n        out = super().forward(input)\n        if self.process_group.size() > 1 and reduce:\n            # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge\n            # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used\n            # (which is required for tensor parallel HPUGraph inference)\n            htorch.core.mark_step()\n            torch.distributed.all_reduce(out, group=self.process_group)\n        return out\n\n\nclass TensorParallelEmbedding(torch.nn.Module):\n    def __init__(self, prefix: str, weights, reduce=True):\n        super().__init__()\n        weight = weights.get_partial_sharded(f\"{prefix}.weight\", dim=0)\n        num_embeddings = weights.get_shape(f\"{prefix}.weight\")[0]\n\n        process_group = weights.process_group\n\n        world_size = process_group.size()\n        rank = process_group.rank()\n\n        block_size = (num_embeddings + world_size - 1) // world_size\n        self.min_id = rank * block_size\n        self.max_id = min(num_embeddings, (rank + 1) * block_size)\n        self.null_idx = weight.shape[\n            0\n        ]  # Usually block_size, might be less in non even vocab_size.\n        self.process_group = weights.process_group\n        self.reduce = reduce\n\n        \"\"\"Additional 0 entry used for masking\"\"\"\n        self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        # default all out of bounds values to `self.null_idx` that will then be mapped to 0\n        # translate for [0, self.max_id - self.min_id[\n        input = torch.where(\n            (self.min_id > input) | (input >= self.max_id),\n            self.null_idx,\n            input - self.min_id,\n        )\n        out = torch.nn.functional.embedding(input, self.weight)\n        if self.reduce and self.process_group.size() > 1:\n            # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge\n            # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used\n            # (which is required for tensor parallel HPUGraph inference)\n            htorch.core.mark_step()\n            torch.distributed.all_reduce(out, group=self.process_group)\n        return out\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/__init__.py",
    "content": "# ruff: noqa: F821\n# the above line disables the `undefined-name` rule for the model type variables\nimport torch\nimport os\n\nfrom loguru import logger\nfrom transformers.configuration_utils import PretrainedConfig\nfrom huggingface_hub import hf_hub_download, HfApi\nfrom typing import Optional\nfrom pathlib import Path\nfrom typing import List, Dict\nimport enum\n\n# Needed to properly setup habana_frameworks\n\nfrom text_generation_server.utils.speculate import get_speculate, set_speculate\nfrom text_generation_server.models.model import Model\nfrom text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (\n    PhiMoEConfig,\n)\n\nfrom text_generation_server.utils.adapter import (\n    AdapterParameters,\n    build_layer_weight_lookup,\n    load_and_merge_adapters,\n    AdapterInfo,\n)\nfrom text_generation_server.adapters.lora import LoraWeights\n\nfrom text_generation_server.utils.log import log_master\n\n__all__ = [\n    \"Model\",\n    \"CausalLM\",\n    \"Seq2SeqLM\",\n    \"get_model_with_lora_adapters\",\n]\n\nVLM_BATCH_TYPES = set()\n\nFLASH_ATTENTION = True\n\ntry:\n    from text_generation_server.models.flash_causal_lm import FlashCausalLM\n    from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM\n    from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM\n    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (\n        FlashDeepseekV2ForCausalLM,\n        DeepseekV2Config,\n    )\n    from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (\n        FlashDeepseekV3ForCausalLM,\n        DeepseekV3Config,\n    )\n    from text_generation_server.models.custom_modeling.flash_llama_modeling import (\n        FlashLlamaForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_llama4_modeling import (\n        Llama4ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (\n        FlashCohereForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n        FlashGemmaForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (\n        FlashGemma2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (\n        Gemma3ForConditionalGeneration,\n        FlashGemma3ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (\n        FlashDbrxForCausalLM,\n        DbrxConfig,\n    )\n    from text_generation_server.models.custom_modeling.flash_rw_modeling import (\n        RWConfig,\n        FlashRWForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_neox_modeling import (\n        FlashGPTNeoXForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (\n        PaliGemmaForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.flash_phi_modeling import (\n        FlashPhiForCausalLM,\n    )\n    from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch\n    from text_generation_server.models.custom_modeling.flash_mllama import (\n        FlashMllamaForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.flash_llava_next import (\n        FlashLlavaNextForConditionalGeneration,\n    )\n\n    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (\n        FlashSantacoderForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (\n        FlashStarcoder2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n        Qwen2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (\n        Qwen3ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (\n        Qwen3MoeForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (\n        FlashMistralForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (\n        FlashMixtralForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (\n        FlashGPT2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (\n        FlashGPTJForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.idefics2 import (\n        Idefics2ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.idefics3 import (\n        Idefics3ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.qwen2_vl import (\n        Qwen2VLForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.qwen2_5_vl import (\n        Qwen2_5VLForConditionalGeneration,\n        Qwen2_5_VLConfig,\n        Qwen2_5_VLProcessor,\n    )\n    from text_generation_server.layers.attention import SUPPORTS_WINDOWING\nexcept ImportError as e:\n    log_master(logger.warning, f\"Could not import Flash Attention enabled models: {e}\")\n    SUPPORTS_WINDOWING = False\n    FLASH_ATTENTION = False\n    VLM_BATCH_TYPES = set()\n\nif FLASH_ATTENTION:\n    __all__.append(FlashCausalLM)\n\n    from text_generation_server.models.flash_vlm_causal_lm import (\n        FlashVlmCausalLMBatch,\n    )\n\n    VLM_BATCH_TYPES = {\n        FlashVlmCausalLMBatch,\n        FlashMllamaCausalLMBatch,\n    }\n\n\n__all__.append(VLM_BATCH_TYPES)\n\n\nclass ModelType(enum.Enum):\n    DEEPSEEK_V2 = {\n        \"type\": \"deepseek_v2\",\n        \"name\": \"Deepseek V2\",\n        \"url\": \"https://huggingface.co/deepseek-ai/DeepSeek-V2\",\n    }\n    DEEPSEEK_V3 = {\n        \"type\": \"deepseek_v3\",\n        \"name\": \"Deepseek V3\",\n        \"url\": \"https://huggingface.co/deepseek-ai/DeepSeek-V3\",\n    }\n    IDEFICS2 = {\n        \"type\": \"idefics2\",\n        \"name\": \"Idefics 2\",\n        \"url\": \"https://huggingface.co/HuggingFaceM4/idefics2-8b\",\n        \"multimodal\": True,\n    }\n    IDEFICS3 = {\n        \"type\": \"idefics3\",\n        \"name\": \"Idefics 3\",\n        \"url\": \"https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3\",\n        \"multimodal\": True,\n    }\n    LLAVA_NEXT = {\n        \"type\": \"llava_next\",\n        \"name\": \"Llava Next (1.6)\",\n        \"url\": \"https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf\",\n        \"multimodal\": True,\n    }\n    LLAMA = {\n        \"type\": \"llama\",\n        \"name\": \"Llama\",\n        \"url\": \"https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f\",\n    }\n    LLAMA4 = {\n        \"type\": \"llama4\",\n        \"name\": \"Llama4\",\n        \"url\": \"https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f\",\n    }\n    PHI3 = {\n        \"type\": \"phi3\",\n        \"name\": \"Phi 3\",\n        \"url\": \"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\",\n    }\n    GRANITE = {\n        \"type\": \"granite\",\n        \"name\": \"Granite\",\n        \"url\": \"https://huggingface.co/ibm-granite/granite-3.0-8b-instruct\",\n    }\n    GEMMA = {\n        \"type\": \"gemma\",\n        \"name\": \"Gemma\",\n        \"url\": \"https://huggingface.co/google/gemma-7b\",\n    }\n    PALIGEMMA = {\n        \"type\": \"paligemma\",\n        \"name\": \"PaliGemma\",\n        \"url\": \"https://huggingface.co/google/paligemma-3b-pt-224\",\n    }\n    GEMMA2 = {\n        \"type\": \"gemma2\",\n        \"name\": \"Gemma2\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315\",\n    }\n    GEMMA3 = {\n        \"type\": \"gemma3\",\n        \"name\": \"Gemma3\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d\",\n    }\n    GEMMA3_TEXT = {\n        \"type\": \"gemma3_text\",\n        \"name\": \"Gemma3 Text\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d\",\n    }\n    COHERE = {\n        \"type\": \"cohere\",\n        \"name\": \"Cohere\",\n        \"url\": \"https://huggingface.co/CohereForAI/c4ai-command-r-plus\",\n    }\n    DBRX = {\n        \"type\": \"dbrx\",\n        \"name\": \"Dbrx\",\n        \"url\": \"https://huggingface.co/databricks/dbrx-instruct\",\n    }\n    MAMBA = {\n        \"type\": \"mamba\",\n        \"name\": \"Mamba\",\n        \"url\": \"https://huggingface.co/state-spaces/mamba-2.8b-slimpj\",\n    }\n    MISTRAL = {\n        \"type\": \"mistral\",\n        \"name\": \"Mistral\",\n        \"url\": \"https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407\",\n    }\n    MIXTRAL = {\n        \"type\": \"mixtral\",\n        \"name\": \"Mixtral\",\n        \"url\": \"https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1\",\n    }\n    GPT_BIGCODE = {\n        \"type\": \"gpt_bigcode\",\n        \"name\": \"Gpt Bigcode\",\n        \"url\": \"https://huggingface.co/bigcode/gpt_bigcode-santacoder\",\n    }\n    PHI = {\n        \"type\": \"phi\",\n        \"name\": \"Phi\",\n        \"url\": \"https://huggingface.co/microsoft/phi-1_5\",\n    }\n    PHI_MOE = {\n        \"type\": \"phimoe\",\n        \"name\": \"PhiMoe\",\n        \"url\": \"https://huggingface.co/microsoft/Phi-3.5-MoE-instruct\",\n    }\n    BAICHUAN = {\n        \"type\": \"baichuan\",\n        \"name\": \"Baichuan\",\n        \"url\": \"https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat\",\n    }\n    FALCON = {\n        \"type\": \"falcon\",\n        \"name\": \"Falcon\",\n        \"url\": \"https://huggingface.co/tiiuae/falcon-7b-instruct\",\n    }\n    STARCODER2 = {\n        \"type\": \"starcoder2\",\n        \"name\": \"StarCoder 2\",\n        \"url\": \"https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1\",\n    }\n    QWEN2 = {\n        \"type\": \"qwen2\",\n        \"name\": \"Qwen 2\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f\",\n    }\n    QWEN2_VL = {\n        \"type\": \"qwen2_vl\",\n        \"name\": \"Qwen 2 VL\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d\",\n    }\n    QWEN2_5_VL = {\n        \"type\": \"qwen2_5_vl\",\n        \"name\": \"Qwen 2.5 VL\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e\",\n    }\n    QWEN3 = {\n        \"type\": \"qwen3\",\n        \"name\": \"Qwen 3\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f\",\n    }\n    QWEN3_MOE = {\n        \"type\": \"qwen3_moe\",\n        \"name\": \"Qwen 3 Moe\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f\",\n    }\n    GALACTICA = {\n        \"type\": \"galactica\",\n        \"name\": \"Galactica\",\n        \"url\": \"https://huggingface.co/facebook/galactica-120b\",\n    }\n    SANTACODER = {\n        \"type\": \"santacoder\",\n        \"name\": \"SantaCoder\",\n        \"url\": \"https://huggingface.co/bigcode/santacoder\",\n    }\n    GPT2 = {\n        \"type\": \"gpt2\",\n        \"name\": \"Gpt2\",\n        \"url\": \"https://huggingface.co/openai-community/gpt2\",\n    }\n    GPT_NEOX = {\n        \"type\": \"gpt_neox\",\n        \"name\": \"Gpt Neox\",\n        \"url\": \"https://huggingface.co/EleutherAI/gpt-neox-20b\",\n    }\n    GPTJ = {\n        \"type\": \"gptj\",\n        \"name\": \"Gptj\",\n        \"url\": \"https://huggingface.co/EleutherAI/gpt-j-6b\",\n    }\n    MLLAMA = {\n        \"type\": \"mllama\",\n        \"name\": \"Mllama\",\n        \"url\": \"https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"multimodal\": True,\n    }\n\n\n__GLOBALS = locals()\nfor data in ModelType:\n    __GLOBALS[data.name] = data.value[\"type\"]\n\nSDP_ON_BF16 = int(os.environ.get(\"SDP_ON_BF16\", 0))\n# Disable gradients\ntorch.set_grad_enabled(False)\n\n\ndef get_model(\n    model_id: str,\n    lora_adapter_ids: Optional[List[str]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[torch.dtype],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    max_input_tokens: int,\n) -> Model:\n    global FLASH_ATTENTION\n\n    if speculate is not None:\n        set_speculate(speculate)\n    else:\n        set_speculate(0)\n\n    config_dict, _ = PretrainedConfig.get_config_dict(\n        model_id, revision=revision, trust_remote_code=trust_remote_code\n    )\n    model_type = config_dict.get(\"model_type\", None)\n\n    speculator = None\n    if \"medusa_num_heads\" in config_dict:\n        medusa_model_id = model_id\n        medusa_revision = revision\n        model_id = config_dict[\"base_model_name_or_path\"]\n        revision = \"main\"\n        speculate_medusa = config_dict[\"medusa_num_heads\"]\n        if speculate is not None:\n            if speculate > speculate_medusa:\n                raise RuntimeError(\n                    f\"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match\"\n                )\n            else:\n                set_speculate(speculate)\n        else:\n            set_speculate(speculate_medusa)\n\n        config_dict, _ = PretrainedConfig.get_config_dict(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        # Reload model type from parent.\n        model_type = config_dict.get(\"model_type\", None)\n        is_local = Path(medusa_model_id).exists()\n        if not is_local:\n            medusa_config = hf_hub_download(\n                medusa_model_id, revision=medusa_revision, filename=\"config.json\"\n            )\n            hf_hub_download(\n                medusa_model_id,\n                revision=medusa_revision,\n                filename=\"medusa_lm_head.safetensors\",\n            )\n            speculator = {\n                \"path\": Path(medusa_config).parent,\n                \"model_paths\": [\"medusa_lm_head.safetensors\"],\n            }\n        else:\n            speculator = {\n                \"path\": Path(medusa_model_id),\n                \"model_paths\": [\"medusa_lm_head.safetensors\"],\n            }\n\n        method = \"medusa\"\n    elif model_type == \"mlp_speculator\":\n        mlp_model_id = model_id\n        mlp_revision = revision\n        model_id = config_dict[\"base_model_name_or_path\"]\n        revision = \"main\"\n        speculate_mlp = config_dict[\"n_predict\"]\n        if speculate is not None:\n            if speculate > speculate_mlp:\n                raise RuntimeError(\n                    f\"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match\"\n                )\n            else:\n                set_speculate(speculate)\n        else:\n            set_speculate(speculate_mlp)\n\n        config_dict, _ = PretrainedConfig.get_config_dict(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        # Reload model type from parent.\n        model_type = config_dict.get(\"model_type\", None)\n        is_local = Path(mlp_model_id).exists()\n        extension = \".safetensors\"\n        if not is_local:\n            mlp_speculator_config = hf_hub_download(\n                mlp_model_id, revision=mlp_revision, filename=\"config.json\"\n            )\n            api = HfApi()\n            info = api.model_info(mlp_model_id, revision=mlp_revision)\n            filenames = [\n                s.rfilename\n                for s in info.siblings\n                if s.rfilename.endswith(extension)\n                and len(s.rfilename.split(\"/\")) == 1\n                and \"arguments\" not in s.rfilename\n                and \"args\" not in s.rfilename\n                and \"training\" not in s.rfilename\n            ]\n            for filename in filenames:\n                hf_hub_download(\n                    mlp_model_id,\n                    revision=mlp_revision,\n                    filename=filename,\n                )\n            speculator_dir_path = Path(mlp_speculator_config).parent\n            # if these are downloaded, they get converted to safetensors\n            filenames.extend(\n                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]\n            )\n            speculator = {\n                \"path\": Path(mlp_speculator_config).parent,\n                \"model_paths\": filenames,\n            }\n        else:\n            speculator = Path(mlp_model_id)\n            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]\n            speculator = {\"path\": speculator, \"model_paths\": filenames}\n        method = \"mlp_speculator\"\n    else:\n        method = \"n-gram\"\n\n    speculate = get_speculate()\n    if speculate > 0:\n        logger.info(f\"Using speculation {method} with {speculate} input ids.\")\n\n    model_type = config_dict[\"model_type\"]\n\n    if kv_cache_dtype == \"fp8_e4m3fn\":\n        kv_cache_dtype = torch.float8_e4m3fn\n    elif kv_cache_dtype == \"fp8_e5m2\":\n        kv_cache_dtype = torch.float8_e5m2\n    else:\n        kv_cache_dtype = dtype\n\n    if FLASH_ATTENTION:\n        if model_type == DEEPSEEK_V2:\n            head_size = max(\n                config_dict.get(\"qk_nope_dim\", 128)\n                + config_dict.get(\"qk_rope_dim\", 64),\n                config_dict.get(\"v_head_dim\", 128),\n            )\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDeepseekV2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                default_dtype=torch.bfloat16,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DeepseekV2Config,\n                head_size=head_size,\n            )\n        elif model_type == DEEPSEEK_V3:\n            head_size = max(\n                config_dict.get(\"qk_nope_dim\", 128)\n                + config_dict.get(\"qk_rope_dim\", 64),\n                config_dict.get(\"v_head_dim\", 128),\n            )\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDeepseekV3ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                default_dtype=torch.bfloat16,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DeepseekV3Config,\n                head_size=head_size,\n            )\n\n        elif (\n            model_type == GPT_BIGCODE\n            or model_type == GPT2\n            and model_id.startswith(\"bigcode/\")\n        ):\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashSantacoderForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                aliases={\"transformer.wte.weight\": [\"lm_head.weight\"]},\n                num_kv_heads=1,\n            )\n        elif model_type == GPT2:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGPT2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == GPTJ:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGPTJForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == GPT_NEOX:\n            from text_generation_server.models.custom_modeling.flash_neox_modeling import (\n                GPTNeoXConfig,\n            )\n\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGPTNeoXForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=GPTNeoXConfig,\n            )\n        elif model_type == PHI:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashPhiForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == PHI_MOE:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                config_class=PhiMoEConfig,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == LLAMA4:\n            print(f\"Llama4 model detected: {model_id}\")\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Llama4ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                support_chunking=False,\n            )\n        elif model_type == BAICHUAN:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == GEMMA:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemmaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == GEMMA2:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemma2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == GEMMA3:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Gemma3ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                support_chunking=False,\n            )\n        elif model_type == GEMMA3_TEXT:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemma3ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == COHERE:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashCohereForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == DBRX:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDbrxForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Dbrx works better in bfloat16.\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DbrxConfig,\n            )\n        elif (\n            model_type in [\"RefinedWeb\", \"RefinedWebModel\", FALCON]\n            and not sharded\n            and not config_dict.get(\"alibi\", False)\n        ):\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashRWForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                aliases={\n                    \"lm_head.weight\": [\"transformer.word_embeddings.weight\"],\n                    \"transformer.word_embeddings.weight\": [\"lm_head.weight\"],\n                },\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=RWConfig,\n            )\n        elif model_type == MISTRAL:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashMistralForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == MIXTRAL:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashMixtralForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == STARCODER2:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashStarcoder2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == QWEN2:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=Qwen2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == QWEN2_VL:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Qwen2VLForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # TODO: Fix bug in rust image_text_replacement implementation\n                support_chunking=False,\n            )\n        elif model_type == QWEN2_5_VL:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Qwen2_5VLForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=Qwen2_5_VLConfig,\n                processor_class=Qwen2_5_VLProcessor,\n                # TODO: Fix bug in rust image_text_replacement implementation\n                support_chunking=False,\n            )\n        elif model_type == QWEN3:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=Qwen3ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == QWEN3_MOE:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=Qwen3MoeForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == MLLAMA:\n            return FlashMllamaCausalLM(\n                model_id=model_id,\n                model_class=FlashMllamaForConditionalGeneration,\n                batch_class=FlashMllamaCausalLMBatch,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                support_chunking=False,\n            )\n        elif model_type == IDEFICS2:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Idefics2ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # XXX: Extremely important to cap resolution in order to limit\n                # VRAM usage.\n                processor_kwargs={\"size\": {\"longest_edge\": 448, \"shortest_edge\": 378}},\n            )\n        elif model_type == IDEFICS3:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=Idefics3ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # XXX: Extremely important to cap resolution in order to limit\n                # VRAM usage.\n                processor_kwargs={\"size\": {\"longest_edge\": 1456}},\n            )\n        elif model_type == PALIGEMMA:\n            return FlashVlmCausalLM(\n                model_id=model_id,\n                model_class=PaliGemmaForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif model_type == LLAVA_NEXT:\n            return FlashVlmCausalLM(\n                model_class=FlashLlavaNextForConditionalGeneration,\n                model_id=model_id,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    raise ValueError(f\"Unsupported model type {model_type}\")\n\n\n# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters\n# this provides a post model loading hook to load adapters into the model after the model has been loaded\ndef get_model_with_lora_adapters(\n    model_id: str,\n    lora_adapters: Optional[List[AdapterInfo]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[torch.dtype],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    max_input_tokens: int,\n    adapter_to_index: Dict[str, int],\n):\n    lora_adapter_ids = [adapter.id for adapter in lora_adapters]\n    model = get_model(\n        model_id,\n        lora_adapter_ids,\n        revision,\n        sharded,\n        quantize,\n        speculate,\n        dtype,\n        kv_cache_dtype,\n        trust_remote_code,\n        max_input_tokens,\n    )\n\n    if len(lora_adapters) > 0:\n        target_to_layer = build_layer_weight_lookup(model.model)\n\n        for index, adapter in enumerate(lora_adapters):\n            # The AdapterParameters object allows for merging multiple adapters into a single adapter.\n            # At the moment, we only support loading a single adapter into the model, but we keep the\n            # AdapterParameters object for easier extension in the future.\n            adapter_parameters = AdapterParameters(\n                adapter_info=[adapter],\n                # when merging multiple adapters we can weight them differently\n                # if this is not set, all adapters will be weighted equally\n                # see: text_generation_server.utils.merges.strategies for impl\n                weights=None,\n                merge_strategy=0,\n                density=1.0,\n                majority_sign_method=0,\n            )\n\n            adapter_index = index + 1\n            adapter_to_index[adapter.id] = adapter_index\n\n            logger.info(\n                f\"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}\"\n            )\n            weight_names = tuple([v[0] for v in target_to_layer.values()])\n            (\n                module_map,\n                adapter_config,\n                adapter_weight_names,\n                adapter_tokenizer,\n            ) = load_and_merge_adapters(\n                model.model_id,\n                adapter_parameters,\n                adapter_index,\n                weight_names,\n                False,\n            )\n\n            unused_weight_names = adapter_weight_names.copy()\n\n            adapter_layers = [\n                \"q_proj\",\n                \"k_proj\",\n                \"v_proj\",\n                \"o_proj\",\n                \"gate_proj\",\n                \"up_proj\",\n                \"down_proj\",\n                \"qkv_proj\",\n            ]\n\n            for layer_name in adapter_layers:\n                nlayers = (\n                    1 if layer_name == \"lm_head\" else len(model.model.model.layers)\n                )\n                adapter_weights = LoraWeights.prepare_weights(\n                    config=adapter_config,\n                    module_map=module_map,\n                    layer_type=layer_name,\n                    unused_weight_names=unused_weight_names,\n                    nlayers=nlayers,\n                    dtype=model.dtype,\n                    world_size=model.world_size,\n                    process_group=model.process_group,\n                    target_to_layer=target_to_layer,\n                )\n\n                if adapter_weights is None:\n                    continue\n\n                model.layer_to_adapter_weights[layer_name].add_adapter(\n                    adapter_index, adapter_weights\n                )\n\n            if len(unused_weight_names) > 0:\n                logger.warning(\n                    f\"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}\"\n                )\n\n            if adapter_tokenizer is not None:\n                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)\n\n            model.loaded_adapters.add(adapter_index)\n\n    return model\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/__init__.py",
    "content": ""
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/bloom_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 HuggingFace Inc. team and BigScience workshop.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch BLOOM model.\"\"\"\n\nimport math\nimport os\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.distributed\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import LayerNorm\nfrom torch.nn import functional as F\n\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\nfrom transformers import BloomConfig, PreTrainedModel\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n)\n\nCUSTOM_KERNELS_ENABLED = False\nif (\n    torch.cuda.is_available()\n    and not os.environ.get(\"DISABLE_CUSTOM_KERNELS\", \"False\") == \"True\"\n):\n    try:\n        from custom_kernels import fused_bloom_attention_cuda\n\n        CUSTOM_KERNELS_ENABLED = True\n    except ImportError:\n        pass\n\n_CHECKPOINT_FOR_DOC = \"bigscience/bloom-560m\"\n_CONFIG_FOR_DOC = \"BloomConfig\"\n\nBLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"bigscience/bigscience-small-testing\",\n    \"bigscience/bloom-560m\",\n    \"bigscience/bloom-1b1\",\n    \"bigscience/bloom-1b7\",\n    \"bigscience/bloom-3b\",\n    \"bigscience/bloom-7b1\",\n    \"bigscience/bloom\",\n]\n\n\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int\n) -> torch.BoolTensor:\n    \"\"\"\n    Make causal mask used for self-attention.\n    \"\"\"\n    batch_size, target_length = input_ids_shape\n    mask = torch.ones(\n        (target_length, target_length + past_key_values_length),\n        dtype=torch.bool,\n        device=device,\n    )\n    mask = mask.triu(1 + past_key_values_length)\n\n    expanded_mask = mask.unsqueeze(0).expand(\n        batch_size, target_length, target_length + past_key_values_length\n    )\n    return expanded_mask\n\n\ndef _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:\n    \"\"\"\n    Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.\n    \"\"\"\n    batch_size, src_length = mask.shape\n    tgt_length = tgt_length if tgt_length is not None else src_length\n\n    expanded_mask = ~(mask[:, None, :].to(torch.bool))\n    return expanded_mask.expand(batch_size, tgt_length, src_length)\n\n\ndef build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:\n    \"\"\"\n    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it\n    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value\n    `softmax(l+a) = softmax(l)`. Based on\n    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742\n    TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.\n\n    Args:\n    Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)\n        attention_mask (`torch.Tensor`):\n            Token-wise attention mask, this should be of shape (batch_size, max_seq_len).\n        num_heads (`int`, *required*):\n            number of heads\n        dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):\n            dtype of the output tensor\n    \"\"\"\n    batch_size, seq_length = attention_mask.shape\n    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))\n    base = torch.tensor(\n        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),\n        device=attention_mask.device,\n        dtype=torch.float32,\n    )\n    powers = torch.arange(\n        1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32\n    )\n    slopes = torch.pow(base, powers)\n\n    if closest_power_of_2 != num_heads:\n        extra_base = torch.tensor(\n            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),\n            device=attention_mask.device,\n            dtype=torch.float32,\n        )\n        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)\n        extra_powers = torch.arange(\n            1,\n            1 + 2 * num_remaining_heads,\n            2,\n            device=attention_mask.device,\n            dtype=torch.int32,\n        )\n        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)\n\n    # Note: alibi will added to the attention bias that will be applied to the query, key product of attention\n    # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)\n    # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)\n    # => the query_length dimension will then be broadcasted correctly\n    # This is more or less identical to T5's relative position bias:\n    # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527\n    arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]\n    alibi = slopes[..., None] * arange_tensor\n    return alibi\n\n\n# @torch.jit.script\ndef dropout_add(\n    x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool\n) -> torch.Tensor:\n    \"\"\"\n    Dropout add function\n\n    Args:\n        x (`torch.tensor`, *required*):\n            input tensor\n        residual (`torch.tensor`, *required*):\n            esidual tensor\n        prob (`float`, *required*):\n            dropout probability\n        training (`bool`, *required*):\n            training mode\n    \"\"\"\n    out = F.dropout(x, p=prob, training=training)\n    out = residual + out\n    return out\n\n\n# @torch.jit.script # this is shit for unknow reasons.\ndef _split_heads(\n    fused_qkv: torch.Tensor, num_heads: int, head_dim: int\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory\n    storage as `fused_qkv`\n\n    Args:\n        fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]\n\n    Returns:\n        query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]\n        value: [batch_size, seq_length, num_heads, head_dim]\n    \"\"\"\n    batch_size, seq_length, three_times_hidden_size = fused_qkv.shape\n    fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)\n    query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)\n\n    query_layer = query_layer.transpose(1, 2).reshape(\n        batch_size * num_heads, seq_length, head_dim\n    )\n    key_layer = key_layer.permute(0, 2, 3, 1).reshape(\n        batch_size * num_heads, head_dim, seq_length\n    )\n    value_layer = value_layer.transpose(1, 2).reshape(\n        batch_size * num_heads, seq_length, head_dim\n    )\n\n    return query_layer, key_layer, value_layer\n\n\n# @torch.jit.script\ndef _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:\n    \"\"\"\n    Merge heads together over the last dimenstion\n\n    Args:\n        x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]\n\n    Returns:\n        torch.tensor: [batch_size, seq_length, num_heads * head_dim]\n    \"\"\"\n    # What we want to achieve is:\n    # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim\n    batch_size_and_num_heads, seq_length, _ = x.shape\n    batch_size = batch_size_and_num_heads // num_heads\n\n    # First view to decompose the batch size\n    # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim\n    x = x.view(batch_size, num_heads, seq_length, head_dim)\n\n    # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim\n    x = x.permute(0, 2, 1, 3)\n\n    # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim\n    return x.reshape(batch_size, seq_length, num_heads * head_dim)\n\n\nclass BloomAttention(nn.Module):\n    def __init__(self, prefix, config: BloomConfig, weights):\n        super().__init__()\n\n        self.pretraining_tp = config.pretraining_tp\n        self.slow_but_exact = config.slow_but_exact\n\n        self.process_group = weights.process_group\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.n_head\n        self.head_dim = self.hidden_size // self.num_heads\n        self.split_size = self.hidden_size\n        self.hidden_dropout = config.hidden_dropout\n\n        if self.head_dim * self.num_heads != self.hidden_size:\n            raise ValueError(\n                f\"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n\n        # Layer-wise attention scaling\n        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)\n        self.beta = 1.0\n\n        process_group = weights.process_group\n        if self.num_heads % process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // process_group.size()\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config=config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=True,\n        )\n        self.dense = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.dense\", weights=weights, bias=True\n        )\n        self.attention_dropout = nn.Dropout(config.attention_dropout)\n\n    @staticmethod\n    def compute_attention(\n        fused_qkv: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        head_mask: Optional[torch.Tensor],\n        beta: float,\n        inv_norm_factor: float,\n        num_heads: int,\n        use_cache: bool,\n    ):\n        batch_size, q_length, three_times_hidden_size = fused_qkv.shape\n        head_dim = three_times_hidden_size // (3 * num_heads)\n        batch_size * num_heads\n\n        ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?\n        # 3 x [batch_size, seq_length, num_heads, head_dim]\n        (query_layer, key_layer, value_layer) = _split_heads(\n            fused_qkv, num_heads=num_heads, head_dim=head_dim\n        )\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            # concatenate along seq_length dimension:\n            #  - key: [batch_size * self.num_heads, head_dim, kv_length]\n            #  - value: [batch_size * self.num_heads, kv_length, head_dim]\n            past_key = past_key.view(-1, *past_key.shape[-2:])\n            key_layer = torch.cat((past_key, key_layer), dim=2)\n            past_value = past_value.view(-1, *past_value.shape[-2:])\n            value_layer = torch.cat((past_value, value_layer), dim=1)\n\n        _, _, kv_length = key_layer.shape\n\n        if use_cache is True:\n            present = (key_layer, value_layer)\n        else:\n            present = None\n        ###\n\n        # [batch_size * num_heads, q_length, kv_length]\n        # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11\n        attention_scores = alibi.baddbmm(\n            batch1=query_layer,\n            batch2=key_layer,\n            beta=beta,\n            alpha=inv_norm_factor,\n        )\n\n        # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]\n        input_dtype = attention_scores.dtype\n        # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`\n        if input_dtype == torch.float16:\n            attention_scores = attention_scores.to(torch.float)\n        # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`\n        attn_weights = attention_scores.masked_fill_(\n            attention_mask, torch.finfo(attention_scores.dtype).min\n        )\n        attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(\n            input_dtype\n        )\n\n        # # [batch_size, num_heads, q_length, kv_length]\n        # attention_probs = self.attention_dropout(attention_probs)\n\n        if head_mask is not None:\n            attention_probs = attention_probs * head_mask\n\n        # matmul: [batch_size * num_heads, q_length, head_dim]\n        context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)\n\n        # change view [batch_size, num_heads, q_length, head_dim]\n        context_layer = _merge_heads(\n            context_layer, num_heads=num_heads, head_dim=head_dim\n        )\n\n        return context_layer, present, attention_probs\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        fused_qkv = self.query_key_value(\n            hidden_states\n        )  # [batch_size, seq_length, 3 x hidden_size]\n        batch_size, q_length, _ = fused_qkv.shape\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            layer_past = (\n                past_key.view(-1, *past_key.shape[-2:]),\n                past_value.view(-1, *past_value.shape[-2:]),\n            )\n\n        if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:\n            assert self.training is False, \"Only foward pass was implemented\"\n            assert (\n                attention_mask.shape[-1] < 4096\n            ), \"Custom kernel support only up to 4096 tokens\"\n            (\n                context_layer,\n                present,\n                attention_probs,\n            ) = fused_bloom_attention_cuda.forward(\n                fused_qkv,\n                layer_past,\n                alibi,\n                attention_mask,\n                head_mask,\n                self.beta,\n                self.inv_norm_factor,\n                self.num_heads,\n                use_cache,\n            )\n        else:\n            context_layer, present, attention_probs = self.compute_attention(\n                fused_qkv=fused_qkv,\n                layer_past=layer_past,\n                alibi=alibi,\n                attention_mask=attention_mask,\n                head_mask=head_mask,\n                beta=self.beta,\n                inv_norm_factor=self.inv_norm_factor,\n                num_heads=self.num_heads,\n                use_cache=use_cache,\n            )\n\n        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232\n        if self.pretraining_tp > 1 and self.slow_but_exact:\n            slices = self.hidden_size / self.pretraining_tp\n            output_tensor = torch.zeros_like(context_layer)\n            for i in range(self.pretraining_tp):\n                output_tensor = output_tensor + F.linear(\n                    context_layer[:, :, int(i * slices) : int((i + 1) * slices)],\n                    self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],\n                )\n        else:\n            output_tensor = self.dense(context_layer)\n\n        # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)\n        output_tensor += residual\n\n        outputs = (output_tensor, present)\n        if output_attentions:\n            outputs += (attention_probs,)\n\n        return outputs\n\n\nclass BloomMLP(nn.Module):\n    def __init__(self, prefix, config: BloomConfig, weights):\n        super().__init__()\n\n        self.pretraining_tp = config.pretraining_tp\n        self.slow_but_exact = config.slow_but_exact\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=True\n        )\n        self.dense_4h_to_h = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=True\n        )\n        self.gelu_impl = torch.nn.GELU(approximate=\"tanh\")\n        self.hidden_dropout = config.hidden_dropout\n\n    def forward(\n        self, hidden_states: torch.Tensor, residual: torch.Tensor\n    ) -> torch.Tensor:\n        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))\n\n        if self.pretraining_tp > 1 and self.slow_but_exact:\n            intermediate_output = torch.zeros_like(residual)\n            slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp\n            for i in range(self.pretraining_tp):\n                intermediate_output = intermediate_output + F.linear(\n                    hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],\n                    self.dense_4h_to_h.weight[\n                        :, int(i * slices) : int((i + 1) * slices)\n                    ],\n                )\n        else:\n            intermediate_output = self.dense_4h_to_h(hidden_states)\n\n        # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)\n        intermediate_output += residual\n\n        return intermediate_output\n\n\nclass BloomBlock(nn.Module):\n    def __init__(self, layer_id: int, config: BloomConfig, weights):\n        super().__init__()\n\n        prefix = f\"h.{layer_id}\"\n        self.input_layernorm = LayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.num_heads = config.n_head\n        self.self_attention = BloomAttention(\n            prefix=f\"{prefix}.self_attention\", config=config, weights=weights\n        )\n        self.post_attention_layernorm = LayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.mlp = BloomMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.apply_residual_connection_post_layernorm = (\n            config.apply_residual_connection_post_layernorm\n        )\n        self.hidden_dropout = config.hidden_dropout\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        # hidden_states: [batch_size, seq_length, hidden_size]\n\n        # Layer norm at the beginning of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n\n        # Layer norm post the self attention.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        # Self attention.\n        attn_outputs = self.self_attention(\n            layernorm_output,\n            residual,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            alibi=alibi,\n            head_mask=head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n\n        attention_output = attn_outputs[0]\n\n        outputs = attn_outputs[1:]\n\n        layernorm_output = self.post_attention_layernorm(attention_output)\n\n        # Get residual\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = attention_output\n\n        # MLP.\n        output = self.mlp(layernorm_output, residual)\n\n        if use_cache:\n            outputs = (output,) + outputs\n        else:\n            outputs = (output,) + outputs[1:]\n\n        return outputs  # hidden_states, present, attentions\n\n\nclass BloomPreTrainedModel(PreTrainedModel):\n    config_class = BloomConfig\n    base_model_prefix = \"transformer\"\n    _no_split_modules = [\"BloomBlock\"]\n\n    @staticmethod\n    def _convert_to_standard_cache(\n        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,\n        num_heads, ...]))\n        \"\"\"\n        batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape\n        num_heads = batch_size_times_num_heads // batch_size\n        # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]\n        # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]\n        return tuple(\n            (\n                layer_past[0].view(batch_size, num_heads, head_dim, seq_length),\n                layer_past[1].view(batch_size, num_heads, seq_length, head_dim),\n            )\n            for layer_past in past_key_value\n        )\n\n    @staticmethod\n    def _convert_to_bloom_cache(\n        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))\n        \"\"\"\n        batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape\n        batch_size_times_num_heads = batch_size * num_heads\n        # key:  [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]\n        # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]\n        return tuple(\n            (\n                layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),\n                layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),\n            )\n            for layer_past in past_key_value\n        )\n\n\nclass BloomModel(BloomPreTrainedModel):\n    def __init__(self, config: BloomConfig, weights):\n        super().__init__(config)\n\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.n_head\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        self.word_embeddings = TensorParallelEmbedding(\n            prefix=\"word_embeddings\", weights=weights\n        )\n\n        self.word_embeddings_layernorm = LayerNorm.load(\n            prefix=\"word_embeddings_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        # Transformer blocks\n        self.h = nn.ModuleList(\n            [\n                BloomBlock(layer_id=layer_id, config=config, weights=weights)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        # Final Layer Norm\n        self.ln_f = LayerNorm.load(\n            prefix=\"ln_f\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def _prepare_attn_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_shape: Tuple[int, int],\n        past_key_values_length: int,\n    ) -> torch.BoolTensor:\n        # create causal mask\n        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n        combined_attention_mask = None\n        device = attention_mask.device\n        _, src_length = input_shape\n\n        if src_length > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                device=device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)\n        combined_attention_mask = (\n            expanded_attn_mask\n            if combined_attention_mask is None\n            else expanded_attn_mask | combined_attention_mask\n        )\n\n        return combined_attention_mask\n\n    def set_input_embeddings(self, new_embeddings: torch.Tensor):\n        self.word_embeddings = new_embeddings\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:\n        if deprecated_arguments.pop(\"position_ids\", False) is not False:\n            # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`\n            warnings.warn(\n                \"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore\"\n                \" passing `position_ids`.\",\n                FutureWarning,\n            )\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f\"Got unexpected arguments: {deprecated_arguments}\")\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if past_key_values is None:\n            past_key_values = tuple([None] * len(self.h))\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape batch_size x num_heads x N x N\n        # head_mask has shape n_layer x batch x num_heads x N x N\n        head_mask = self.get_head_mask(head_mask, self.config.n_layer)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        # Compute alibi tensor: check build_alibi_tensor documentation\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values[0] is not None:\n            past_key_values_length = past_key_values[0][0].shape[-1]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), device=hidden_states.device\n            )\n        else:\n            attention_mask = attention_mask.to(hidden_states.device)\n\n        alibi = build_alibi_tensor(attention_mask, self.num_heads)\n\n        causal_mask = self._prepare_attn_mask(\n            attention_mask,\n            input_shape=(batch_size, seq_length),\n            past_key_values_length=past_key_values_length,\n        )\n\n        if hasattr(self, \"tp_rank\"):\n            assert self.num_heads % self.tp_world_size == 0\n            block_size = self.num_heads // self.tp_world_size\n            alibi = alibi[\n                :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size\n            ]\n            alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)\n            causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)\n        else:\n            alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)\n            causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)\n\n        alibi = alibi.to(hidden_states.dtype)\n\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            outputs = block(\n                hidden_states,\n                layer_past=layer_past,\n                attention_mask=causal_mask,\n                head_mask=head_mask[i],\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                alibi=alibi,\n            )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (\n                    outputs[2 if use_cache else 1],\n                )\n\n        # Add last hidden state\n        hidden_states = self.ln_f(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    presents,\n                    all_hidden_states,\n                    all_self_attentions,\n                ]\n                if v is not None\n            )\n\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass BloomForCausalLM(BloomPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.transformer = BloomModel(config, weights)\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"word_embeddings\",\n            weights=weights,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> dict:\n        # only last token for input_ids if past is not None\n        if past_key_values:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n\n            # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed\n            if past_key_values[0][0].shape[0] == input_ids.shape[0]:\n                past_key_values = self._convert_to_bloom_cache(past_key_values)\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\n            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`\n            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`\n        \"\"\"\n        if deprecated_arguments.pop(\"position_ids\", False) is not False:\n            # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`\n            warnings.warn(\n                \"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore\"\n                \" passing `position_ids`.\",\n                FutureWarning,\n            )\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f\"Got unexpected arguments: {deprecated_arguments}\")\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        logits, speculative_logits = self.lm_head(hidden_states)\n        loss = None\n\n        if not return_dict:\n            output = (logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return (\n            CausalLMOutputWithCrossAttentions(\n                loss=loss,\n                logits=logits,\n                past_key_values=transformer_outputs.past_key_values,\n                hidden_states=transformer_outputs.hidden_states,\n                attentions=transformer_outputs.attentions,\n            ),\n            speculative_logits,\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/clip.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_attn_mask_utils import (\n    _create_4d_causal_attention_mask,\n    _prepare_4d_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPooling,\n)\nfrom transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig\n\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\nclass CLIPVisionEmbeddings(nn.Module):\n    def __init__(self, prefix, config: CLIPVisionConfig, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        # TODO Should we TP this ?\n        self.class_embedding = weights.get_tensor(f\"{prefix}.class_embedding\")\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions, device=weights.device).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        batch_size = pixel_values.shape[0]\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(\n            pixel_values.to(dtype=target_dtype)\n        )  # shape = [*, width, grid, grid]\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass CLIPTextEmbeddings(nn.Module):\n    def __init__(self, config: CLIPTextConfig):\n        super().__init__()\n        embed_dim = config.hidden_size\n\n        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)\n        self.position_embedding = nn.Embedding(\n            config.max_position_embeddings, embed_dim\n        )\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(config.max_position_embeddings).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        seq_length = (\n            input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n        )\n\n        if position_ids is None:\n            position_ids = self.position_ids[:, :seq_length]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.token_embedding(input_ids)\n\n        position_embeddings = self.position_embedding(position_ids)\n        embeddings = inputs_embeds + position_embeddings\n\n        return embeddings\n\n\nclass CLIPAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_size)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n            ]\n            * 3,\n            dim=2,\n        )\n        query_states = query_states * self.scale\n        key_states = self._shape(key_states, -1, bsz)\n        value_states = self._shape(value_states, -1, bsz)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_size)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        # apply the causal_attention_mask first\n        if causal_attention_mask is not None:\n            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is\"\n                    f\" {causal_attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + causal_attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        attn_probs = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, None\n\n\nclass CLIPMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass CLIPEncoderLayer(nn.Module):\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = CLIPAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.mlp = CLIPMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", weights=weights, eps=config.layer_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        causal_attention_mask: torch.Tensor,\n    ):\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass CLIPPreTrainedModel(nn.Module):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = CLIPConfig\n    base_model_prefix = \"clip\"\n    supports_gradient_checkpointing = True\n\n\nCLIP_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nCLIP_TEXT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n\"\"\"\n\nCLIP_VISION_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using\n            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.\n\"\"\"\n\nCLIP_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using\n            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.\n        return_loss (`bool`, *optional*):\n            Whether or not to return the contrastive loss.\n\"\"\"\n\n\nclass CLIPEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`CLIPEncoderLayer`].\n\n    Args:\n        config: CLIPConfig\n    \"\"\"\n\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                CLIPEncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Args:\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Causal mask for the text model. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n        \"\"\"\n\n        hidden_states = inputs_embeds\n        for idx, encoder_layer in enumerate(self.layers):\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n                causal_attention_mask,\n            )\n\n        return hidden_states\n\n\nclass CLIPTextTransformer(nn.Module):\n    def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n        self.embeddings = CLIPTextEmbeddings(config)\n        # Initialize weights and apply final processing with `self.post_init()`\n        self.encoder = CLIPEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n        # For `pooled_output` computation\n        self.eos_token_id = config.eos_token_id\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        \"\"\"\n        if input_ids is None:\n            raise ValueError(\"You have to specify input_ids\")\n\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n\n        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)\n\n        # CLIP's text model uses causal mask, prepare it here.\n        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n        causal_attention_mask = _create_4d_causal_attention_mask(\n            input_shape, hidden_states.dtype, device=hidden_states.device\n        )\n        # expand attention_mask\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            attention_mask = _prepare_4d_attention_mask(\n                attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs[0]\n        last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n        if self.eos_token_id == 2:\n            # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.\n            # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added\n            # ------------------------------------------------------------\n            # text_embeds.shape = [batch_size, sequence_length, transformer.width]\n            # take features from the eot embedding (eot_token is the highest number in each sequence)\n            # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14\n            last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(\n                    dim=-1\n                ),\n            ]\n        else:\n            # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)\n            last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)\n                (\n                    input_ids.to(dtype=torch.int, device=last_hidden_state.device)\n                    == self.eos_token_id\n                )\n                .int()\n                .argmax(dim=-1),\n            ]\n\n        return last_hidden_state\n\n\nclass CLIPTextModel(CLIPPreTrainedModel):\n    config_class = CLIPTextConfig\n\n    _no_split_modules = [\"CLIPTextEmbeddings\", \"CLIPEncoderLayer\"]\n\n    def __init__(self, prefix, config: CLIPTextConfig):\n        super().__init__(config)\n        self.text_model = CLIPTextTransformer(prefix, config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, CLIPTextModel\n\n        >>> model = CLIPTextModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=True, return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states\n        ```\"\"\"\n\n        return self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n\nclass CLIPVisionTransformer(nn.Module):\n    def __init__(self, prefix, config: CLIPVisionConfig, weights):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = CLIPVisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.pre_layrnorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.pre_layrnorm\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.encoder = CLIPEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        # self.post_layernorm = nn.LayerNorm.load(prefix=f\"{prefix}.post_layernorm\", weights=weights, eps=config.layer_norm_eps)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        \"\"\"\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.pre_layrnorm(hidden_states)\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n        )\n        last_hidden_state = encoder_outputs\n        # pooled_output = last_hidden_state[:, 0, :]\n        # pooled_output = self.post_layernorm(pooled_output)\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            # pooler_output=pooled_output,\n            # hidden_states=encoder_outputs,\n        )\n\n\nclass CLIPVisionModel(CLIPPreTrainedModel):\n    config_class = CLIPVisionConfig\n    main_input_name = \"pixel_values\"\n    _no_split_modules = [\"CLIPEncoderLayer\"]\n\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__(config)\n        self.vision_model = CLIPVisionTransformer(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.vision_model.embeddings.patch_embedding\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPVisionModel\n\n        >>> model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled CLS states\n        ```\"\"\"\n\n        return self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n\nclass CLIPModel(nn.Module):\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        text_config = config.text_config\n        vision_config = config.vision_config\n\n        self.projection_dim = config.projection_dim\n        self.text_embed_dim = text_config.hidden_size\n        self.vision_embed_dim = vision_config.hidden_size\n\n        self.text_model = CLIPTextTransformer(text_config)\n        self.vision_model = CLIPVisionTransformer(vision_config)\n\n        self.visual_projection = nn.Linear(\n            self.vision_embed_dim, self.projection_dim, bias=False\n        )\n        self.text_projection = nn.Linear(\n            self.text_embed_dim, self.projection_dim, bias=False\n        )\n        self.logit_scale = nn.Parameter(\n            torch.tensor(self.config.logit_scale_init_value)\n        )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_text_features(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by\n            applying the projection layer to the pooled output of [`CLIPTextModel`].\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=True, return_tensors=\"pt\")\n        >>> text_features = model.get_text_features(**inputs)\n        ```\"\"\"\n        text_outputs = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        pooled_output = text_outputs[1]\n        text_features = self.text_projection(pooled_output)\n\n        return text_features\n\n    def get_image_features(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by\n            applying the projection layer to the pooled output of [`CLIPVisionModel`].\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> image_features = model.get_image_features(**inputs)\n        ```\"\"\"\n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        vision_outputs = self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n        pooled_output = vision_outputs[1]  # pooled_output\n        image_features = self.visual_projection(pooled_output)\n\n        return image_features\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(\n        ...     text=[\"a photo of a cat\", \"a photo of a dog\"], images=image, return_tensors=\"pt\", padding=True\n        ... )\n\n        >>> outputs = model(**inputs)\n        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score\n        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities\n        ```\"\"\"\n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        vision_outputs = self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n        text_outputs = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        image_embeds = vision_outputs[1]\n        image_embeds = self.visual_projection(image_embeds)\n\n        text_embeds = text_outputs[1]\n        text_embeds = self.text_projection(text_embeds)\n\n        # normalized features\n        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)\n        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp()\n        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale\n        logits_per_image = logits_per_text.t()\n\n        return logits_per_image, logits_per_text\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Cohere team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nfrom habana_frameworks.torch.hpex.kernels import (\n    RotaryPosEmbeddingMode,\n    apply_rotary_pos_emb,\n)\n\nimport habana_frameworks.torch as htorch\n\n\nclass CohereRotary(PositionRotaryEmbedding):\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        # Such controlflows may add some overhead.\n        num_tokens = query.shape[0]\n        head_size = query.shape[-1]\n        rope_mode = RotaryPosEmbeddingMode.PAIRWISE\n        sin = torch.repeat_interleave(sin, 2, dim=-1)\n        cos = torch.repeat_interleave(cos, 2, dim=-1)\n        rotary_dim = cos.shape[-1]\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, head_size)\n        query_rot = query[..., :rotary_dim]\n        query_pass = query[..., rotary_dim:]\n        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)\n        query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))\n\n        key_shape = key.shape\n        key = key.view(num_tokens, -1, head_size)\n        key_rot = key[..., :rotary_dim]\n        key_pass = key[..., rotary_dim:]\n        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)\n        key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))\n\n\nclass CohereLayerNorm(nn.Module):\n    def __init__(self, prefix, weights, eps):\n        super().__init__()\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0)\n        self.weight = nn.Parameter(weight)\n        # Fake weights\n        self.ones = weight.new_ones(weight.shape[1])\n        self.eps = eps\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.reshape(\n            -1, self.weight.shape[0], self.weight.shape[1]\n        )\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        mean = hidden_states.mean(-1, keepdim=True)\n        hidden_states_minus_mean = hidden_states - mean\n        variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)\n        hidden_states = self.weight.to(torch.float32) * hidden_states\n        hidden_states = hidden_states.view(-1, self.weight.shape[1])\n        return hidden_states.to(input_dtype)\n\n\ndef load_attention(config, prefix, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    if config.attention_bias:\n        w = [\n            weights.get_sharded(f\"{p}.bias\", dim=0)\n            for p in [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n        ]\n        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=bias))\n\n\nclass FlashCohereAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.use_qk_norm = config.use_qk_norm\n        if self.use_qk_norm:\n            self.q_norm = CohereLayerNorm(\n                prefix=f\"{prefix}.q_norm\",\n                weights=weights,\n                eps=config.layer_norm_eps,\n            )\n            self.k_norm = CohereLayerNorm(\n                prefix=f\"{prefix}.k_norm\",\n                weights=weights,\n                eps=config.layer_norm_eps,\n            )\n        else:\n            self.q_norm = None\n            self.k_norm = None\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, key, value = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_key_value_heads,\n                self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        if self.use_qk_norm:\n            query = query.reshape(-1, self.head_size)\n            key = key.reshape(-1, self.head_size)\n            query = self.q_norm(query.contiguous())\n            key = self.k_norm(key.contiguous())\n\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_key_value_heads, self.head_size)\n        value = value.view(-1, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), reduce=False\n        )\n\n\nclass CohereMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False\n        )\n\n\nclass FlashCohereLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = FlashCohereAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = CohereMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        mlp_output = self.mlp(normed_hidden_states)\n        output = attn_output + mlp_output\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(output, group=self.process_group)\n\n        return output, res\n\n\nclass FlashCohereModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        rotary_emb = CohereRotary.static(\n            config=config,\n            dim=config.hidden_size // config.num_attention_heads,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n        self.layers = nn.ModuleList(\n            [\n                FlashCohereLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.layer_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: torch.Tensor,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashCohereForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = FlashCohereModel(prefix, config, weights)\n        try:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=\"lm_head\",\n                weights=weights,\n            )\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=f\"{prefix}.embed_tokens\",\n                weights=weights,\n            )\n        self.logit_scale = config.logit_scale\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        logits *= self.logit_scale\n        if speculative_logits is not None:\n            speculative_logits *= self.logit_scale\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple, Any\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\n\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    FastLinear,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom vllm_hpu_extension.ops import DynamicFusedMOE\nimport habana_frameworks.torch as htorch\n\n\nclass DbrxAttentionConfig(PretrainedConfig):\n    def __init__(\n        self,\n        attn_pdrop: float = 0,\n        clip_qkv: Optional[float] = None,\n        kv_n_heads: int = 1,\n        rope_theta: float = 10000.0,\n        **kwargs: Any,\n    ):\n        super().__init__(**kwargs)\n        self.attn_pdrop = attn_pdrop\n        self.clip_qkv = clip_qkv\n        self.kv_n_heads = kv_n_heads\n        self.rope_theta = rope_theta\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n\nclass DbrxFFNConfig(PretrainedConfig):\n    def __init__(\n        self,\n        ffn_act_fn: Optional[dict] = None,\n        ffn_hidden_size: int = 3584,\n        moe_num_experts: int = 4,\n        moe_top_k: int = 1,\n        moe_jitter_eps: Optional[float] = None,\n        moe_loss_weight: float = 0.01,\n        moe_normalize_expert_weights: Optional[float] = 1,\n        uniform_expert_assignment: bool = False,\n        **kwargs: Any,\n    ):\n        super().__init__()\n        if ffn_act_fn is None:\n            ffn_act_fn = {\"name\": \"silu\"}\n        self.ffn_act_fn = ffn_act_fn\n        self.ffn_hidden_size = ffn_hidden_size\n        self.moe_num_experts = moe_num_experts\n        self.moe_top_k = moe_top_k\n        self.moe_jitter_eps = moe_jitter_eps\n        self.moe_loss_weight = moe_loss_weight\n        self.moe_normalize_expert_weights = moe_normalize_expert_weights\n        self.uniform_expert_assignment = uniform_expert_assignment\n\n        if uniform_expert_assignment:\n            raise ValueError(\"`uniform_expert_assignment = True` is not supported\")\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n\nclass DbrxConfig(PretrainedConfig):\n    attribute_map = {\n        \"hidden_size\": \"d_model\",\n        \"num_attention_heads\": \"n_heads\",\n        \"num_hidden_layers\": \"n_layers\",\n    }\n\n    def __init__(\n        self,\n        d_model: int = 2048,\n        n_heads: int = 16,\n        n_layers: int = 24,\n        max_seq_len: int = 2048,\n        vocab_size: int = 32000,\n        resid_pdrop: float = 0.0,\n        emb_pdrop: float = 0.0,\n        attn_config: Optional[DbrxAttentionConfig] = None,\n        ffn_config: Optional[DbrxFFNConfig] = None,\n        use_cache: bool = True,\n        initializer_range: float = 0.02,\n        output_router_logits: bool = False,\n        router_aux_loss_coef: float = 0.05,\n        **kwargs: Any,\n    ):\n        if attn_config is None:\n            self.attn_config = DbrxAttentionConfig()\n        elif isinstance(attn_config, dict):\n            self.attn_config = DbrxAttentionConfig(**attn_config)\n        else:\n            self.attn_config = attn_config\n\n        if ffn_config is None:\n            self.ffn_config = DbrxFFNConfig()\n        elif isinstance(ffn_config, dict):\n            self.ffn_config = DbrxFFNConfig(**ffn_config)\n        else:\n            self.ffn_config = ffn_config\n\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.resid_pdrop = resid_pdrop\n        self.emb_pdrop = emb_pdrop\n        self.use_cache = use_cache\n        self.initializer_range = initializer_range\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\"tie_word_embeddings is not supported for Dbrx models.\")\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def num_key_value_heads(self):\n        # We can't use the attribute map, since this the number of KV\n        # heads is not top-level.\n        return self.attn_config.kv_n_heads\n\n\ndef promote_scalar(x: torch.Tensor) -> torch.Tensor:\n    return x.view(1) if len(x.size()) == 0 else x\n\n\ndef load_attention(config, prefix, weights):\n    return TensorParallelColumnLinear.load_qkv(\n        config,\n        prefix=f\"{prefix}.Wqkv\",\n        weights=weights,\n        bias=False,\n        num_heads=config.n_heads,\n        num_key_value_heads=config.attn_config.kv_n_heads,\n    )\n\n\ndef _load_experts(config, prefix, weights):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.ffn_config.ffn_hidden_size % world_size == 0\n    ), f\"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards\"\n\n    expert_size = config.ffn_config.ffn_hidden_size\n    block_size = expert_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    tensor = torch.empty(\n        (config.ffn_config.moe_num_experts * block_size, config.d_model),\n        dtype=weights.dtype,\n        device=weights.device,\n    )\n\n    slice_ = weights._get_slice(f\"{prefix}\")\n\n    for i in range(config.ffn_config.moe_num_experts):\n        offset = i * expert_size\n        expert_slice = slice_[start + offset : stop + offset]\n\n        tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(\n            dtype=weights.dtype\n        ).to(device=weights.device)\n    return tensor\n\n\ndef _load_experts_quantized(config, prefix, weights, cls):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.ffn_config.ffn_hidden_size % world_size == 0\n    ), f\"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards\"\n\n    expert_size = config.ffn_config.ffn_hidden_size\n    block_size = expert_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    slice_ = weights._get_slice(f\"{prefix}\")\n\n    experts = []\n    for i in range(config.ffn_config.moe_num_experts):\n        if config.quantize in [\"gptq\", \"awq\"]:\n            raise NotImplementedError(\n                \"Dbrx does not support gptq/awq quantization yet.\"\n            )\n        else:\n            offset = i * expert_size\n            expert_slice = (\n                slice_[start + offset : stop + offset]\n                .to(dtype=weights.dtype)\n                .to(device=weights.device)\n            )\n\n        if cls == TensorParallelRowLinear:\n            expert_slice = expert_slice.t().contiguous()\n            linear = get_linear(expert_slice, None)\n            experts.append(cls(linear, weights.process_group))\n        else:\n            linear = get_linear(expert_slice, None)\n            experts.append(cls(linear))\n\n    return experts\n\n\nclass DbrxAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.clip_qkv = config.attn_config.clip_qkv\n        self.num_heads = config.n_heads\n        self.hidden_size = config.d_model\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.attn_config.kv_n_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        if self.clip_qkv is not None:\n            qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)\n\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass DbrxNormAttentionNorm(nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.norm_1 = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_1\", weights=weights, eps=1e-5\n        )\n        self.self_attn = DbrxAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.norm_2 = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_2\",\n            weights=weights,\n            eps=1e-5,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.norm_1(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.norm_2(attn_output, res)\n\n        return normed_attn_res_output, attn_res\n\n\n@torch.jit.script\ndef select_experts(\n    gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int\n):\n    # all_probs: (sequence_length, n_experts) and upcast for softmax\n    all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n    # weights, selected_experts: (sequence_length, top-k)\n    weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)\n    if moe_normalize_expert_weights:\n        weights = weights / torch.norm(\n            weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True\n        )\n    weights = weights.view(-1)\n    selected_experts = selected_experts.view(-1)\n\n    return selected_experts, weights\n\n\n@torch.jit.script\ndef round_up(x: torch.Tensor, value: int):\n    return torch.div(x + (value - 1), value, rounding_mode=\"trunc\") * value\n\n\nclass BlockSparseMoE(nn.Module):\n    def __init__(self, prefix, config: DbrxConfig, weights):\n        super().__init__()\n        self.moe_normalize_expert_weights = (\n            config.ffn_config.moe_normalize_expert_weights\n        )\n        self.hidden_dim = config.d_model\n        self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()\n        self.num_experts = config.ffn_config.moe_num_experts\n        self.top_k = config.ffn_config.moe_top_k\n\n        act = config.ffn_config.ffn_act_fn[\"name\"]\n        if \"gelu\" in act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        elif \"silu\" in act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[act]\n\n        # gating\n        self.gate = FastLinear.load(\n            config, f\"{prefix}.router.layer\", weights, bias=False\n        )\n\n        # merged expert weights, all of size  (n_experts * ffn_dim, hidden_dim)\n        w1 = _load_experts(config, f\"{prefix}.experts.mlp.w1\", weights).view(\n            self.num_experts, self.ffn_dim, self.hidden_dim\n        )\n        v1 = _load_experts(config, f\"{prefix}.experts.mlp.v1\", weights).view(\n            self.num_experts, self.ffn_dim, self.hidden_dim\n        )\n        self.wv1 = torch.cat([w1, v1], dim=1)\n        self.w2 = (\n            _load_experts(config, f\"{prefix}.experts.mlp.w2\", weights)\n            .view(self.num_experts, self.ffn_dim, self.hidden_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n        self.process_group = weights.process_group\n\n        self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)\n        for i in range(self.num_experts):\n            self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])\n            self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n        out = self.hpu_fused_moe(x, router_logits, self.top_k)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DenseMoE(nn.Module):\n    def __init__(self, prefix, config: DbrxConfig, weights):\n        super().__init__()\n\n        self.moe_normalize_expert_weights = (\n            config.ffn_config.moe_normalize_expert_weights\n        )\n        self.hidden_dim = config.d_model\n        self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()\n        self.num_experts = config.ffn_config.moe_num_experts\n        self.top_k = config.ffn_config.moe_top_k\n\n        act = config.ffn_config.ffn_act_fn[\"name\"]\n        if \"gelu\" in act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        elif \"silu\" in act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[act]\n\n        # gating\n        self.gate = FastLinear.load(\n            config, f\"{prefix}.router.layer\", weights, bias=False\n        )\n\n        self.w1 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.w1\",\n            weights=weights,\n            cls=TensorParallelColumnLinear,\n        )\n        self.w2 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.w2\",\n            weights=weights,\n            cls=TensorParallelRowLinear,\n        )\n        self.v1 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.v1\",\n            weights=weights,\n            cls=TensorParallelColumnLinear,\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        x: (sequence_length, model_dim)\n        gate_logits: (sequence_length, n_experts)\n        \"\"\"\n        # optional reshape\n        input_shape = x.shape\n        x = x.view(-1, input_shape[-1])\n\n        # gate_logits: (sequence_length, n_experts)\n        gate_logits = self.gate(x)\n        # all_probs: (sequence_length, n_experts) and upcast for softmax\n        weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n\n        if self.top_k < self.num_experts:\n            _, not_selected_experts = torch.topk(\n                weights,\n                self.num_experts - self.top_k,\n                largest=False,\n                sorted=False,\n                dim=1,\n            )\n            # Mask not selected experts\n            weights.scatter_(1, not_selected_experts, 0)\n\n        # Re-normalize\n        if self.moe_normalize_expert_weights:\n            weights = weights / torch.norm(\n                weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True\n            )\n        weights = weights.to(x.dtype)\n\n        # Final output tensor\n        out = x.new_zeros(x.shape[0], self.hidden_dim)\n        for i in range(self.num_experts):\n            h = self.act(self.w1[i](x)) * self.v1[i](x)\n            h = self.w2[i](h, reduce=False)\n            # Add expert output to out with masking\n            out += h * weights[:, i].view(-1, 1)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out\n\n\nclass DbrxLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.blocks.{layer_id}\"\n\n        self.attn = DbrxNormAttentionNorm(\n            prefix=f\"{prefix}.norm_attn_norm\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n\n        moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE\n        self.moe = moe_cls(f\"{prefix}.ffn\", config, weights)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        # Self Attention\n        attn_output, attn_res = self.attn(\n            hidden_states,\n            residual,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        moe_output = self.moe(attn_output)\n\n        return moe_output, attn_res\n\n\nclass DbrxModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wte\", weights=weights\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.d_model // config.n_heads,\n            base=config.attn_config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                DbrxLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.n_layers)\n            ]\n        )\n        self.norm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_f\", weights=weights, eps=1e-5\n        )\n\n        self.head_size = self.layers[0].attn.self_attn.head_size\n        self.num_heads = self.layers[0].attn.self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDbrxForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        self.model = DbrxModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n    Fp8Linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention_mla,\n    set_block_mapping,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale\nfrom text_generation_server.utils.weights import Weights\nimport habana_frameworks.torch as htorch\n\n\ndef get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:\n    if isinstance(layer, Fp8Linear):\n        eye = torch.eye(\n            layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device\n        )\n        dequant_weights = layer(eye)\n        del eye\n        # standardize to (output, input)\n        return dequant_weights.T\n    return layer.weight\n\n\nclass DeepseekV2Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=2,\n        n_routed_experts=160,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=8,\n        topk_group=3,\n        num_experts_per_tok=6,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\n                \"tie_word_embeddings is not supported for Deepseek V2 models.\"\n            )\n\n        if ep_size != 1:\n            raise ValueError(\n                f\"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}\"\n            )\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass DeepseekV2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights: Weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.kv_lora_rank = config.kv_lora_rank\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim\n        self.value_head_size = config.v_head_dim\n        self.head_pad_size = max(self.head_size, self.value_head_size)\n        self.rotary_emb = rotary_emb\n\n        mscale = get_mscale(\n            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim\n        )\n        self.softmax_scale = self.head_size**-0.5 * mscale * mscale\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        if self.q_lora_rank is None:\n            self.q_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n        else:\n            self.q_a_proj = get_linear(\n                weight=weights.get_weights(f\"{prefix}.q_a_proj\"),\n                bias=(\n                    weights.get_tensor(f\"{prefix}.q_a_proj.bias\")\n                    if config.attention_bias\n                    else None\n                ),\n            )\n            self.q_a_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.q_a_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.q_b_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_b_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n\n        self.kv_a_proj_with_mqa = get_linear(\n            weight=weights.get_weights(f\"{prefix}.kv_a_proj_with_mqa\"),\n            bias=(\n                weights.get_tensor(f\"{prefix}.kv_a_proj_with_mqa.bias\")\n                if config.attention_bias\n                else None\n            ),\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.kv_a_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.kv_a_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.kv_b_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.kv_b_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T\n        kv_b_proj_weight = kv_b_proj_weight.view(\n            self.kv_lora_rank,\n            self.num_heads,\n            self.qk_nope_head_dim + self.value_head_size,\n        )\n\n        W_UK, W_UV = kv_b_proj_weight.split(\n            [self.qk_nope_head_dim, self.value_head_size], dim=-1\n        )\n        # Convert from (L, N, V) to (N, L, V)\n        self.W_UV = W_UV.transpose(0, 1)\n        # Convert from (L, N, P) to (N, P, L)\n        self.W_UK_T = W_UK.permute(1, 2, 0)\n\n    def _q_proj_and_k_up_proj(self, x):\n        q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj\n        q_nope, q_pe = (\n            q_proj(x)\n            .view(-1, self.num_heads, self.head_size)\n            .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n        )\n\n        # Convert from (B, N, P) to (N, B, P)\n        q_nope = q_nope.transpose(0, 1)\n        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)\n        ql_nope = torch.bmm(q_nope, self.W_UK_T)\n        # Convert from (N, B, L) to (B, N, L)\n        return ql_nope.transpose(0, 1), q_pe\n\n    def _v_up_proj_and_o_proj(self, x):\n        # Convert from (B, N, L) to (N, B, L)\n        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)\n        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)\n        x = torch.bmm(x, self.W_UV)\n        # Convert from (N, B, V) to (B, N * V)\n        x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)\n        return self.o_proj(x)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache: KVCache,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        if self.q_lora_rank is None:\n            hidden_states_or_q_c = hidden_states\n        else:\n            hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, key_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n\n        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)\n        kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj\n            query = q_proj(hidden_states_or_q_c)\n            query = query.view(-1, self.num_heads, self.head_size)\n            query_nope, query_pe = torch.split(\n                query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n        else:\n            query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)\n\n        batch_size, heads, head_dim = query_pe.shape\n        query_pe = (\n            query_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        batch_size, heads, head_dim = key_pe.shape\n        key_pe = (\n            key_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        self.rotary_emb(query_pe, key_pe, cos, sin)\n        latent_vec_k = torch.concat(\n            (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1\n        )\n        latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)\n\n        latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))\n\n        kv_cache.store(\n            key=latent_vec_k,\n            value=None,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        if cu_seqlen_prefill is not None:\n            kv = self.kv_b_proj(kv_c_normed).view(\n                -1,\n                self.num_key_value_heads,\n                self.qk_nope_head_dim + self.value_head_size,\n            )\n\n            key_nope, value = torch.split(\n                kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1\n            )\n            query[..., self.qk_nope_head_dim :] = query_pe\n            key = torch.empty_like(query)\n            key[..., : self.qk_nope_head_dim] = key_nope\n            key[..., self.qk_nope_head_dim :] = key_pe\n\n            # We need to pad the heads because Flash Attention does not support\n            # qk and v with different head sizes.\n            query = torch.nn.functional.pad(\n                query, (0, self.head_pad_size - self.head_size), value=0\n            )\n            key = torch.nn.functional.pad(\n                key, (0, self.head_pad_size - self.head_size), value=0\n            )\n            value = torch.nn.functional.pad(\n                value, (0, self.head_pad_size - self.value_head_size), value=0\n            )\n\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n            attn_output = attn_output[..., : self.value_head_size]\n\n            return self.o_proj(\n                attn_output.reshape(-1, self.num_heads * self.value_head_size)\n            )\n        else:\n            # Decode\n            query = torch.cat([query_nope, query_pe], dim=-1)\n            attn_output = paged_attention_mla(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                kv_lora_rank=self.kv_lora_rank,\n            )\n            attn_output = self._v_up_proj_and_o_proj(attn_output)\n            return attn_output\n\n\nclass DeepseekV2MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, intermediate_size: int):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        if self.hidden_act != \"silu\":\n            # Bail out because MoE only supports silu.\n            raise NotImplementedError(\n                \"Currently only `silu` is supported as an activation for Deepseek V2.\"\n            )\n        self.act = ACT2FN[self.hidden_act]\n\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states: torch.Tensor, reduce: bool = True):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce\n        )\n\n\nclass DeepseekV2MoE(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config: DeepseekV2Config,\n        moe_layer_cls: Type[MoELayer],\n        weights,\n    ):\n        super().__init__()\n\n        self.hidden_dim = config.hidden_size\n        self.moe_intermediate_size = (\n            config.moe_intermediate_size // weights.process_group.size()\n        )\n        self.routed_scaling_factor = config.routed_scaling_factor\n\n        # Gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe_layer = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.n_routed_experts,\n            n_expert_group=config.n_group,\n            renormalize=config.norm_topk_prob,\n            topk=config.num_experts_per_tok,\n            topk_group=config.topk_group,\n            weights=weights,\n        )\n        assert isinstance(self.moe_layer, MoELayer)\n\n        if config.n_shared_experts is not None:\n            self.shared_experts = DeepseekV2MLP(\n                prefix=f\"{prefix}.shared_experts\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.moe_intermediate_size\n                * config.n_shared_experts,\n            )\n        else:\n            self.shared_experts = None\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.shared_experts is not None:\n            shared_output = self.shared_experts(x, reduce=False)\n        else:\n            shared_output = None\n\n        router_logits = self.gate(x)\n\n        out = self.moe_layer(x, gating_output=router_logits)\n\n        if shared_output is not None:\n            out = out + shared_output\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DeepseekV2Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = DeepseekV2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n\n        if (\n            config.n_routed_experts is not None\n            and layer_id >= config.first_k_dense_replace\n            and layer_id % config.moe_layer_freq == 0\n        ):\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = DeepseekV2MoE(f\"{prefix}.mlp\", config, moe_layer_cls, weights)\n        else:\n            self.mlp = DeepseekV2MLP(\n                prefix=f\"{prefix}.mlp\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.intermediate_size,\n            )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, residual = self.post_attention_layernorm(\n            attn_output, residual\n        )\n\n        output = self.mlp(normed_attn_res_output)\n\n        return output, residual\n\n\nclass DeepseekV2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.qk_rope_head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV2Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDeepseekV2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.model = DeepseekV2Model(\n            \"model\" if not prefix else f\"{prefix}.model\", config, weights\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n    Fp8Linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention_mla,\n    set_block_mapping,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale\nfrom text_generation_server.utils.weights import Weights\nimport habana_frameworks.torch as htorch\n\n\ndef get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:\n    if isinstance(layer, Fp8Linear):\n        eye = torch.eye(\n            layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device\n        )\n        dequant_weights = layer(eye)\n        del eye\n        # standardize to (output, input)\n        return dequant_weights.T\n    return layer.weight\n\n\nclass DeepseekV3Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=2,\n        n_routed_experts=160,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=8,\n        topk_group=3,\n        num_experts_per_tok=6,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\n                \"tie_word_embeddings is not supported for Deepseek V2 models.\"\n            )\n\n        if ep_size != 1:\n            raise ValueError(\n                f\"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}\"\n            )\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass DeepseekV3Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights: Weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.kv_lora_rank = config.kv_lora_rank\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim\n        self.value_head_size = config.v_head_dim\n        self.head_pad_size = max(self.head_size, self.value_head_size)\n        self.rotary_emb = rotary_emb\n\n        mscale = get_mscale(\n            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim\n        )\n        self.softmax_scale = self.head_size**-0.5 * mscale * mscale\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        if self.q_lora_rank is None:\n            self.q_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n        else:\n            self.q_a_proj = get_linear(\n                weight=weights.get_weights(f\"{prefix}.q_a_proj\"),\n                bias=(\n                    weights.get_tensor(f\"{prefix}.q_a_proj.bias\")\n                    if config.attention_bias\n                    else None\n                ),\n            )\n            self.q_a_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.q_a_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.q_b_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_b_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n\n        self.kv_a_proj_with_mqa = get_linear(\n            weight=weights.get_weights(f\"{prefix}.kv_a_proj_with_mqa\"),\n            bias=(\n                weights.get_tensor(f\"{prefix}.kv_a_proj_with_mqa.bias\")\n                if config.attention_bias\n                else None\n            ),\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.kv_a_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.kv_a_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.kv_b_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.kv_b_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n        kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T\n        kv_b_proj_weight = kv_b_proj_weight.view(\n            self.kv_lora_rank,\n            self.num_heads,\n            self.qk_nope_head_dim + self.value_head_size,\n        )\n        W_UK, W_UV = kv_b_proj_weight.split(\n            [self.qk_nope_head_dim, self.value_head_size], dim=-1\n        )\n        # Convert from (L, N, V) to (N, L, V)\n        self.W_UV = W_UV.transpose(0, 1)\n        # Convert from (L, N, P) to (N, P, L)\n        self.W_UK_T = W_UK.permute(1, 2, 0)\n\n    def _q_proj_and_k_up_proj(self, x):\n        q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj\n        q_nope, q_pe = (\n            q_proj(x)\n            .view(-1, self.num_heads, self.head_size)\n            .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)\n        )\n\n        # Convert from (B, N, P) to (N, B, P)\n        q_nope = q_nope.transpose(0, 1)\n        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)\n        ql_nope = torch.bmm(q_nope, self.W_UK_T)\n        # Convert from (N, B, L) to (B, N, L)\n        return ql_nope.transpose(0, 1), q_pe\n\n    def _v_up_proj_and_o_proj(self, x):\n        # Convert from (B, N, L) to (N, B, L)\n        x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)\n        # Multiply (N, B, L) x (N, L, V) -> (N, B, V)\n        x = torch.bmm(x, self.W_UV)\n        # Convert from (N, B, V) to (B, N * V)\n        x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)\n        return self.o_proj(x)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache: KVCache,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        if self.q_lora_rank is None:\n            hidden_states_or_q_c = hidden_states\n        else:\n            hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, key_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n\n        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)\n        kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj\n            query = q_proj(hidden_states_or_q_c)\n            query = query.view(-1, self.num_heads, self.head_size)\n            query_nope, query_pe = torch.split(\n                query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n        else:\n            query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)\n\n        batch_size, heads, head_dim = query_pe.shape\n        query_pe = (\n            query_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        batch_size, heads, head_dim = key_pe.shape\n        key_pe = (\n            key_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        self.rotary_emb(query_pe, key_pe, cos, sin)\n        latent_vec_k = torch.concat(\n            (kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1\n        )\n        latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)\n\n        latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))\n\n        kv_cache.store(\n            key=latent_vec_k,\n            value=None,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        if cu_seqlen_prefill is not None:\n            kv = self.kv_b_proj(kv_c_normed).view(\n                -1,\n                self.num_key_value_heads,\n                self.qk_nope_head_dim + self.value_head_size,\n            )\n\n            key_nope, value = torch.split(\n                kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1\n            )\n            query[..., self.qk_nope_head_dim :] = query_pe\n            key = torch.empty_like(query)\n            key[..., : self.qk_nope_head_dim] = key_nope\n            key[..., self.qk_nope_head_dim :] = key_pe\n\n            # We need to pad the heads because Flash Attention does not support\n            # qk and v with different head sizes.\n            query = torch.nn.functional.pad(\n                query, (0, self.head_pad_size - self.head_size), value=0\n            )\n            key = torch.nn.functional.pad(\n                key, (0, self.head_pad_size - self.head_size), value=0\n            )\n            value = torch.nn.functional.pad(\n                value, (0, self.head_pad_size - self.value_head_size), value=0\n            )\n\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n            attn_output = attn_output[..., : self.value_head_size]\n\n            return self.o_proj(\n                attn_output.reshape(-1, self.num_heads * self.value_head_size)\n            )\n        else:\n            # Decode\n            query = torch.cat([query_nope, query_pe], dim=-1)\n            attn_output = paged_attention_mla(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                kv_lora_rank=self.kv_lora_rank,\n            )\n            attn_output = self._v_up_proj_and_o_proj(attn_output)\n            return attn_output\n\n\nclass DeepseekV3MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, intermediate_size: int):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        if self.hidden_act != \"silu\":\n            # Bail out because MoE only supports silu.\n            raise NotImplementedError(\n                \"Currently only `silu` is supported as an activation for Deepseek V2.\"\n            )\n        self.act = ACT2FN[self.hidden_act]\n\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states: torch.Tensor, reduce: bool = True):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce\n        )\n\n\nclass DeepseekV3MoE(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config: DeepseekV3Config,\n        moe_layer_cls: Type[MoELayer],\n        weights,\n    ):\n        super().__init__()\n\n        self.hidden_dim = config.hidden_size\n        self.moe_intermediate_size = (\n            config.moe_intermediate_size // weights.process_group.size()\n        )\n        self.routed_scaling_factor = config.routed_scaling_factor\n\n        # Gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        if config.topk_method == \"noaux_tc\":\n            self.gate.e_score_correction_bias = torch.zeros(\n                config.n_routed_experts, device=weights.device\n            )\n        else:\n            self.gate.e_score_correction_bias = None\n\n        self.moe_layer = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.n_routed_experts,\n            n_expert_group=config.n_group,\n            renormalize=config.norm_topk_prob,\n            topk=config.num_experts_per_tok,\n            topk_group=config.topk_group,\n            weights=weights,\n            scoring_func=config.scoring_func,\n            e_score_correction_bias=self.gate.e_score_correction_bias,\n        )\n        assert isinstance(self.moe_layer, MoELayer)\n\n        if config.n_shared_experts is not None:\n            self.shared_experts = DeepseekV3MLP(\n                prefix=f\"{prefix}.shared_experts\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.moe_intermediate_size\n                * config.n_shared_experts,\n            )\n        else:\n            self.shared_experts = None\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.shared_experts is not None:\n            shared_output = self.shared_experts(x, reduce=False)\n        else:\n            shared_output = None\n\n        router_logits = self.gate(x)\n\n        out = self.moe_layer(x, gating_output=router_logits)\n\n        if shared_output is not None:\n            out = out + shared_output\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DeepseekV3Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = DeepseekV3Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n\n        if (\n            config.n_routed_experts is not None\n            and layer_id >= config.first_k_dense_replace\n            and layer_id % config.moe_layer_freq == 0\n        ):\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = DeepseekV3MoE(f\"{prefix}.mlp\", config, moe_layer_cls, weights)\n        else:\n            self.mlp = DeepseekV3MLP(\n                prefix=f\"{prefix}.mlp\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.intermediate_size,\n            )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, residual = self.post_attention_layernorm(\n            attn_output, residual\n        )\n\n        output = self.mlp(normed_attn_res_output)\n\n        return output, residual\n\n\nclass DeepseekV3Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.qk_rope_head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV3Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDeepseekV3ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.model = DeepseekV3Model(\n            \"model\" if not prefix else f\"{prefix}.model\", config, weights\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nimport habana_frameworks.torch as htorch\n\n\nclass Gemma2Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=256128,\n        hidden_size=3072,\n        intermediate_size=24576,\n        num_hidden_layers=28,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        head_dim=256,\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=8192,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=True,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.head_dim = head_dim\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass Gemma2FastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemma2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        layer_id,\n        causal: bool,\n        is_sliding: bool,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n        if is_sliding:\n            self.window_size = config.sliding_window\n        else:\n            self.window_size = -1\n        self.rotary_emb = rotary_emb\n\n        # self.softmax_scale = self.head_size**-0.5\n        self.softmax_scale = config.query_pre_attn_scalar**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n        self.softcap = config.attn_logit_softcapping\n\n        query_key_value = load_attention(config, prefix, weights)\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.window_size,\n                softcap=self.softcap,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                softcap=self.softcap,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.window_size,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Gemma2MLP(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        act = config.hidden_activation\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass FlashGemma2Layer(nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        layer_id,\n        causal: bool,\n        is_sliding: bool,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.self_attn = FlashGemma2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n            causal=causal,\n            is_sliding=is_sliding,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = Gemma2MLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.pre_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.post_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)\n        normed_attn_res_output = normed_attn_res_output + res\n        res = normed_attn_res_output\n\n        pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)\n        mlp_output = self.mlp(pre_normed, adapter_data)\n        post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)\n\n        return post_hidden_states, normed_attn_res_output\n\n\nclass FlashGemma2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashGemma2Layer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                    causal=causal,\n                    is_sliding=layer_id % 2 == 0,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data: Optional[torch.Tensor],\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemma2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemma2Model(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n        self.softcap = config.final_logit_softcapping\n        assert isinstance(self.softcap, float)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        logits /= self.softcap\n        logits = torch.tanh(logits)\n        logits *= self.softcap\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom typing import Optional, List, Tuple\nimport copy\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n    #\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\n\nimport torch\n\n\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\n\n\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n    set_block_mapping,\n    HPUPagedAttentionMetadata,\n)\nimport habana_frameworks.torch as htorch\n\nATTENTION_TYPE_GLOBAL = \"global\"\nATTENTION_TYPE_LOCAL = \"local_sliding\"\n\n\nclass Gemma3FastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemma3Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        layer_id,\n        causal: bool,\n        is_sliding: bool,\n        local_rotary_emb,\n        global_rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n        if is_sliding:\n            self.window_size = config.sliding_window\n            self.rotary_emb = local_rotary_emb\n        else:\n            self.window_size = -1\n            self.rotary_emb = global_rotary_emb\n\n        self.softmax_scale = (\n            config.query_pre_attn_scalar**-0.5\n            if config.query_pre_attn_scalar is not None\n            else None\n        )\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n        self.softcap = None  # config.attn_logit_softcapping\n\n        query_key_value = load_attention(config, prefix, weights)\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n        self.q_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.k_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.enable_gqa = self.num_heads != self.num_key_value_heads\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)\n        key = kv[:, 0]\n        value = kv[:, 1]\n\n        query = query.reshape(-1, self.head_size)\n        key = key.reshape(-1, self.head_size)\n\n        query, _ = self.q_norm(query.contiguous())\n        key, _ = self.k_norm(key.contiguous())\n\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_key_value_heads, self.head_size)\n        value = value.view(-1, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.window_size,\n                softcap=self.softcap,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                softcap=self.softcap,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.window_size,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Gemma3MLP(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        act = config.hidden_activation\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass FlashGemma3Layer(nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        layer_id,\n        causal: bool,\n        is_sliding: bool,\n        local_rotary_emb,\n        global_rotary_emb,\n    ):\n        super().__init__()\n        self.self_attn = FlashGemma3Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n            causal=causal,\n            is_sliding=is_sliding,\n            local_rotary_emb=local_rotary_emb,\n            global_rotary_emb=global_rotary_emb,\n        )\n        self.mlp = Gemma3MLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.pre_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.post_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)\n        normed_attn_res_output = normed_attn_res_output + res\n        res = normed_attn_res_output\n\n        pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)\n        mlp_output = self.mlp(pre_normed, adapter_data)\n        post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)\n\n        return post_hidden_states, normed_attn_res_output\n\n\nclass FlashGemma3Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        local_config = copy.deepcopy(config)\n        local_config.rope_scaling = dict(rope_type=\"default\")\n        local_rotary_emb = PositionRotaryEmbedding.static(\n            config=local_config,\n            dim=config.head_dim,\n            base=config.rope_local_base_freq,\n            device=weights.device,\n        )\n        global_rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashGemma3Layer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                    causal=causal,\n                    is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),\n                    local_rotary_emb=local_rotary_emb,\n                    global_rotary_emb=global_rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data: Optional[torch.Tensor],\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            # Get rotary cos and sin for this forward\n            # Avoid to index in each layer\n            cos, sin = layer.self_attn.rotary_emb.get_cos_sin(position_ids)\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemma3ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemma3Model(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n        # self.softcap = config.attn_logit_softcapping\n        # assert isinstance(self.softcap, float)\n        self.softcap = None\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        return logits, speculative_logits\n\n\nclass Gemma3MultimodalInputProjection(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.mm_input_projection_weight = weights.get_tensor(\n            \"multi_modal_projector.mm_input_projection_weight\"\n        )\n\n        self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.mm_soft_emb_norm\",\n            weights=weights,\n            eps=config.vision_config.layer_norm_eps,\n        )\n\n        self.patches_per_image = int(\n            config.vision_config.image_size // config.vision_config.patch_size\n        )\n        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)\n        self.kernel_size = self.patches_per_image // self.tokens_per_side\n        self.avg_pool = nn.AvgPool2d(\n            kernel_size=self.kernel_size, stride=self.kernel_size\n        )\n\n    def forward(self, vision_outputs: torch.Tensor):\n        batch_size, _, seq_length = vision_outputs.shape\n\n        reshaped_vision_outputs = vision_outputs.transpose(1, 2)\n        reshaped_vision_outputs = reshaped_vision_outputs.reshape(\n            batch_size, seq_length, self.patches_per_image, self.patches_per_image\n        )\n        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()\n\n        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)\n        pooled_vision_outputs = pooled_vision_outputs.flatten(2)\n        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)\n\n        normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)\n\n        projected_vision_outputs = torch.matmul(\n            normed_vision_outputs, self.mm_input_projection_weight\n        )\n        return projected_vision_outputs.type_as(vision_outputs)\n\n\nclass Gemma3ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.config = config\n\n        if config.vision_config is not None:\n\n            config.vision_config.quantize = config.quantize\n\n            self.post_vision_model_layernorm = nn.LayerNorm.load(\n                prefix=\"vision_tower.vision_model.post_layernorm\",\n                weights=weights,\n                eps=config.vision_config.layer_norm_eps,\n            )\n\n            self.multimodal_projector = Gemma3MultimodalInputProjection(\n                prefix=\"multi_modal_projector\",\n                config=config,\n                weights=weights,\n            )\n\n            text_config = config.text_config\n            text_config.speculator = config.speculator\n            text_config.quantize = config.quantize\n\n            self.vision_model = load_vision_model(\n                prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n                config=config.vision_config,\n                weights=weights,\n            )\n\n            self.text_model = load_text_model(\n                prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n                config=config.text_config,\n                weights=weights,\n            )\n        else:\n            config.text_config.quantize = config.quantize\n            config.text_config.speculator = config.speculator\n            self.text_model = load_text_model(\n                prefix=prefix,\n                config=config.text_config,\n                weights=weights,\n            )\n\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n        self.dtype = weights.dtype\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        pixel_values = pixel_values.to(dtype=self.dtype)\n        image_outputs = self.vision_model(pixel_values)\n        vision_outputs = self.post_vision_model_layernorm(\n            image_outputs.last_hidden_state\n        )\n        image_features = self.multimodal_projector(vision_outputs)\n        image_features = image_features.view(-1, image_features.shape[-1])\n        return image_features\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # Replace the image token embeddings with the vision features\n            image_token_mask = (input_ids == self.config.image_token_index).to(\n                input_ids.device\n            )\n            inputs_embeds[image_token_mask] = vision_embeds.view(\n                -1, vision_embeds.shape[-1]\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        if cu_seqlen_prefill is not None:\n            position_ids += 1\n\n        if attention_mask is not None:\n            min_dtype = torch.finfo(inputs_embeds.dtype).min\n            # prefill may be larger than sliding window\n            effective_seq_len = max(\n                position_ids.shape[0], self.config.text_config.sliding_window\n            )\n            sliding_window_mask = torch.tril(\n                torch.ones_like(attention_mask, dtype=torch.bool),\n                diagonal=-self.config.text_config.sliding_window,\n            )\n            attention_mask_local = torch.where(\n                sliding_window_mask, min_dtype, attention_mask\n            )\n            offset = max(0, position_ids.shape[0] - effective_seq_len)\n            attention_mask_local = attention_mask_local[\n                :, :, :, offset : offset + effective_seq_len\n            ]\n        else:\n            attention_mask_local = None\n\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            adapter_data=adapter_data,\n        )\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nimport habana_frameworks.torch as htorch\n\n\nclass GemmaConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=256128,\n        hidden_size=3072,\n        intermediate_size=24576,\n        num_hidden_layers=28,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        head_dim=256,\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=8192,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=True,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.head_dim = head_dim\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass GemmaFastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemmaAttention(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n        self.rotary_emb = rotary_emb\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                causal=self.causal,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GemmaMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])\n\n\nclass FlashGemmaLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool, rotary_emb):\n        super().__init__()\n        self.self_attn = FlashGemmaAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            causal=causal,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = GemmaMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output)\n\n        return mlp_output, attn_res\n\n\nclass FlashGemmaModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashGemmaLayer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    causal=causal,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data: Optional[torch.Tensor],\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemmaForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemmaModel(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nimport habana_frameworks.torch as htorch\n\n\ndef load_qkv(config, prefix: str, weights, head_size, num_heads):\n    if config.quantize == \"gptq\":\n        return _load_qkv_gptq(\n            config,\n            prefix,\n            weights,\n        )\n    else:\n        return _load_qkv(config, prefix, weights, head_size, num_heads)\n\n\ndef _load_qkv_gptq(config, prefix: str, weights):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    # Weights\n    weight = weights.get_weights_col_packed_qkv(\n        f\"{prefix}.c_attn\",\n        config.num_attention_heads,\n        config.num_attention_heads,\n    )\n\n    # Bias\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n    shape = slice_.get_shape()\n    total_size = shape[0]\n    assert total_size % 3 == 0, f\"Prepacked is not divisible by {3}\"\n    single_size = total_size // 3\n    assert single_size % world_size == 0\n    block_size = single_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    tensors = []\n    for i in range(3):\n        tensor = slice_[start + i * single_size : stop + i * single_size]\n        tensors.append(tensor)\n    bias = torch.cat(tensors, dim=0)\n    bias = bias.to(device=weights.device)\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef _load_qkv(config, prefix: str, weights, head_size, num_heads):\n    \"\"\"Load QKV from a single, transposed matrix.\"\"\"\n\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.weight\")\n    shape = slice_.get_shape()\n    total_size = shape[1]\n    assert total_size % 3 == 0, f\"Prepacked is not divisible by {3}\"\n    world_size = weights.process_group.size()\n    single_size = total_size // 3\n    assert single_size % world_size == 0\n    rank = weights.process_group.rank()\n\n    # Weights\n    block_size = single_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    tensors = []\n    for i in range(3):\n        tensor = slice_[:, start + i * single_size : stop + i * single_size]\n        tensors.append(tensor)\n    weight = torch.cat(tensors, dim=1).T\n    weight = weight.to(dtype=weights.dtype)\n    weight = weight.to(device=weights.device)\n\n    # Bias\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n    shape = slice_.get_shape()\n    total_size = shape[0]\n    single_size = total_size // 3\n    block_size = single_size // world_size\n    assert single_size % world_size == 0\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    b = []\n    for i in range(3):\n        tensor = slice_[start + i * single_size : stop + i * single_size]\n        b.append(tensor)\n    bias = torch.cat(b, dim=0)\n    bias = bias.to(dtype=weights.dtype)\n    bias = bias.to(device=weights.device)\n    assert list(bias.shape) == [\n        3 * num_heads * head_size\n    ], f\"{weight.shape} != {[3 * num_heads * head_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    \"\"\"load_row, but with transposed weight matrices.\"\"\"\n\n    if config.quantize == \"gptq\":\n        weight = weights.get_weights_row(prefix)\n    else:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0).T\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    return TensorParallelRowLinear(\n        get_linear(weight, bias), process_group=weights.process_group\n    )\n\n\ndef load_col(config, prefix: str, weights, bias: bool):\n    \"\"\"load_col, but with transposed weight matrices.\"\"\"\n    if config.quantize == \"gptq\":\n        weight = weights.get_multi_weights_col([prefix], dim=1)\n    else:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=1).T\n\n    if bias:\n        bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\nclass FlashGPT2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n\n        self.head_size = self.hidden_size // self.num_heads\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = load_qkv(\n            config,\n            prefix=prefix,\n            weights=weights,\n            head_size=self.head_size,\n            num_heads=self.num_heads,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = load_row(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        query, key, value = self.query_key_value(hidden_states).split(\n            self.head_size * self.num_heads, dim=1\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_heads, self.head_size)\n        value = value.view(-1, self.num_heads, self.head_size)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GPT2MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.c_fc = load_col(\n            config, prefix=f\"{prefix}.c_fc\", weights=weights, bias=True\n        )\n        self.c_proj = load_row(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n        intermediate_size = (\n            config.n_inner if config.n_inner is not None else 4 * config.hidden_size\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        return self.c_proj(hidden_states)\n\n\nclass FlashGPT2Layer(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.self_attn = FlashGPT2Attention(\n            prefix=f\"{prefix}.attn\", config=config, weights=weights\n        )\n        self.mlp = GPT2MLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.ln_2\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        hidden_states = attn_output + residual\n        residual = hidden_states\n\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        mlp_output = self.mlp(hidden_states)\n\n        return residual + mlp_output, residual\n\n\nclass FlashGPT2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                FlashGPT2Layer(\n                    prefix=(\n                        f\"h.{layer_id}\" if not prefix else f\"{prefix}.h.{layer_id}\"\n                    ),\n                    config=config,\n                    weights=weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.norm = nn.LayerNorm.load(\n            prefix=\"ln_f\" if not prefix else f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass FlashGPT2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\"wte\" if not prefix else f\"{prefix}.wte\"),\n            weights=weights,\n        )\n        self.embed_positions = TensorParallelEmbedding(\n            prefix=(\"wpe\" if not prefix else f\"{prefix}.wpe\"),\n            weights=weights,\n        )\n\n        self.model = FlashGPT2Model(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"wte\" if not prefix else f\"{prefix}.wte\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        token_embeds = self.embed_tokens(input_ids)\n        position_embeds = self.embed_positions(position_ids)\n        inputs_embeds = token_embeds + position_embeds\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom habana_frameworks.torch.hpex.kernels import (\n    RotaryPosEmbeddingMode,\n    apply_rotary_pos_emb,\n)\nimport habana_frameworks.torch as htorch\n\n\ndef load_attention(config, prefix: str, weights):\n    return TensorParallelColumnLinear.load_multi(\n        config,\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n        weights=weights,\n        bias=False,\n    )\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\nclass GPTJRotary(PositionRotaryEmbedding):\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        num_tokens = query.shape[0]\n        head_size = query.shape[-1]\n        rope_mode = RotaryPosEmbeddingMode.PAIRWISE\n        sin = torch.repeat_interleave(sin, 2, dim=-1)\n        cos = torch.repeat_interleave(cos, 2, dim=-1)\n        rotary_dim = cos.shape[-1]\n        query_shape = query.shape\n        query = query.view(num_tokens, -1, head_size)\n        query_rot = query[..., :rotary_dim]\n        query_pass = query[..., rotary_dim:]\n        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)\n        query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))\n\n        key_shape = key.shape\n        key = key.view(num_tokens, -1, head_size)\n        key_rot = key[..., :rotary_dim]\n        key_pass = key[..., rotary_dim:]\n        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)\n        key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))\n\n\nclass FlashGPTJAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n\n        self.head_size = self.hidden_size // self.num_heads\n        self.softmax_scale = self.head_size**-0.5\n        self.rotary_dim = config.rotary_dim\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = load_attention(\n            config,\n            prefix=prefix,\n            weights=weights,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = load_row(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n        self.rotary_emb = rotary_emb\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        query, key, value = self.query_key_value(hidden_states).split(\n            self.head_size * self.num_heads, dim=1\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_heads, self.head_size)\n        value = value.view(-1, self.num_heads, self.head_size)\n\n        # Compute rotary embeddings on rotary_ndims\n        if self.rotary_dim is not None:\n            self.rotary_emb(\n                query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin\n            )\n        else:\n            self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GPTJMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.fc_in = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.fc_in\", weights=weights, bias=True\n        )\n\n        self.fc_out = load_row(\n            config,\n            prefix=f\"{prefix}.fc_out\",\n            weights=weights,\n            bias=True,\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.fc_in(hidden_states)\n        hidden_states = self.act(hidden_states)\n        return self.fc_out(hidden_states)\n\n\nclass FlashGPTJLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights, rotary_emb):\n        super().__init__()\n        self.self_attn = FlashGPTJAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = GPTJMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        hidden_states, residual = self.input_layernorm(hidden_states, residual)\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        feed_forward_hidden_states = self.mlp(hidden_states)\n\n        return attn_output + feed_forward_hidden_states, residual\n\n\nclass FlashGPTJModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.config = config\n\n        self.wte = TensorParallelEmbedding(prefix=f\"{prefix}.wte\", weights=weights)\n        rotary_emb = GPTJRotary.static(\n            config=config,\n            dim=config.rotary_dim,\n            base=10000,\n            device=weights.device,\n        )\n        self.layers = nn.ModuleList(\n            [\n                FlashGPTJLayer(\n                    prefix=(\n                        f\"h.{layer_id}\" if not prefix else f\"{prefix}.h.{layer_id}\"\n                    ),\n                    config=config,\n                    weights=weights,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.ln_f = FastLayerNorm.load(\n            prefix=\"ln_f\" if not prefix else f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor],\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.wte(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGPTJForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n        self.model = FlashGPTJModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama4_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport math\nimport torch.utils.checkpoint\nfrom torch import nn\nimport torch.nn.functional as F\n\nimport habana_frameworks.torch as htorch\nfrom transformers.cache_utils import Cache\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_outputs import (\n    BaseModelOutput,\n)\n\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n    FastLinear,\n)\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.attention import (\n    KVCache,\n    paged_attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.models.custom_modeling.flash_llama_modeling import (\n    FlashLlamaAttention,\n)\n\n\ndef reshape_for_broadcast(freqs: torch.Tensor, target):\n    ndim = len(target)\n    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(target)]\n    return freqs.view(*shape)\n\n\ndef apply_rotary_emb(\n    query: torch.Tensor,\n    key: torch.Tensor,\n    freqs_ci: torch.Tensor,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    query_shape = query.shape\n    key_shape = key.shape\n    cos_emb, sin_emb = freqs_ci.split(1, dim=-1)\n\n    if len(query.shape) == 3:\n        query = query.unsqueeze(0)\n        key = key.unsqueeze(0)\n\n    query_reshaped = query.float().reshape(*query.shape[:-1], -1, 2)\n    key_reshaped = key.float().reshape(*key.shape[:-1], -1, 2)\n    q_shape = query_reshaped.shape[:-1]\n    cos_emb = reshape_for_broadcast(cos_emb, q_shape)\n    sin_emb = reshape_for_broadcast(sin_emb, q_shape)\n    x_q, y_q = query_reshaped.unbind(-1)\n    x_k, y_k = key_reshaped.unbind(-1)\n\n    x_q_rot = x_q * cos_emb - y_q * sin_emb\n    y_q_rot = x_q * sin_emb + y_q * cos_emb\n    x_k_rot = x_k * cos_emb - y_k * sin_emb\n    y_k_rot = x_k * sin_emb + y_k * cos_emb\n\n    query_out = torch.stack([x_q_rot, y_q_rot], dim=-1).flatten(-2)\n    key_out = torch.stack([x_k_rot, y_k_rot], dim=-1).flatten(-2)\n    query_out = query_out.view(*query_shape)\n    key_out = key_out.view(*key_shape)\n    return query_out.type_as(query), key_out.type_as(key)\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Llama4TextExperts(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.process_group = weights.process_group\n        self.num_experts = config.num_local_experts\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n        self.hidden_size = config.hidden_size\n        self.expert_dim = self.intermediate_size\n        self.gate_up_proj = nn.Parameter(\n            weights.get_packed_sharded(f\"{prefix}.gate_up_proj\", dim=-1, block_sizes=2),\n            requires_grad=False,\n        )\n        self.down_proj = nn.Parameter(\n            weights.get_sharded(f\"{prefix}.down_proj\", dim=1), requires_grad=False\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        This should really not be run on a single machine, as we are reaching compute bound:\n        - the inputs are expected to be \"sorted\" per expert already.\n        - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape\n\n        Args:\n            hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)\n            selected_experts (torch.Tensor): (batch_size * token_num, top_k)\n            routing_weights (torch.Tensor): (batch_size * token_num, top_k)\n        Returns:\n            torch.Tensor\n        \"\"\"\n        gate_up_proj = self.gate_up_proj.view(self.num_experts, -1, 2 * self.expert_dim)\n        down_proj = self.down_proj.view(self.num_experts, self.expert_dim, -1)\n        hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)\n        gate_up = torch.bmm(hidden_states, gate_up_proj)\n        gate, up = gate_up.chunk(2, dim=-1)  # not supported for DTensors\n        next_states = torch.bmm((up * self.act_fn(gate)), down_proj)\n        next_states = next_states.view(-1, self.hidden_size)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(next_states, group=self.process_group)\n\n        return next_states\n\n\n# Phi3MLP\nclass Llama4TextMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config=config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config=config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up_states = self.gate_up_proj(x)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1])\n\n\nclass Llama4TextL2Norm(torch.nn.Module):\n    def __init__(self, eps: float = 1e-6):\n        super().__init__()\n        self.eps = eps\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        return self._norm(x.float()).type_as(x)\n\n    def extra_repr(self):\n        return f\"eps={self.eps}\"\n\n\nclass Llama4TextMoe(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.top_k = config.num_experts_per_tok\n        self.hidden_dim = config.hidden_size\n        self.num_experts = config.num_local_experts\n        self.experts = Llama4TextExperts(\n            config=config, prefix=f\"{prefix}.experts\", weights=weights\n        )\n        self.router = FastLinear.load(\n            config=config, prefix=f\"{prefix}.router\", weights=weights, bias=False\n        )\n        self.shared_expert = Llama4TextMLP(\n            config=config, prefix=f\"{prefix}.shared_expert\", weights=weights\n        )\n        self.process_group = weights.process_group\n\n    def forward(self, hidden_states, adapter_data):\n        seq_len, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, self.hidden_dim)\n        tokens_per_expert = hidden_states.shape[0]\n        router_logits = self.router(hidden_states)\n\n        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)\n        router_scores = (\n            torch.full_like(router_logits, float(\"-inf\"))\n            .scatter_(1, router_indices, router_top_value)\n            .transpose(0, 1)\n        )\n        # We do this to make sure we have -inf for non topK tokens before going through the !\n        # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!\n        router_indices = (\n            torch.arange(tokens_per_expert, device=hidden_states.device)\n            .view(1, -1)\n            .expand(router_scores.size(0), -1)\n        )\n        router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)\n\n        router_indices = router_indices.reshape(-1, 1).expand(-1, self.hidden_dim)\n        routed_in = torch.gather(\n            input=hidden_states,\n            dim=0,\n            index=router_indices,\n        ).to(hidden_states.device)\n\n        # we gather inputs corresponding to each expert based on the router indices\n        routed_in = routed_in * router_scores.reshape(-1, 1)\n        routed_out = self.experts(routed_in)\n        out = self.shared_expert(hidden_states)\n\n        # now that we finished expert computation -> we scatter add because we gathered previously\n        # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound\n        # this scales a lot better if you do EP!\n        out.scatter_add_(\n            dim=0, index=router_indices, src=routed_out.view(-1, self.hidden_dim)\n        )\n        return out\n\n\nclass Llama4TextRotaryEmbedding(nn.Module):\n    def __init__(self, config, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        self.rope_type = \"llama3\" if config.rope_scaling is not None else \"default\"\n\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def forward(self, x, position_ids):\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n        device_type = (\n            x.device.type\n            if isinstance(x.device.type, str) and x.device.type != \"mps\"\n            else \"cpu\"\n        )\n        inv_freq_expanded = inv_freq_expanded.to(device_type)\n        position_ids_expanded = position_ids_expanded.to(device_type)\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)\n            freqs_cis = (\n                torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)\n                * self.attention_scaling\n            )\n        return freqs_cis.to(dtype=x.dtype, device=x.device)\n\n\nclass Llama4TextAttention(FlashLlamaAttention):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights, layer_idx):\n        super().__init__(layer_idx, prefix, config, weights, None)\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n        self.scaling = self.head_dim**-0.5\n        self.attn_scale = config.attn_scale\n        self.floor_scale = config.floor_scale\n        self.attn_temperature_tuning = config.attn_temperature_tuning\n        self.attention_dropout = config.attention_dropout\n        self.use_rope = int((layer_idx + 1) % 4 != 0)  # rope unused for dense layers\n\n        if self.config.use_qk_norm and self.use_rope:\n            self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_ci,\n        cu_seqlen_prefill,\n        kv_cache: KVCache,\n        slots,\n        seqlen,\n        adapter_data,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bs = seqlen.input_lengths.shape[0]\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_key_value_heads,\n                self.head_dim * self.num_key_value_heads,\n            ],\n            dim=-1,\n        )\n\n        query_states = query_states.view(hidden_shape)\n        key_states = key_states.view(hidden_shape)\n        value_states = value_states.view(hidden_shape)\n\n        if self.use_rope:  # the 16E model skips rope for long context on certain layers\n            query_states, key_states = apply_rotary_emb(\n                query_states, key_states, freqs_ci\n            )\n\n        if hasattr(self, \"qk_norm\"):  # the 128E model does not use qk_norm\n            query_states = self.qk_norm(query_states)\n            key_states = self.qk_norm(key_states)\n\n        kv_cache.store(\n            key=key_states,\n            value=value_states,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers\n        if self.attn_temperature_tuning and not self.use_rope:\n            attn_scales = (\n                torch.log(\n                    torch.floor((position_ids.float() + 1.0) / self.floor_scale) + 1.0\n                )\n                * self.attn_scale\n                + 1.0\n            )\n            attn_scales = attn_scales.view(*input_shape, 1, 1)\n            query_states = (query_states * attn_scales).to(query_states.dtype)\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            query = query_states.view(bs, -1, self.num_heads, self.head_dim).transpose(\n                1, 2\n            )\n            key = key_states.view(\n                bs, -1, self.num_key_value_heads, self.head_dim\n            ).transpose(1, 2)\n            value = value_states.view(\n                bs, -1, self.num_key_value_heads, self.head_dim\n            ).transpose(1, 2)\n            key = repeat_kv(key, self.num_key_value_groups)\n            value = repeat_kv(value, self.num_key_value_groups)\n\n            causal_mask = attention_mask\n            if attention_mask is not None and causal_mask.ndim == 4:\n                causal_mask = causal_mask[:, :, :, : key.shape[-2]]\n            is_causal = query.shape[2] > 1 and causal_mask is None\n            # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions\n            # Reference: https://github.com/pytorch/pytorch/issues/112577.\n            query = query.contiguous()\n            key = key.contiguous()\n            value = value.contiguous()\n\n            attn_output = torch.nn.functional.scaled_dot_product_attention(\n                query,\n                key,\n                value,\n                attn_mask=causal_mask,\n                dropout_p=0,\n                scale=self.scaling,\n                is_causal=is_causal,\n            )\n            attn_output = attn_output.transpose(1, 2).contiguous()\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query_states,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output, adapter_data)\n        return attn_output\n\n\nclass Llama4TextDecoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights, layer_idx):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Llama4TextAttention(\n            f\"{prefix}.self_attn\", config, weights, layer_idx\n        )\n        self.use_chunked_attention = int((layer_idx + 1) % 4 != 0)  # <=> use rope\n        self.is_moe_layer = layer_idx in config.moe_layers\n        if self.is_moe_layer:  # the 128E model interleaves dense / sparse\n            self.feed_forward = Llama4TextMoe(f\"{prefix}.feed_forward\", config, weights)\n        else:\n            self.feed_forward = Llama4TextMLP(f\"{prefix}.feed_forward\", config, weights)\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        freqs_ci,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        attention_mask: Optional[torch.Tensor] = None,\n        chunk_causal_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        residual = hidden_states\n        hidden_states, _ = self.input_layernorm(hidden_states)\n\n        # use local attention mask for ROPE layers\n        if self.use_chunked_attention and chunk_causal_mask is not None:\n            attention_mask = chunk_causal_mask\n\n        attention_states = self.self_attn(\n            hidden_states,\n            freqs_ci,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n\n        hidden_states = residual + attention_states\n\n        # Fully Connected\n        residual = hidden_states\n\n        hidden_states, _ = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.feed_forward(hidden_states, adapter_data)\n        hidden_states = residual + hidden_states.view(residual.shape)\n        return hidden_states\n\n\nclass Llama4TextModel(nn.Module):\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.layers = nn.ModuleList(\n            [\n                Llama4TextDecoderLayer(\n                    prefix=f\"{prefix}.layers.{layer_idx}\",\n                    config=config,\n                    weights=weights,\n                    layer_idx=layer_idx,\n                )\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n\n        # self.norm = Llama4TextRMSNorm(prefix=f\"{prefix}.norm\", config=config, weights=weights)\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.rotary_emb = Llama4TextRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n\n        hidden_states = inputs_embeds\n        bs = seqlen.input_lengths.shape[0]\n        seq_len = inputs_embeds.shape[0] / bs\n        cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask, chunk_causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds.view(bs, int(seq_len), -1),\n            cache_position,\n            None,\n            output_attentions=False,\n            use_cache=False,\n        )\n\n        freqs_ci = self.rotary_emb(hidden_states, position_ids.view(bs, -1))\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n\n        for i, layer in enumerate(self.layers):\n            hidden_states = layer(\n                hidden_states,\n                freqs_ci,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                attention_mask=causal_mask,\n                chunk_causal_mask=chunk_causal_mask,\n                position_ids=position_ids,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states)\n\n        return hidden_states\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool = False,\n        chunked_attention_mask=None,\n        use_cache=True,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and (attention_mask == 0.0).any():\n                return (\n                    attention_mask,\n                    attention_mask,\n                )  # flash does not support chunked attn TODO support flash\n            return None, None\n\n        if self.config._attn_implementation not in [\"sdpa\", \"flex_attention\", \"eager\"]:\n            return None, None\n\n        sequence_length = input_tensor.shape[1]\n        attention_chunk_size = self.config.attention_chunk_size\n\n        first_cache_position = cache_position[0]\n\n        if past_key_values is not None:\n            full_cache_length = past_key_values.get_max_cache_shape() or sequence_length\n        else:\n            full_cache_length = (\n                attention_mask.shape[-1]\n                if attention_mask is not None\n                else sequence_length\n            )\n\n        cond1 = first_cache_position >= attention_chunk_size\n        cond2 = (first_cache_position < attention_chunk_size) & (\n            first_cache_position + sequence_length > attention_chunk_size\n        )\n        key_length = (\n            torch.where(\n                cond1,\n                attention_chunk_size + sequence_length - 1,\n                torch.where(\n                    cond2, first_cache_position + sequence_length, attention_chunk_size\n                ),\n            )\n            if use_cache\n            else full_cache_length\n        )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        dtype, device = input_tensor.dtype, input_tensor.device\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=max(full_cache_length, attention_chunk_size),\n            dtype=dtype,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            device=device,\n        )\n        if full_cache_length > self.config.attention_chunk_size:\n            start_idx = max(first_cache_position - attention_chunk_size + 1, 0)\n            end_idx = start_idx + key_length\n            chunked_attention_mask = self.create_chunked_attention_mask(\n                self.config.attention_chunk_size,\n                start=start_idx,  # same offset as with flex\n                end=end_idx,\n                device=device,\n            )\n\n            local_attention_mask = attention_mask[\n                :, start_idx:end_idx\n            ]  # offset here as well\n            # It may be smaller than attention_chunk_size -> pad it\n            requires_padding = local_attention_mask.shape[-1] < attention_chunk_size\n            if requires_padding:\n                local_attention_mask = nn.functional.pad(\n                    local_attention_mask,\n                    (0, attention_chunk_size - local_attention_mask.shape[-1]),\n                )\n            # Depending on the padding, take the query tokens from the end or the cache_position\n            if not requires_padding:\n                chunked_attention_mask = chunked_attention_mask[\n                    None, None, -sequence_length:, :\n                ]\n            else:\n                chunked_attention_mask = chunked_attention_mask[\n                    None, None, cache_position, :\n                ]\n\n            chunked_attention_mask = chunked_attention_mask.expand(\n                input_tensor.shape[0], -1, -1, -1\n            )\n            chunked_attention_mask = (\n                chunked_attention_mask * local_attention_mask[:, None, None, :]\n            )\n            if self.config._attn_implementation == \"eager\":\n                min_dtype = torch.finfo(dtype).min\n                chunked_attention_mask = torch.where(\n                    chunked_attention_mask == 0, min_dtype, 0.0\n                ).to(dtype)\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\", \"npu\"]\n            and attention_mask.ndim == 4\n            and not output_attentions  # Only unmask for 4d masks\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and chunked_attention_mask is not None\n        ):\n            chunked_attention_mask = chunked_attention_mask.bool()\n            causal_mask = causal_mask.bool()\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=first_cache_position,\n                is_training=self.training,\n            ):\n                causal_mask = None\n        return causal_mask, chunked_attention_mask\n\n    def create_chunked_attention_mask(\n        self, attention_chunk_size: int, start: int, end: int, device: torch.device\n    ) -> torch.Tensor:\n        \"\"\"\n        Generate the following:\n\n        'What'      :  0 ■ ⬚ ⬚ ⬚ ⬚ ⬚    |\n        '▁is'       :  1 ■ ■ ⬚ ⬚ ⬚ ⬚     |\n        '▁ch'       :  2 ■ ■ ■ ⬚ ⬚ ⬚     |\n        'unked'     :  3 ⬚ ⬚ ⬚ ■ ⬚ ⬚    |\n        '▁attention':  4 ⬚ ⬚ ⬚ ■ ■ ⬚    |\n        '?'         :  5 ⬚ ⬚ ⬚ ■ ■ ■     |\n\n        If the chunk size is 3.\n        This can just be applied over the already created attention mask\n        \"\"\"\n        arange_vector = torch.arange(start, end, device=device)\n        block_pos = torch.abs(\n            arange_vector.unsqueeze(0) // attention_chunk_size\n            - arange_vector.unsqueeze(1) // attention_chunk_size\n        )\n        token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1)\n        mask = (block_pos == 0) & (token_pos <= 0)\n        return mask.to(device)\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        **kwargs,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape\n                `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache,\n                to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to place the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(\n                target_length, device=device\n            ) > cache_position.to(device).reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[\n                    :, None, None, :\n                ].to(device)\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n\n        return causal_mask\n\n\nclass Llama4ForCausalLM(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.model = Llama4TextModel(\n            prefix=f\"{prefix}.model\", config=config, weights=weights\n        )\n        self.vocab_size = config.vocab_size\n        self.lm_head = SpeculativeHead.load(\n            config,\n            f\"{prefix}.lm_head\",\n            weights,\n        )\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        adapter_data: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data=adapter_data,\n            hpu_attention_meta=hpu_attention_meta,\n            attention_mask=attention_mask,\n        )\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n\n\nclass Llama4VisionMLP2(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.fc1 = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.fc1\", weights=weights, bias=False\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.fc2\", weights=weights, bias=False\n        )\n        self.activation_fn = nn.GELU()  # ACT2FN[config.hidden_act]\n        self.dropout = config.projector_dropout\n\n    def forward(self, hidden_states):\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)\n        hidden_states = self.fc2(hidden_states)\n        return self.activation_fn(\n            hidden_states\n        )  # TODO: check if we need to apply activation again\n\n\nclass Llama4MultiModalProjector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.linear_1 = FastLinear.load(\n            config=config, prefix=f\"{prefix}.linear_1\", weights=weights, bias=False\n        )\n\n    def forward(self, image_features):\n        hidden_states = self.linear_1(image_features)\n        return hidden_states\n\n\ndef pixel_shuffle(input_tensor, shuffle_ratio):\n    # input_tensor: [batch_size, num_patches, channels]\n    batch_size, num_patches, channels = input_tensor.shape\n    patch_size = int(math.sqrt(num_patches))\n\n    input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)\n    batch_size, height, width, channels = input_tensor.size()\n    reshaped_tensor = input_tensor.view(\n        batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)\n    )\n    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n    reshaped_tensor = reshaped_tensor.view(\n        batch_size,\n        int(height * shuffle_ratio),\n        int(width * shuffle_ratio),\n        int(channels / (shuffle_ratio**2)),\n    )\n    reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()\n\n    output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])\n    return output_tensor\n\n\nclass Llama4VisionPixelShuffleMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.pixel_shuffle_ratio = config.pixel_shuffle_ratio\n        self.inner_dim = int(\n            config.projector_input_dim // (self.pixel_shuffle_ratio**2)\n        )\n        self.output_dim = config.projector_output_dim\n        self.mlp = Llama4VisionMLP2(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n    def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:\n        encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)\n        return self.mlp(encoded_patches)\n\n\n# TODO there is a different RoPE for vision encoder, defined as below\ndef vision_reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):\n    ndim = query.ndim\n    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]\n    return freqs_ci.view(*shape)\n\n\nclass Llama4VisionAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads // weights.process_group.size()\n        self.progress_group = weights.process_group\n\n        self.head_dim = config.hidden_size // config.num_attention_heads\n        self.num_key_value_groups = 1\n        self.attention_dropout = config.attention_dropout\n        self.qkv_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_ci: torch.Tensor,  # Now takes (cos_theta, sin_theta) instead of complex\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        qkv = self.qkv_proj(hidden_states)\n\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n            ],\n            dim=2,\n        )\n        query_states = query_states.view(hidden_shape)\n        key_states = key_states.view(hidden_shape)\n        value_states = value_states.view(hidden_shape)\n\n        query_states, key_states = apply_rotary_emb(\n            query_states, key_states, freqs_ci=freqs_ci\n        )\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=attention_mask,\n            is_causal=False,\n            dropout_p=0,\n        )\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Llama4VisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = nn.GELU()  # ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Llama4VisionEncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = Llama4VisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = Llama4VisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=1e-05\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\", weights=weights, eps=1e-05\n        )\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        freqs_ci: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        # Self Attention\n        residual = hidden_state\n\n        hidden_state = self.input_layernorm(hidden_state)\n\n        hidden_state = self.self_attn(\n            hidden_state,\n            freqs_ci=freqs_ci,\n            attention_mask=attention_mask,\n        )\n\n        hidden_state = residual + hidden_state\n\n        # Feed forward\n        residual = hidden_state\n        hidden_state = self.post_attention_layernorm(hidden_state)\n        hidden_state = self.mlp(hidden_state)\n        hidden_state = residual + hidden_state\n        outputs = (hidden_state,)\n        return outputs\n\n\nclass Llama4VisionEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`Llama4VisionEncoderLayer`].\n\n    Args:\n        config: Llama4VisionConfig\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                Llama4VisionEncoderLayer(\n                    prefix=f\"{prefix}.layers.{layer_id}\", config=config, weights=weights\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.gradient_checkpointing = False\n        self.config = config\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        freqs_ci: torch.Tensor,  # TODO move this to an attribute instead of keeping it around\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Union[Tuple, BaseModelOutput]:\n\n        for encoder_layer in self.layers:\n            layer_outputs = encoder_layer(\n                hidden_state=hidden_states,\n                attention_mask=attention_mask,\n                freqs_ci=freqs_ci,\n            )\n\n            hidden_states = layer_outputs[0]\n\n        return hidden_states\n\n\nclass Llama4UnfoldConvolution(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        kernel_size = config.patch_size\n        if isinstance(kernel_size, int):\n            kernel_size = (kernel_size, kernel_size)\n        self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)\n        self.linear = FastLinear.load(\n            config=config, prefix=f\"{prefix}.linear\", weights=weights, bias=False\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.unfold(hidden_states)\n        hidden_states = hidden_states.permute(0, 2, 1)\n        hidden_states = self.linear(hidden_states)\n        return hidden_states\n\n\nclass Llama4VisionRotaryEmbedding(nn.Module):\n    def __init__(self, config, weights):\n        super().__init__()\n        # Calculate image grid indices\n        idx = config.image_size // config.patch_size\n        img_idx = torch.arange(\n            idx**2, dtype=torch.int32, device=weights.device\n        ).reshape(idx**2, 1)\n        img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)\n\n        img_idx[-1, -1] = -2  # ID_CLS_TOKEN\n        # Calculate x and y coordinates\n        frequencies_x = img_idx % idx  # x coordinates\n        frequencies_y = torch.div(img_idx, idx, rounding_mode=\"floor\")  # y coordinates\n        # Calculate frequency components\n        freq_dim = config.hidden_size // config.num_attention_heads // 2\n        rope_freq = 1.0 / (\n            config.rope_theta\n            ** (\n                torch.arange(0, freq_dim, 2, device=weights.device)[\n                    : (freq_dim // 2)\n                ].float()\n                / freq_dim\n            )\n        )\n\n        # Compute frequencies for x and y directions\n        freqs_x = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]\n        freqs_x = freqs_x.repeat_interleave(2, dim=-1)\n        freqs_y = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]\n        freqs_y = freqs_y.repeat_interleave(2, dim=-1)\n\n        # Combine frequencies and mask special tokens\n        freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]\n        freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)\n\n        freq_cis = torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)\n        self.freqs_ci = freq_cis  # idx**2, idx**2, idx * 2\n\n    def forward(self, hidden_states):\n        \"\"\"\n        Returns the rotary embedding components (cosθ, sinθ) for the given hidden states\n        \"\"\"\n        return self.freqs_ci.to(dtype=hidden_states.dtype, device=hidden_states.device)\n\n\nclass Llama4VisionModel(nn.Module):\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.hidden_size = config.hidden_size\n        self.num_channels = config.num_channels\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1\n        self.scale = config.hidden_size**-0.5\n\n        self.patch_embedding = Llama4UnfoldConvolution(\n            prefix=f\"{prefix}.patch_embedding\", config=config, weights=weights\n        )\n\n        self.class_embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.class_embedding\"), requires_grad=False\n        )\n\n        self.positional_embedding_vlm = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.positional_embedding_vlm\"),\n            requires_grad=False,\n        )\n\n        self.rotary_embedding = Llama4VisionRotaryEmbedding(config, weights)\n\n        # layer norms\n        self.layernorm_pre = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_pre\", weights=weights, eps=config.norm_eps\n        )\n        self.layernorm_post = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_post\", weights=weights, eps=config.norm_eps\n        )\n\n        # encoders\n        self.model = Llama4VisionEncoder(\n            prefix=f\"{prefix}.model\", config=config, weights=weights\n        )\n        self.vision_adapter = Llama4VisionPixelShuffleMLP(\n            prefix=f\"{prefix}.vision_adapter\", config=config, weights=weights\n        )\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        return_dict: Optional[bool] = None,\n    ):\n        # num_concurrent_media and num_chunks are both currently 1\n        batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape\n        num_concurrent_media = 1\n        num_chunks = 1\n        hidden_state = self.patch_embedding(pixel_values)\n        _, num_patches, hidden_dim = hidden_state.shape\n\n        # Add cls token\n        hidden_state = hidden_state.reshape(\n            batch_size_times_num_tiles * num_concurrent_media * num_chunks,\n            num_patches,\n            hidden_dim,\n        )\n        class_embedding = self.class_embedding.expand(\n            hidden_state.shape[0], 1, hidden_state.shape[-1]\n        )\n        hidden_state = torch.cat([hidden_state, class_embedding], dim=1)\n        num_patches += 1\n\n        # Position embeddings\n        hidden_state = hidden_state.reshape(\n            batch_size_times_num_tiles * num_concurrent_media,\n            num_chunks,\n            num_patches,\n            hidden_dim,\n        )\n        positional_embedding = self.positional_embedding_vlm.to(\n            dtype=hidden_state.dtype, device=hidden_state.device\n        )\n        hidden_state = hidden_state + positional_embedding\n        hidden_state = self.layernorm_pre(hidden_state)\n        hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)\n        freqs_ci = self.rotary_embedding(pixel_values)\n\n        hidden_state = self.model(\n            hidden_state,\n            attention_mask=None,\n            freqs_ci=freqs_ci,\n        )\n\n        hidden_state = self.layernorm_post(hidden_state)\n\n        hidden_state = hidden_state[:, :-1, :]\n\n        # now, we use Llama4VisionPixelShuffle + mlp to project embeddings\n        hidden_state = self.vision_adapter(hidden_state)\n        return hidden_state\n\n\nclass Llama4ForConditionalGeneration(nn.Module):\n\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.config = config\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        config.text_config._attn_implementation = None\n\n        self.vision_model = Llama4VisionModel(\n            prefix=\"vision_model\", config=config.vision_config, weights=weights\n        )\n\n        self.multi_modal_projector = Llama4MultiModalProjector(\n            prefix=\"multi_modal_projector\", config=config, weights=weights\n        )\n\n        self.text_model = Llama4ForCausalLM(\n            prefix=\"language_model\", config=config.text_config, weights=weights\n        )\n        self.vocab_size = config.text_config.vocab_size\n        self.pad_token_id = (\n            self.config.pad_token_id if self.config.pad_token_id is not None else -1\n        )\n        self.config = config\n        self.dtype = weights.dtype\n        self.device = weights.device\n\n    def get_image_features(\n        self,\n        pixel_values: torch.FloatTensor,\n        vision_feature_layer: Union[int, List[int]],\n        vision_feature_select_strategy: str,\n        **kwargs,\n    ):\n        \"\"\"\n        Obtains image last hidden states from the vision tower and apply al projection.\n\n        Args:\n            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)\n               The tensors corresponding to the input images.\n            vision_feature_layer (`Union[int, List[int]]`):\n                The index of the layer to select the vision feature. If multiple indices are provided,\n                the vision feature of the corresponding indices will be concatenated to form the\n                vision features.\n            vision_feature_select_strategy (`str`):\n                The feature selection strategy used to select the vision feature from the vision backbone.\n                Can be one of `\"default\"` or `\"full\"`\n        Returns:\n            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).\n        \"\"\"\n        if vision_feature_select_strategy not in [\"default\", \"full\"]:\n            raise ValueError(\n                f\"Unexpected select feature strategy: {self.vision_feature_select_strategy}\"\n            )\n        kwargs = {k: v for k, v in kwargs.items() if v is not None}\n        hidden_state = self.vision_model(pixel_values)\n        return hidden_state\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_features = self.get_image_features(\n            pixel_values=pixel_values,\n            vision_feature_layer=self.config.vision_config.vision_feature_layer,\n            vision_feature_select_strategy=self.config.vision_config.vision_feature_select_strategy,\n            image_sizes=image_sizes,\n        )\n        vision_flat = image_features.view(-1, image_features.size(-1))\n        image_features = self.multi_modal_projector(vision_flat)\n        return image_features\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n        pixel_values: torch.FloatTensor = None,\n        image_sizes: Optional[torch.LongTensor] = None,\n    ):\n        inputs_embeds = self.text_model.model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            original_inputs_embeds_shape = inputs_embeds.shape\n            special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(\n                -1\n            )\n            final_mask = special_image_mask.to(inputs_embeds.device)\n            inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))\n\n            final_mask_1d = final_mask[..., 0].reshape(-1)\n            num_tokens_to_fill = final_mask_1d.sum()\n\n            if num_tokens_to_fill != vision_embeds.size(0):\n                raise ValueError(\n                    f\"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, \"\n                    f\"but multi_modal_projector returned {vision_embeds.size(0)}\"\n                )\n\n            expanded_mask = final_mask_1d.unsqueeze(-1).expand(\n                -1, inputs_embeds.size(-1)\n            )\n            inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)\n            inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: Optional[torch.LongTensor] = None,\n        cu_seqlen_prefill: Optional[torch.Tensor] = None,\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = None,\n        slots: torch.Tensor = None,\n        seqlen: Seqlen = None,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        **lm_kwargs,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n\n        logits, speculative_logits = self.text_model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n            adapter_data,\n            lm_head_indices,\n            attention_mask,\n        )\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom contextlib import contextmanager\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nimport habana_frameworks.torch as htorch\nfrom text_generation_server.layers.attention import (\n    KVCache,\n    get_kv_scales,\n)\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n    FastLayerNorm,\n)\nfrom text_generation_server.layers import (\n    FastLinear,\n)\nfrom text_generation_server.utils.weights import (\n    Weights,\n)\nfrom text_generation_server.layers.fp8 import HybridFP8UnquantLoader\n\n\ndef load_attention(config, prefix: str, weights, layer_id):\n    # Only defined in granite.\n    bias = getattr(config, \"attention_bias\", False)\n    head_size = config.hidden_size // config.num_attention_heads\n    sizes = None\n    prefixes = None\n\n    if config.model_type == \"phi3\":\n        base_layer = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv_proj\",\n            weights=weights,\n            bias=bias,\n            num_heads=config.num_attention_heads,\n            num_key_value_heads=config.num_key_value_heads,\n        )\n        prefixes = [\"qkv_proj\"]\n    elif config.model_type == \"baichuan\":\n        prefix = f\"{prefix}.W_pack\"\n        base_layer = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=prefix,\n            weights=weights,\n            bias=bias,\n            num_heads=config.num_attention_heads,\n            num_key_value_heads=config.num_key_value_heads,\n        )\n        prefixes = [prefix]\n    else:\n        prefixes = [\"q_proj\", \"k_proj\", \"v_proj\"]\n        sizes = [\n            head_size * config.num_attention_heads,\n            head_size * config.num_key_value_heads,\n            head_size * config.num_key_value_heads,\n        ]\n        base_layer = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=bias,\n        )\n\n    return TensorParallelMultiAdapterLinear.load(\n        base_layer=base_layer,\n        layer_id=layer_id,\n        layer_names=prefixes,\n        sizes=sizes,\n        process_group=weights.process_group,\n    )\n\n\n@contextmanager\ndef no_fp8(weights: Weights):\n    \"\"\"De-activate fp8 auto conversion for the duration of this context manager\"\"\"\n    weights_loader = weights.weights_loader\n    if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:\n        weights_loader = HybridFP8UnquantLoader(\n            weights_loader.activation_scale_ub, to_fp8=False\n        )\n\n    with weights.use_loader(weights_loader):\n        yield\n\n\nclass FlashLlamaAttention(torch.nn.Module):\n    def __init__(\n        self,\n        index: int,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = rotary_emb\n\n        # `config.attention_multiplier` is used in Granite\n        self.softmax_scale = getattr(\n            config, \"attention_multiplier\", self.head_size**-0.5\n        )\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        if config.num_key_value_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights, index)\n        self.index = index\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=getattr(config, \"attention_bias\", False),\n        )\n\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            index,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache: KVCache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_scales=self.kv_scales,\n                kv_cache=kv_cache,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Phi3MoE(nn.Module):\n    def __init__(\n        self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights\n    ):\n        super().__init__()\n\n        # gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.num_local_experts,\n            n_expert_group=None,\n            renormalize=True,\n            topk=config.num_experts_per_tok,\n            topk_group=None,\n            weights=weights,\n            gate_proj_name=\"w1\",\n            up_proj_name=\"w3\",\n            down_proj_name=\"w2\",\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(self, x, adapter_data) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n        out = self.moe(x, gating_output=router_logits)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, prefix, config, weights, index):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        self.act = (\n            ACT2FN[self.hidden_act]\n            if \"gelu\" not in self.hidden_act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if self.hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        )\n        prefixes = None\n        sizes = None\n\n        # Fuse gate and up proj\n        bias = getattr(config, \"mlp_bias\", False)\n        if config.model_type == \"phi3\":\n            gate_up_proj = TensorParallelColumnLinear.load_gate_up(\n                config,\n                prefix=f\"{prefix}.gate_up_proj\",\n                weights=weights,\n                bias=bias,\n            )\n        else:\n            prefixes = [\"gate_proj\", \"up_proj\"]\n            sizes = [\n                config.intermediate_size,\n                config.intermediate_size,\n            ]\n            gate_up_proj = TensorParallelColumnLinear.load_multi(\n                config,\n                prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n                weights=weights,\n                dim=0,\n                bias=bias,\n            )\n\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            index,\n            layer_names=prefixes,\n            sizes=sizes,\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=bias,\n        )\n\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            index,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n        self.hidden_size = config.hidden_size\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass FlashLlamaLayer(nn.Module):\n    def __init__(self, index, prefix, config, weights, rotary_emb):\n        super().__init__()\n\n        with no_fp8(weights):\n            self.self_attn = FlashLlamaAttention(\n                index=index,\n                prefix=f\"{prefix}.self_attn\",\n                config=config,\n                weights=weights,\n                rotary_emb=rotary_emb,\n            )\n\n        if config.model_type == \"phimoe\":\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = Phi3MoE(\n                f\"{prefix}.block_sparse_moe\", config, moe_layer_cls, weights\n            )\n            # with moe the layernorms are are not rmsnorms and they have bias\n            self.input_layernorm = FastLayerNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.post_attention_layernorm = FastLayerNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n        else:\n            self.mlp = LlamaMLP(\n                prefix=f\"{prefix}.mlp\", config=config, weights=weights, index=index\n            )\n            self.input_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.post_attention_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n\n        # Used in Granite\n        # This could eventually be baked into the weights like we do for the embeddings/lm_head\n        # but this would mean modifying the lora code\n        self.residual_multiplier = getattr(config, \"residual_multiplier\", None)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        cross_attention_states,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if self.residual_multiplier is not None:\n            attn_output *= self.residual_multiplier\n\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n        if self.residual_multiplier is not None:\n            mlp_output *= self.residual_multiplier\n\n        return mlp_output, attn_res\n\n\nclass FlashLlamaModel(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        # Skip fp8 quant for first and last layers\n        self.layers = nn.ModuleList()\n        self.cross_attention_layers = getattr(config, \"cross_attention_layers\", [])\n        # Setting defaults for baichuan custom config which doesn't apply them.\n        config.rope_theta = getattr(config, \"rope_theta\", 10000)\n        config.num_key_value_heads = getattr(\n            config, \"num_key_value_heads\", config.num_attention_heads\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.hidden_size // config.num_attention_heads,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n        with no_fp8(weights):\n            self.layers.append(\n                FlashLlamaLayer(\n                    index=0,\n                    prefix=f\"{prefix}.layers.0\",\n                    config=config,\n                    weights=weights,\n                    rotary_emb=rotary_emb,\n                )\n            )\n\n        # Skip first and last layers\n        for layer_id in range(1, config.num_hidden_layers - 1):\n            if layer_id in self.cross_attention_layers:\n                from text_generation_server.models.custom_modeling.flash_mllama import (\n                    FlashLlamaCrossLayer,\n                )\n\n                self.layers.append(\n                    FlashLlamaCrossLayer(\n                        index=layer_id,\n                        prefix=(f\"{prefix}.layers.{layer_id}\"),\n                        config=config,\n                        weights=weights,\n                    )\n                )\n            else:\n                self.layers.append(\n                    FlashLlamaLayer(\n                        index=layer_id,\n                        prefix=(f\"{prefix}.layers.{layer_id}\"),\n                        config=config,\n                        weights=weights,\n                        rotary_emb=rotary_emb,\n                    )\n                )\n\n        with no_fp8(weights):\n            last_layer_id = config.num_hidden_layers - 1\n            self.layers.append(\n                FlashLlamaLayer(\n                    index=last_layer_id,\n                    prefix=(f\"{prefix}.layers.{last_layer_id}\"),\n                    config=config,\n                    weights=weights,\n                    rotary_emb=rotary_emb,\n                )\n            )\n\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        cross_attention_states=None,\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                cross_attention_states,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashLlamaForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, name=None):\n        if name is None:\n            name = \"model\"\n        super().__init__()\n        with no_fp8(weights):\n            self.embed_tokens = TensorParallelEmbedding(\n                prefix=(\n                    f\"{name}.embed_tokens\"\n                    if not prefix\n                    else f\"{prefix}.{name}.embed_tokens\"\n                ),\n                weights=weights,\n            )\n        self.model = FlashLlamaModel(\n            prefix=name if not prefix else f\"{prefix}.{name}\",\n            config=config,\n            weights=weights,\n        )\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        # Used in Granite\n        embedding_multiplier = getattr(config, \"embedding_multiplier\", None)\n        if embedding_multiplier is not None:\n            self.embed_tokens.weight.data *= embedding_multiplier\n        prefix = suffix if not prefix or name != \"model\" else f\"{prefix}.{suffix}\"\n        with no_fp8(weights):\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix,\n                weights,\n            )\n\n        # Used in Granite\n        self.logits_scaling = getattr(config, \"logits_scaling\", None)\n        if self.logits_scaling is not None and self.lm_head.head is not None:\n            try:\n                # Scale the weights directly\n                self.lm_head.head.linear.weight.data /= self.logits_scaling\n                self.logits_scaled = True\n            except Exception:\n                self.logits_scaled = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        cross_attention_states=None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        inputs_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data=adapter_data,\n            cross_attention_states=cross_attention_states,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        # Used in Granite\n        if self.logits_scaling is not None and not self.logits_scaled:\n            logits /= self.logits_scaling\n            if speculative_logits is not None:\n                speculative_logits /= self.logits_scaling\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llava_next.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Llava-NeXT model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.image_processing_utils import select_best_resolution\n\nfrom text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\ndef get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n    \"\"\"\n    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n\n    Args:\n        image_size (`tuple`):\n            The size of the input image in the format (height, width).\n        grid_pinpoints (`List`):\n            A list containing possible resolutions. Each item in the list should be a tuple or list\n            of the form `(height, width)`.\n        patch_size (`int`):\n            The size of each image patch.\n\n    Returns:\n        tuple: The shape of the image patch grid in the format (height, width).\n    \"\"\"\n    if not isinstance(grid_pinpoints, list):\n        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n\n    height, width = select_best_resolution(image_size, grid_pinpoints)\n    return height // patch_size, width // patch_size\n\n\ndef unpad_image(tensor, original_size):\n    \"\"\"\n    Unpads a PyTorch tensor of a padded and resized image.\n\n    Args:\n        tensor (`torch.Tensor`):\n            The image tensor, assumed to be of shape (num_channels, height, width).\n        original_size (`tuple`):\n            The original size of the image (height, width).\n\n    Returns:\n        `torch.Tensor`: The unpadded image tensor.\n    \"\"\"\n    original_height, original_width = original_size\n    current_height, current_width = tensor.shape[1:]\n\n    original_aspect_ratio = original_width / original_height\n    current_aspect_ratio = current_width / current_height\n\n    if original_aspect_ratio > current_aspect_ratio:\n        scale_factor = current_width / original_width\n        new_height = int(original_height * scale_factor)\n        padding = (current_height - new_height) // 2\n        unpadded_tensor = tensor[:, padding : current_height - padding, :]\n    else:\n        scale_factor = current_height / original_height\n        new_width = int(original_width * scale_factor)\n        padding = (current_width - new_width) // 2\n        unpadded_tensor = tensor[:, :, padding : current_width - padding]\n\n    return unpadded_tensor\n\n\n# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext\nclass LlavaNextMultiModalProjector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.linear_1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.linear_1\", config=config, weights=weights, bias=True\n        )\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.linear_2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.linear_2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, image_features):\n        hidden_states = self.linear_1(image_features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass FlashLlavaNextForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = config.quantize\n        vision_config = config.vision_config\n        # Instead of selecting in hidden_states[-2].\n        # Instead compute only the n -2 + 1 layers and don't pool\n        if config.vision_feature_layer < 0:\n            vision_config.num_hidden_layers += config.vision_feature_layer + 1\n        else:\n            vision_config.num_hidden_layers = config.vision_feature_layer + 1\n        self.vision_tower = load_vision_model(\n            prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n            config=config.vision_config,\n            weights=weights,\n        )\n\n        self.multi_modal_projector = LlavaNextMultiModalProjector(\n            prefix=\"multi_modal_projector\", config=config, weights=weights\n        )\n\n        self.image_newline = weights.get_tensor(\"image_newline\")\n\n        self.vocab_size = config.text_config.vocab_size\n        self.config = config\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        self.text_model = load_text_model(\n            prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n            config=config.text_config,\n            weights=weights,\n        )\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        mask = torch.where(input_ids == self.config.image_token_index)\n        # Let's pray we have enabled enough slots !\n        try:\n            inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        except Exception as e:\n            raise RuntimeError(\n                f\"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens`  to handle images. If error happens at regular runtime, please fill in an issue: {e}\"\n            )\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()\n        # assert num_special_image_tokens == len(pixel_values), f\"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid\"\n        # 1. Extract the input embeddings\n\n        # 2. Merge text and images\n        num_images, num_patches, channels, height, width = pixel_values.shape\n        pixel_values = pixel_values.view(\n            num_images * num_patches, channels, height, width\n        )\n        image_features = self.vision_tower(pixel_values)\n\n        # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]\n        # Already done within the clip model\n        selected_image_feature = image_features.last_hidden_state\n\n        if self.config.vision_feature_select_strategy == \"default\":\n            selected_image_feature = selected_image_feature[:, 1:]\n        elif self.config.vision_feature_select_strategy == \"full\":\n            selected_image_feature = selected_image_feature\n        else:\n            raise RuntimeError(\n                f\"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid.\"\n            )\n\n        image_features = self.multi_modal_projector(selected_image_feature)\n\n        # split up image_features for each of the individual images\n        # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)\n        # if we assume each image has 5 image features (base image + 4 patches)\n        split_sizes = [num_patches] * num_images\n        image_features = torch.split(image_features, split_sizes, dim=0)\n\n        # NOTE we only support multimodal_patch_merge_type == \"spatial_unpad\"\n        height = width = (\n            self.config.vision_config.image_size // self.config.vision_config.patch_size\n        )\n\n        new_image_features = []\n        for image_idx, image_feature in enumerate(image_features):\n            if image_feature.shape[0] > 1:\n                base_image_feature = image_feature[0]\n                image_feature = image_feature[1:]\n\n                if height * width != base_image_feature.shape[0]:\n                    raise ValueError(\n                        \"The number of patches is not consistent with the image size.\"\n                    )\n\n                # Dimensions are intentionally swapped to be bug-compatible with\n                # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59\n                num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n                    image_sizes[image_idx],\n                    self.config.image_grid_pinpoints,\n                    self.config.vision_config.image_size,\n                )\n                image_feature = image_feature.view(\n                    num_patch_height, num_patch_width, height, width, -1\n                )\n                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()\n                image_feature = image_feature.flatten(1, 2).flatten(2, 3)\n                image_feature = unpad_image(image_feature, image_sizes[image_idx])\n                image_feature = torch.cat(\n                    (\n                        image_feature,\n                        self.image_newline[:, None, None].expand(\n                            *image_feature.shape[:-1], 1\n                        ),\n                    ),\n                    dim=-1,\n                )\n                image_feature = image_feature.flatten(1, 2).transpose(0, 1)\n                image_feature = torch.cat((base_image_feature, image_feature), dim=0)\n            else:\n                image_feature = image_feature[0]\n                image_feature = torch.cat(\n                    (image_feature, self.image_newline[None]), dim=0\n                )\n            new_image_features.append(image_feature)\n        image_features = torch.stack(new_image_features, dim=0)\n        return image_features.view(-1, image_features.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n        pixel_values: torch.FloatTensor = None,\n        image_sizes: Optional[torch.LongTensor] = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nimport habana_frameworks.torch as htorch\n\n\nclass MistralConfig(PretrainedConfig):\n    model_type = \"mistral\"\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        sliding_window=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass MistralAttention(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n\n        if getattr(config, \"head_dim\", None) is not None:\n            self.head_size = config.head_dim\n        else:\n            self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        query_key_value = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass MistralMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        self.act = (\n            ACT2FN[self.hidden_act]\n            if \"gelu\" not in self.hidden_act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if self.hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass MistralLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id, rotary_emb):\n        super().__init__()\n        self.self_attn = MistralAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = MistralMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n\n        return mlp_output, attn_res\n\n\nclass MistralModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        if getattr(config, \"head_dim\", None) is not None:\n            head_dim = config.head_dim\n        else:\n            head_dim = config.hidden_size // config.num_attention_heads\n\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                MistralLayer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n        return hidden_states\n\n\nclass FlashMistralForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, name=None):\n        if name is None:\n            name = \"model\"\n        super().__init__()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\n                f\"{name}.embed_tokens\"\n                if not prefix\n                else f\"{prefix}.{name}.embed_tokens\"\n            ),\n            weights=weights,\n        )\n        self.model = MistralModel(\n            prefix=name if not prefix else f\"{prefix}.{name}\",\n            config=config,\n            weights=weights,\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            # TODO dirty hack for idefics2.\n            prefix=(\n                \"lm_head\" if not prefix or name != \"model\" else f\"{prefix}.lm_head\"\n            ),\n            weights=weights,\n        )\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        inputs_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention,\n    set_block_mapping,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.utils.weights import UnquantizedWeight\nimport habana_frameworks.torch as htorch\n\n\nclass MixtralConfig(PretrainedConfig):\n    model_type = \"mixtral\"\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-05,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        sliding_window=None,\n        num_experts_per_tok=2,\n        num_local_experts=8,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_local_experts = num_local_experts\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\ndef promote_scalar(x: torch.Tensor) -> torch.Tensor:\n    return x.view(1) if len(x.size()) == 0 else x\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\ndef _load_experts(config, prefix: str, mat, weights):\n    if config.quantize is not None:\n        raise NotImplementedError(\"Mixtral does not support weight quantization yet.\")\n\n    assert mat in [\"w1\", \"w2\", \"w3\"]\n\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.intermediate_size % world_size == 0\n    ), f\"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards\"\n\n    block_size = config.intermediate_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    tensor = torch.empty(\n        (config.num_local_experts * block_size, config.hidden_size),\n        dtype=weights.dtype,\n        device=weights.device,\n    )\n\n    for i in range(config.num_local_experts):\n        slice_ = weights._get_slice(f\"{prefix}.{i}.{mat}.weight\")\n\n        if mat == \"w2\":\n            expert_slice = slice_[:, start:stop].t().contiguous()\n        else:\n            expert_slice = slice_[start:stop]\n        tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(\n            dtype=weights.dtype\n        ).to(device=weights.device)\n    return tensor\n\n\nclass MixtralAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\n@torch.jit.script\ndef select_experts(gate_logits: torch.Tensor, top_k: int):\n    # all_probs: (sequence_length, n_experts) and upcast for softmax\n    all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n    # weights, selected_experts: (sequence_length, top-k)\n    weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)\n    weights /= weights.sum(dim=-1, keepdim=True)\n    weights = weights.view(-1)\n    selected_experts = selected_experts.view(-1)\n\n    return selected_experts, weights\n\n\n@torch.jit.script\ndef round_up(x: torch.Tensor, value: int):\n    return torch.div(x + (value - 1), value, rounding_mode=\"trunc\") * value\n\n\nclass MixtralMoE(nn.Module):\n    def __init__(\n        self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights\n    ):\n        super().__init__()\n\n        # gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe = moe_layer_cls(\n            n_expert_group=None,\n            n_experts=config.num_local_experts,\n            prefix=f\"{prefix}.experts\",\n            renormalize=True,\n            topk=config.num_experts_per_tok,\n            topk_group=None,\n            weights=weights,\n            gate_proj_name=\"w1\",\n            up_proj_name=\"w3\",\n            down_proj_name=\"w2\",\n        )\n        assert isinstance(self.moe, MoELayer)\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n        out = self.moe(x, gating_output=router_logits)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass MixtralLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = MixtralAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n\n        moe_layer_cls = (\n            SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer\n        )\n        self.moe = MixtralMoE(\n            f\"{prefix}.block_sparse_moe\", config, moe_layer_cls, weights\n        )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        moe_output = self.moe(normed_attn_res_output)\n\n        return moe_output, attn_res\n\n\nclass MixtralModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\n                \"model.embed_tokens\" if not prefix else f\"{prefix}.model.embed_tokens\"\n            ),\n            weights=weights,\n        )\n\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.hidden_size // config.num_attention_heads,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n        self.layers = nn.ModuleList(\n            [\n                MixtralLayer(\n                    \"model\" if not prefix else f\"{prefix}.model\",\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=\"model.norm\" if not prefix else f\"{prefix}.model.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashMixtralForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.model = MixtralModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Mllama model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    FastLinear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.models.custom_modeling.flash_llama_modeling import (\n    FlashLlamaForCausalLM,\n)\nfrom habana_frameworks.torch.hpex.kernels import FusedSDPA\nfrom vllm_hpu_extension.utils import ModuleFusedSDPA\nimport habana_frameworks.torch as htorch\n\n\ndef _prepare_aspect_ratio_attention_mask(\n    aspect_ratio_mask: torch.Tensor,\n    num_patches: int,\n    target_length: int,\n    dtype: torch.dtype,\n) -> torch.Tensor:\n    # Expand aspect ratio mask to target_length\n    batch_size, max_num_tiles = aspect_ratio_mask.shape\n    attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)\n    attention_mask = attention_mask.repeat(1, 1, target_length, 1)\n\n    # Mask padding patches\n    pad_patches = target_length - num_patches\n    attention_mask[:, :, -pad_patches:] = 0\n\n    # Invert the mask (0 -> 1, 1 -> 0)\n    attention_mask = 1 - attention_mask\n\n    # Reshape to 2D and create 4D attention mask\n    # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)\n    attention_mask = attention_mask.reshape(\n        batch_size, max_num_tiles * target_length, 1\n    )\n    attention_mask = (\n        attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min\n    )\n    attention_mask = attention_mask.unsqueeze(1)\n\n    return attention_mask\n\n\n# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position\ndef _prepare_4d_causal_attention_mask_with_cache_position(\n    attention_mask: torch.Tensor,\n    sequence_length: int,\n    target_length: int,\n    dtype: torch.dtype,\n    device: torch.device,\n    min_dtype: float,\n    cache_position: torch.Tensor,\n    batch_size: int,\n):\n    \"\"\"\n    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n    Args:\n        attention_mask (`torch.Tensor`):\n            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n        sequence_length (`int`):\n            The sequence length being processed.\n        target_length (`int`):\n            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n        dtype (`torch.dtype`):\n            The dtype to use for the 4D attention mask.\n        device (`torch.device`):\n            The device to plcae the 4D attention mask on.\n        min_dtype (`float`):\n            The minimum value representable with the dtype `dtype`.\n        cache_position (`torch.Tensor`):\n            Indices depicting the position of the input sequence tokens in the sequence.\n        batch_size (`torch.Tensor`):\n            Batch size.\n    \"\"\"\n    if attention_mask is not None and attention_mask.dim() == 4:\n        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n        causal_mask = attention_mask\n    else:\n        causal_mask = torch.full(\n            (sequence_length, target_length),\n            fill_value=min_dtype,\n            dtype=dtype,\n            device=device,\n        )\n        if sequence_length != 1:\n            causal_mask = torch.triu(causal_mask, diagonal=1)\n        causal_mask *= torch.arange(\n            target_length, device=device\n        ) > cache_position.reshape(-1, 1)\n        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n        if attention_mask is not None:\n            causal_mask = (\n                causal_mask.clone()\n            )  # copy to contiguous memory for in-place edit\n            mask_length = attention_mask.shape[-1]\n            padding_mask = (\n                causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n            )\n            padding_mask = padding_mask == 0\n            causal_mask[:, :, :, :mask_length] = causal_mask[\n                :, :, :, :mask_length\n            ].masked_fill(padding_mask, min_dtype)\n\n    return causal_mask\n\n\ndef _prepare_cross_attention_mask(\n    cross_attention_mask: torch.Tensor,\n    num_vision_tokens: int,\n    dtype: str,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # reshape so it can be used by attn module\n    batch_size, text_total_length, *_ = cross_attention_mask.shape\n    cross_attention_mask = cross_attention_mask.repeat_interleave(\n        num_vision_tokens, dim=3\n    )\n    cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)\n    cross_attention_mask = cross_attention_mask.unsqueeze(1)\n\n    # invert the mask\n    inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)\n    cross_attention_mask = inverted_cross_attn_mask.masked_fill(\n        inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n    # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's\n    # last dimension contains negative infinity values, otherwise it's 1\n    negative_inf_value = torch.finfo(dtype).min\n    full_text_row_masked_out_mask = (\n        (cross_attention_mask != negative_inf_value)\n        .any(dim=-1)\n        .type_as(cross_attention_mask)[..., None]\n    )\n    cross_attention_mask *= full_text_row_masked_out_mask\n\n    return cross_attention_mask, full_text_row_masked_out_mask\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision\nclass MllamaVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass MllamaVisionSdpaAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n\n        self.embed_dim = config.hidden_size\n        self.head_dim = config.hidden_size // config.attention_heads\n        self.num_heads = config.attention_heads // weights.process_group.size()\n\n        self.qkv_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        qkv = self.qkv_proj(hidden_state)\n        query, key, value = qkv.split(\n            [\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        batch_size, q_seq_len, _ = query.shape\n        _, kv_seq_len, _ = key.shape\n\n        query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)\n        key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)\n        value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)\n\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n\n        fsdpa_op = ModuleFusedSDPA(FusedSDPA)\n        attn_output = fsdpa_op(\n            query,\n            key,\n            value,\n            attn_mask=attention_mask,\n            dropout_p=0.0,\n            is_causal=False,\n            scale=None,\n            softmax_mode=\"None\",\n            recompute_mode=None,\n            valid_sequence_lengths=None,\n        )\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_seq_len, -1)\n\n        output = self.o_proj(attn_output)\n        return output\n\n\nclass MllamaVisionEncoderLayer(nn.Module):\n    def __init__(self, *, prefix, config, weights, is_gated: bool):\n        super().__init__()\n\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.attention_heads\n        self.is_gated = is_gated\n        self.intermediate_size = config.intermediate_size\n\n        self.self_attn = MllamaVisionSdpaAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = MllamaVisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=1e-05\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\", weights=weights, eps=1e-05\n        )\n\n        # there used to be an if else here, no code path\n        if is_gated:\n            self.gate_attn = nn.Parameter(\n                weights.get_tensor(f\"{prefix}.gate_attn\"), requires_grad=False\n            )\n            self.gate_ffn = nn.Parameter(\n                weights.get_tensor(f\"{prefix}.gate_ffn\"), requires_grad=False\n            )\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        # Self Attention\n        residual = hidden_state\n        hidden_state = self.input_layernorm(hidden_state)\n        hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)\n        gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()\n        hidden_state = residual + gate_attn * hidden_state\n\n        # Feed forward\n        residual = hidden_state\n        hidden_state = self.post_attention_layernorm(hidden_state)\n        hidden_state = self.mlp(hidden_state)\n        gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()\n        hidden_state = residual + gate_ffn * hidden_state\n        return hidden_state\n\n\nclass MllamaVisionEncoder(nn.Module):\n    def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):\n        super().__init__()\n        self.config = config\n        self.layers = [\n            MllamaVisionEncoderLayer(\n                prefix=f\"{prefix}.layers.{i}\",\n                config=config,\n                weights=weights,\n                is_gated=is_gated,\n            )\n            for i in range(num_layers)\n        ]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        encoder_states = [hidden_states]\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for encoder_layer in self.layers:\n            layer_outputs = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n\n            hidden_states = layer_outputs\n            encoder_states.append(hidden_states)\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        return hidden_states, encoder_states\n\n\nclass MllamaPrecomputedAspectRatioEmbedding(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.max_num_tiles = config.max_num_tiles\n        self.hidden_size = config.hidden_size\n        self.max_aspect_ratio_id = config.max_aspect_ratio_id\n\n        self.embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embedding\", weights=weights\n        )\n        self.gate = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.gate\"), requires_grad=False\n        )\n\n    def forward(\n        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor\n    ) -> torch.Tensor:\n        embeddings = self.embedding(aspect_ratio_ids)\n        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)\n\n        # Always gated.\n        embeddings = embeddings * self.gate.tanh()\n\n        hidden_state = hidden_state + embeddings\n        return hidden_state\n\n\nclass MllamaPrecomputedPositionEmbedding(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.max_num_tiles = config.max_num_tiles\n        self.max_aspect_ratio_id = config.max_aspect_ratio_id\n        self.num_patches = (config.image_size // config.patch_size) ** 2 + 1\n        self.hidden_size = config.hidden_size\n        self.scale = config.hidden_size**-0.5\n\n        self.gate = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.gate\"), requires_grad=False\n        )\n\n        # position embedding\n        embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.embedding\"), requires_grad=False\n        )\n        self.gated_position_embedding = (1 - self.gate.tanh()) * embedding\n        self.tile_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.tile_embedding\", weights=weights\n        )\n\n    def forward(\n        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor\n    ) -> torch.Tensor:\n        # position embeddings\n        hidden_state = hidden_state + self.gated_position_embedding.view(\n            1, 1, self.num_patches, self.hidden_size\n        )\n\n        # precomputed tile position embeddings\n        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)\n        batch_size = hidden_state.shape[0]\n        tile_position_embedding = tile_position_embedding.reshape(\n            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size\n        )\n        gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding\n        hidden_state = hidden_state + gated_tile_position_embedding\n\n        return hidden_state\n\n\nclass MllamaVisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.max_num_tiles = config.max_num_tiles\n        self.hidden_size = config.hidden_size\n        self.num_channels = config.num_channels\n        self.intermediate_layers_indices = config.intermediate_layers_indices\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1\n        self.scale = config.hidden_size**-0.5\n        self.dtype = weights.dtype\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.hidden_size,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n\n        self.class_embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.class_embedding\"), requires_grad=False\n        )\n\n        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(\n            prefix=f\"{prefix}.gated_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n\n        self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(\n            prefix=f\"{prefix}.pre_tile_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n        self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(\n            prefix=f\"{prefix}.post_tile_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n\n        ## layer norms\n        self.layernorm_pre = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_pre\",\n            weights=weights,\n            # torch default\n            eps=1e-05,\n        )\n        self.layernorm_post = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_post\",\n            weights=weights,\n            # torch default\n            eps=1e-05,\n        )\n\n        ## encoders\n        self.transformer = MllamaVisionEncoder(\n            prefix=f\"{prefix}.transformer\",\n            config=config,\n            weights=weights,\n            is_gated=False,\n            num_layers=config.num_hidden_layers,\n        )\n        self.global_transformer = MllamaVisionEncoder(\n            prefix=f\"{prefix}.global_transformer\",\n            config=config,\n            weights=weights,\n            is_gated=True,\n            num_layers=config.num_global_layers,\n        )\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        aspect_ratio_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        (\n            batch_size,\n            num_concurrent_media,\n            num_tiles,\n            num_channels,\n            height,\n            width,\n        ) = pixel_values.shape\n\n        pixel_values = pixel_values.reshape(\n            batch_size * num_concurrent_media * num_tiles, num_channels, height, width\n        )\n        aspect_ratio_ids = aspect_ratio_ids.reshape(\n            batch_size * num_concurrent_media, -1\n        )\n\n        # patch embedding\n        patch_embeds = self.patch_embedding(pixel_values)\n        hidden_state = patch_embeds.flatten(2).transpose(1, 2)\n\n        # tile embeddings\n        _, num_patches, dim = hidden_state.shape\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media, num_tiles, -1, dim\n        )\n        hidden_state = self.pre_tile_positional_embedding(\n            hidden_state, aspect_ratio_ids\n        )\n\n        # apply cls token\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media * num_tiles, num_patches, dim\n        )\n        hidden_state = self.apply_class_embedding(hidden_state)\n        num_patches += 1\n\n        # apply position embeddings\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media, num_tiles, num_patches, dim\n        )\n        hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)\n\n        # apply encoder\n        hidden_state = self.layernorm_pre(hidden_state)\n\n        # Compute the number of tokens to pad\n        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8\n        # Compute padding tuple for pad function\n        padding = (\n            0,\n            0,\n            0,\n            num_padding_patches,\n        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)\n        # Pad the tensor\n        hidden_state = F.pad(hidden_state, padding, mode=\"constant\", value=0)\n        slice_index = -num_padding_patches if num_padding_patches > 0 else None\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.reshape(\n                batch_size * num_concurrent_media, -1\n            )\n            attention_mask = _prepare_aspect_ratio_attention_mask(\n                aspect_ratio_mask=attention_mask,\n                num_patches=self.num_patches,\n                target_length=hidden_state.shape[2],\n                dtype=self.dtype,\n            )\n\n        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)\n        hidden_state, all_intermediate_hidden_states = self.transformer(\n            hidden_state,\n            attention_mask=attention_mask,\n        )\n        intermediate_hidden_states = [\n            hidden_state\n            for idx, hidden_state in enumerate(all_intermediate_hidden_states)\n            if idx in self.intermediate_layers_indices\n        ]\n        intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)\n\n        # apply global encoder\n        hidden_state = self.layernorm_post(hidden_state)\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            dim,\n        )\n        hidden_state = self.post_tile_positional_embedding(\n            hidden_state, aspect_ratio_ids\n        )\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles * (num_patches + num_padding_patches),\n            dim,\n        )\n        hidden_state, _ = self.global_transformer(\n            hidden_state, attention_mask=attention_mask\n        )\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            dim,\n        )\n        hidden_state = hidden_state[:, :, :slice_index]\n\n        # adding intermediate layer outputs\n        hidden_state = hidden_state.reshape(\n            batch_size, num_concurrent_media, num_tiles, num_patches, dim\n        )\n        intermediate_hidden_states = intermediate_hidden_states.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            -1,\n        )\n        intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]\n        intermediate_hidden_states = intermediate_hidden_states.reshape(\n            batch_size, num_concurrent_media, num_tiles, num_patches, -1\n        )\n        hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)\n        return hidden_state\n\n\nclass MllamaTextCrossAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, *, prefix, config, weights, layer_idx):\n        super().__init__()\n        self.config = config\n        self.num_heads = self.config.num_attention_heads\n        self.num_key_value_heads = self.config.num_key_value_heads\n        self.dropout = config.dropout\n        self.hidden_size = config.hidden_size\n        self.head_size = config.hidden_size // self.num_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.layer_idx = layer_idx\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            self.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.q_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.q_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.k_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.k_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.v_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.q_norm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.k_norm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.softmax_scale = self.head_size**-0.5\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cross_attention_states: Optional[torch.Tensor] = None,\n        # past_key_value=None,\n        # attention_mask: Optional[torch.Tensor] = None,\n        # cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n        # hidden_states = hidden_states.unsqueeze(0)\n        # bsz, q_len, _ = hidden_states.size()\n        (\n            cross_attention_states,\n            cross_attention_len,\n            indices,\n        ) = cross_attention_states\n        bs = cross_attention_len.size(0)\n        query_states = self.q_proj(hidden_states)\n        query_states = query_states.view(bs, -1, self.num_heads, self.head_size)\n        query_states = self.q_norm(query_states)\n\n        key_states = self.k_proj(cross_attention_states)\n        value_states = self.v_proj(cross_attention_states)\n        key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)\n        value_states = value_states.view(\n            bs, -1, self.num_key_value_heads, self.head_size\n        )\n        key_states = self.k_norm(key_states)\n\n        # key_states = key_states.repeat(1, self.num_key_value_groups, 1)\n        # value_states = value_states.repeat(1, self.num_key_value_groups, 1)\n        # logger.info(\n        #     f\"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}\"\n        # )\n        # execute sdpa\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n        fsdpa_op = ModuleFusedSDPA(FusedSDPA)\n        attn_output = fsdpa_op(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=None,\n            dropout_p=0.0,\n            is_causal=False,\n            scale=None,\n            softmax_mode=\"None\",\n            recompute_mode=None,\n            valid_sequence_lengths=None,\n        )\n        attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()\n        attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n        return attn_output\n\n\n# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText\nclass MllamaTextMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        shape = x.shape\n        gate_up_states = self.gate_up_proj(x)\n        gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)\n        result = self.down_proj(\n            self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]\n        )\n        return result\n\n\nclass FlashLlamaCrossLayer(torch.nn.Module):\n    \"\"\"Cross-attention transformer block with tanh-gated attention and feedforward.\"\"\"\n\n    def __init__(self, *, prefix, config, weights, index) -> None:\n        layer_idx = index\n        super().__init__()\n        self.cross_attn = MllamaTextCrossAttention(\n            prefix=f\"{prefix}.cross_attn\",\n            config=config,\n            weights=weights,\n            layer_idx=layer_idx,\n        )\n\n        self.input_layernorm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.cross_attn_attn_gate = torch.nn.Parameter(\n            weights.get_tensor(f\"{prefix}.cross_attn_attn_gate\"), requires_grad=False\n        )\n\n        self.mlp = MllamaTextMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.post_attention_layernorm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.cross_attn_mlp_gate = torch.nn.Parameter(\n            weights.get_tensor(f\"{prefix}.cross_attn_mlp_gate\"), requires_grad=False\n        )\n        self.layer_idx = layer_idx\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        cross_attention_states,  # [ IB, ...]\n        hpu_attention_meta,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if cross_attention_states is None:\n            return hidden_states, residual\n        if residual is not None:\n            hidden_states += residual\n\n        indices = cross_attention_states[-1]\n        out_hidden_states = hidden_states[:]\n        hidden_states = hidden_states[indices]\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        hidden_states = self.cross_attn(\n            hidden_states=hidden_states,\n            # attention_mask=cross_attention_mask,\n            cross_attention_states=cross_attention_states,\n        )\n        hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states\n\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states\n\n        out_hidden_states[indices] = hidden_states\n        hidden_states = out_hidden_states\n\n        return hidden_states, None\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText\nclass MllamaTextRMSNorm(nn.Module):\n    def __init__(self, weight, eps):\n        super().__init__()\n        self.weight = weight\n        self.variance_epsilon = eps\n\n    @classmethod\n    def load(cls, *, prefix, weights, eps):\n        weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.weight\"), requires_grad=False\n        )\n        return cls(weight=weight, eps=eps)\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass FlashMllamaForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        config.text_config._attn_implementation = \"sdpa\"\n        self.hidden_size = config.text_config.hidden_size\n        self.vision_model = MllamaVisionModel(\n            prefix=\"vision_model\", config=config.vision_config, weights=weights\n        )\n        self.multi_modal_projector = FastLinear.load(\n            prefix=\"multi_modal_projector\", config=config, weights=weights, bias=True\n        )\n        self.text_model = FlashLlamaForCausalLM(\n            prefix=\"language_model\", config=config.text_config, weights=weights\n        )\n        self.config = config\n        self.dtype = weights.dtype\n        self.device = weights.device\n\n    def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):\n        if aspect_ratio_ids is None:\n            raise ValueError(\n                \"`aspect_ratio_ids` must be provided if `pixel_values` is provided\"\n            )\n        # logger.info(f\"PIxel values {pixel_values.shape}\")\n        batch_size = pixel_values.shape[0]\n        vision_states = self.vision_model(\n            pixel_values, aspect_ratio_ids, aspect_ratio_mask\n        )\n        cross_attention_states = self.multi_modal_projector(vision_states).reshape(\n            -1, vision_states.shape[-2], self.hidden_size\n        )\n        _, _, h = cross_attention_states.shape\n        cross_attention_states = cross_attention_states.view(batch_size, -1, h)\n        # logger.info(f\"cross {cross_attention_states.shape}\")\n        return cross_attention_states\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor],\n        adapter_data: Optional[torch.Tensor] = None,\n        cross_attention_states: Optional[torch.Tensor] = None,\n        indices=None,\n        cross_attention_len: Optional[torch.Tensor] = None,\n    ):\n        if cross_attention_states is not None:\n            cross_attention_states = (\n                cross_attention_states,\n                cross_attention_len,\n                indices,\n            )\n\n        outputs = self.text_model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            lm_head_indices=lm_head_indices,\n            adapter_data=adapter_data,\n            cross_attention_states=cross_attention_states,\n        )\n\n        return outputs\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nimport habana_frameworks.torch as htorch\n\n\nclass GPTNeoXConfig(TransformersGPTNeoXConfig):\n    attribute_map = {\n        \"num_key_value_heads\": \"num_attention_heads\",\n    }\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    if config.use_parallel_residual:\n        return linear\n    else:\n        return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\ndef load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):\n    weight = weights.get_multi_weights_col([prefix], dim=0)\n    if isinstance(weight, UnquantizedWeight):\n        # Only on non quantized versions\n        weight.weight = (\n            weight.weight.view(\n                num_heads,\n                3,\n                head_size,\n                hidden_size,\n            )\n            .permute(1, 0, 2, 3)\n            .reshape(-1, hidden_size)\n        )\n\n    bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)\n\n    linear = get_linear(weight, bias)\n    if config.use_parallel_residual:\n        return linear\n    else:\n        return TensorParallelColumnLinear(linear)\n\n\nclass FlashNeoxAttention(torch.nn.Module):\n    def __init__(self, config, prefix, weights, rotary_emb):\n        super().__init__()\n        num_heads = config.num_attention_heads\n        hidden_size = config.hidden_size\n\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n\n        self.rotary_dim = int(config.rotary_pct * self.head_size)\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.rotary_emb = rotary_emb\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        self.query_key_value = load_qkv(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            num_heads=self.num_heads,\n            head_size=self.head_size,\n            hidden_size=self.hidden_size,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=True\n        )\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        qkv = qkv.view(-1, 3, self.num_heads, self.head_size)\n\n        # Compute rotary embeddings on rotary_ndims\n        query_rot = qkv[:, 0][..., : self.rotary_dim]\n        query_pass = qkv[:, 0][..., self.rotary_dim :]\n        key_rot = qkv[:, 1][..., : self.rotary_dim]\n        key_pass = qkv[:, 1][..., self.rotary_dim :]\n\n        # Inplace rotary\n        self.rotary_emb(query_rot, key_rot, cos, sin)\n        qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)\n        qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)\n\n        kv_cache.store(\n            key=qkv[:, 1],\n            value=qkv[:, 2],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=qkv[:, 0],\n                key=qkv[:, 1],\n                value=qkv[:, 2],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                qkv[:, 0],\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass FlashMLP(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=True\n        )\n        self.dense_4h_to_h = load_row(\n            config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        return hidden_states\n\n\nclass FlashNeoXLayer(nn.Module):\n    def __init__(self, layer_id, config, weights, rotary_emb):\n        super().__init__()\n\n        layer_norm_eps = config.layer_norm_eps\n\n        prefix = f\"gpt_neox.layers.{layer_id}\"\n\n        self.use_parallel_residual = config.use_parallel_residual\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=layer_norm_eps\n        )\n        self.post_attention_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=layer_norm_eps,\n        )\n        self.attention = FlashNeoxAttention(\n            config,\n            prefix=f\"{prefix}.attention\",\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n\n        self.mlp = FlashMLP(config, prefix=f\"{prefix}.mlp\", weights=weights)\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        if self.use_parallel_residual:\n            ln1_hidden_states, _ = self.input_layernorm(hidden_states)\n\n            attn_output = self.attention(\n                ln1_hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n\n            ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)\n\n            mlp_output = self.mlp(ln2_hidden_states)\n            intermediate = mlp_output + attn_output\n\n            if self.process_group.size() > 1:\n                torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n            return intermediate + hidden_states, None\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            hidden_states = self.attention(\n                hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n\n            hidden_states, residual = self.post_attention_layernorm(\n                hidden_states, residual\n            )\n\n            mlp_output = self.mlp(hidden_states)\n\n            return mlp_output, residual\n\n\nclass FlashGPTNeoXPreTrainedModel(PreTrainedModel):\n    config_class = GPTNeoXConfig\n    base_model_prefix = \"gpt_neox\"\n    supports_gradient_checkpointing = False\n    _no_split_modules = None\n\n\nclass FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.config = config\n\n        self.embed_in = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_in\", weights=weights\n        )\n\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=int(\n                config.rotary_pct * (config.hidden_size // config.num_attention_heads)\n            ),\n            base=config.rotary_emb_base,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashNeoXLayer(layer_id, config, weights, rotary_emb)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.final_layer_norm = FastLayerNorm.load(\n            prefix=f\"{prefix}.final_layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].attention.head_size\n        self.num_heads = self.layers[0].attention.num_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_in(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.final_layer_norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):\n    def __init__(self, prefix, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"gpt_neox\"\n        else:\n            prefix = f\"{prefix}.gpt_neox\"\n\n        self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)\n\n        self.embed_out = SpeculativeHead.load(\n            config, prefix=\"embed_out\", weights=weights\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.gpt_neox(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.embed_out(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear\nfrom text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\n\n\nclass PaliGemmaForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = config.quantize\n        self.vision_tower = load_vision_model(\n            prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n            config=config.vision_config,\n            weights=weights,\n        )\n        self.post_vision_tower_layernorm = nn.LayerNorm.load(\n            prefix=\"vision_tower.vision_model.post_layernorm\",\n            weights=weights,\n            eps=config.vision_config.layer_norm_eps,\n        )\n\n        self.multi_modal_projector = TensorParallelColumnLinear.load(\n            config,\n            prefix=\"multi_modal_projector.linear\",\n            weights=weights,\n            bias=True,\n        )\n\n        self.vocab_size = config.text_config.vocab_size\n        self.config = config\n\n        text_config = config.text_config\n        text_config.speculator = config.speculator\n        text_config.quantize = config.quantize\n        self.text_model = load_text_model(\n            prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n            config=config.text_config,\n            weights=weights,\n        )\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n        self.dtype = weights.dtype\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        pixel_values = pixel_values.to(dtype=self.dtype)\n        image_outputs = self.vision_tower(pixel_values)\n        last_hidden_state = self.post_vision_tower_layernorm(\n            image_outputs.last_hidden_state\n        )\n        image_features = self.multi_modal_projector(last_hidden_state)\n        image_features = image_features.view(-1, image_features.shape[-1])\n        return image_features\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            mask = input_ids == self.config.image_token_index\n            inputs_embeds[mask] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # TODO This is odd but apparently pali gemma position ids start at 1.\n        if cu_seqlen_prefill is not None:\n            position_ids += 1\n\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            adapter_data=adapter_data,\n        )\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nimport habana_frameworks.torch as htorch\n\n\nclass PhiConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=51200,\n        hidden_size=2560,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        hidden_act=\"gelu_fast\",  # llama uses silu\n        layer_norm_eps=1e-05,  # rms in llama,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        resid_pdrop=0.1,  # llama doesn't have this\n        partial_rotary_factor=0.5,  # important difference between llama and phi\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.layer_norm_eps = layer_norm_eps\n        self.rope_theta = rope_theta\n        self.resid_pdrop = resid_pdrop\n        self.partial_rotary_factor = partial_rotary_factor\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n# this is the same as llama except for Phi uses bias=True\ndef load_attention(config, prefix, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if config.quantize not in [\"gptq\", \"awq\"]:\n        weight = weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    # this is the same as llama except for Phi uses bias=True\n    return TensorParallelColumnLinear(get_linear(weight, bias=True))\n\n\nclass FlashPhiAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.softmax_scale = self.head_size**-0.5\n        self.rotary_dim = int(config.partial_rotary_factor * self.head_size)\n        self.rotary_emb = rotary_emb\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        # in llama the dense layer is called \"o_proj\" and has bias=False\n        self.dense = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.dense\",\n            weights=weights,\n            bias=True,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        # Compute query, key, value and split\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        # Reshape query and key for rotary embeddings\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        # NOTE: this is the main difference between Llama and Phi\n        # in llama the rotary embeddings are applied to the whole query and key.\n        # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions\n        #\n        # Apply partial positional embeddings in place\n        self.rotary_emb(\n            query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin\n        )\n\n        # Reshape key and value and cache\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_scales=self.kv_scales,\n                kv_cache=kv_cache,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass PhiMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        # llama weights are up_proj and down_proj and bias=False\n        self.up_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.fc1\",\n            weights=weights,\n            bias=True,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.fc2\",\n            weights=weights,\n            bias=True,\n        )\n\n    def forward(self, hidden_states):\n        # NOTE: Llama requires the gate up states to an intermediate size\n        # Phi does not and we can avoid the `view` operation\n        return self.down_proj(self.act(self.up_proj(hidden_states)))\n\n\nclass FlashPhiLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = FlashPhiAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = PhiMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        hidden_states, res = self.input_layernorm(hidden_states, residual)\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        hidden_states = self.resid_dropout(attn_output).add(\n            self.resid_dropout(self.mlp(hidden_states))\n        )\n\n        return hidden_states, res\n\n\nclass FlashPhiModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=int(\n                config.partial_rotary_factor\n                * (config.hidden_size // config.num_attention_heads)\n            ),\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashPhiLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n        self.norm = FastLayerNorm.load(\n            prefix=\"model.final_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashPhiForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = FlashPhiModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n\n        return self.lm_head(hidden_states)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"PyTorch Phi-MoE model.\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\n\nPHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"microsoft/Phi-3.5-MoE-instruct\": \"https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json\",\n}\n\n\nclass PhiMoEConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the\n    [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32064):\n            Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`PhiMoEModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 6400):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to `4096*32`):\n            The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention\n            allows sequence of up to 4096*32 tokens.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-05):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            The id of the padding token.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            The id of the \"beginning-of-sequence\" token.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            The id of the \"end-of-sequence\" token.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`dict`, *optional*):\n            The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must\n            contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and\n            `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must\n            be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of\n            the attention head size and the `original_max_position_embeddings` must be an integer.\n        sliding_window (`int`, *optional*):\n            Sliding window attention window size. If not specified, will default to `262144`.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        num_experts_per_tok (`int`, *optional*, defaults to 2):\n            The number of experts to root per-token, can be also interpreted as the `top-p` routing\n            parameter\n        num_local_experts (`int`, *optional*, defaults to 16):\n            Number of experts per Sparse MLP layer.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss. See [here]() for more details\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.0):\n            The aux loss factor for the total loss.\n        router_jitter_noise (`float`, *optional*, defaults to 0.01):\n            Amount of noise to add to the router.\n\n    ```python\n    >>> from transformers import PhiMoEModel, PhiMoEConfig\n\n    >>> # Initializing a Phi-3 style configuration\n    >>> configuration = PhiMoEConfig.from_pretrained(\"microsoft/Phi-3.5-MoE-instruct\")\n\n    >>> # Initializing a model from the configuration\n    >>> model = PhiMoEModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"phimoe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=32064,\n        hidden_size=4096,\n        intermediate_size=6400,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-5,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        rope_theta=1e6,\n        rope_scaling=None,\n        sliding_window=None,\n        attention_dropout=0.0,\n        num_experts_per_tok=2,\n        num_local_experts=16,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        router_jitter_noise=0.01,\n        input_jitter_noise=0.0,\n        attention_bias=False,\n        lm_head_bias=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n        self.attention_bias = attention_bias\n        self.lm_head_bias = lm_head_bias\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_local_experts = num_local_experts\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.router_jitter_noise = router_jitter_noise\n        self.input_jitter_noise = input_jitter_noise\n\n        self.rope_scaling = rope_scaling\n        self._rope_scaling_validation()\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    def _rope_scaling_validation(self):\n        \"\"\"\n        Validate the `rope_scaling` configuration.\n        \"\"\"\n        if self.rope_scaling is None:\n            return\n\n        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:\n            raise ValueError(\n                \"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, \"\n                f\"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}\"\n            )\n        rope_scaling_type = self.rope_scaling.get(\"type\", None)\n        rope_scaling_short_factor = self.rope_scaling.get(\"short_factor\", None)\n        rope_scaling_long_factor = self.rope_scaling.get(\"long_factor\", None)\n        rope_scaling_short_mscale = self.rope_scaling.get(\"short_mscale\", None)\n        rope_scaling_long_mscale = self.rope_scaling.get(\"long_mscale\", None)\n        original_max_position_embeddings = self.rope_scaling.get(\n            \"original_max_position_embeddings\", None\n        )\n        if rope_scaling_type is None or rope_scaling_type not in [\"longrope\"]:\n            raise ValueError(\n                f\"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}\"\n            )\n        if not (\n            isinstance(rope_scaling_short_factor, list)\n            and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}\"\n            )\n        if (\n            not len(rope_scaling_short_factor)\n            == self.hidden_size // self.num_attention_heads // 2\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}\"\n            )\n        if not (\n            isinstance(rope_scaling_long_factor, list)\n            and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}\"\n            )\n        if (\n            not len(rope_scaling_long_factor)\n            == self.hidden_size // self.num_attention_heads // 2\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}\"\n            )\n        if not isinstance(rope_scaling_short_mscale, (int, float)):\n            raise ValueError(\n                f\"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}\"\n            )\n        if not isinstance(rope_scaling_long_mscale, (int, float)):\n            raise ValueError(\n                f\"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}\"\n            )\n        if not isinstance(original_max_position_embeddings, int):\n            raise ValueError(\n                f\"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}\"\n            )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nimport habana_frameworks.torch as htorch\n\n\ndef load_attention(config, prefix, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    return TensorParallelColumnLinear.load_multi(\n        config,\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n        weights=weights,\n        bias=True,\n    )\n\n\nclass Qwen2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window\n            if config.use_sliding_window and config.sliding_window is not None\n            else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass Qwen2MLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])\n\n\nclass Qwen2Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = Qwen2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = Qwen2MLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        hidden_states = attn_output + residual\n\n        # faster post attention rms norm\n        hidden_states, residual = self.post_attention_layernorm(hidden_states)\n        mlp_output = self.mlp(hidden_states)\n        hidden_states = mlp_output + residual\n        return hidden_states\n\n\nclass Qwen2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        prefix = f\"{prefix}.model\" if prefix else \"model\"\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.hidden_size // config.num_attention_heads,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                Qwen2Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids,\n        )\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass Qwen2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.model = Qwen2Model(prefix, config, weights)\n\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=f\"{prefix}.{suffix}\" if prefix else suffix,\n            weights=weights,\n        )\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\" if prefix else \"model.embed_tokens\",\n            weights=weights,\n        )\n\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_modeling.py",
    "content": "# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nfrom torch import nn\nimport habana_frameworks.torch as htorch\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    SpeculativeHead,\n)\n\n\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom .flash_qwen2_modeling import Qwen2MLP\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\n\n\nclass Qwen3Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config, prefix, weights, layer_idx, rotary_emb):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n        self.num_heads = config.num_attention_heads\n        self.attention_dropout = config.attention_dropout\n        self.softmax_scale = self.head_dim**-0.5\n        self.rotary_emb = rotary_emb\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n        self.query_key_value = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n\n        self.q_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.k_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n        qkv = self.query_key_value(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_key_value_heads,\n                self.head_dim * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        query_states, _ = self.q_norm(query_states.view(hidden_shape))\n        key_states, _ = self.k_norm(key_states.view(hidden_shape))\n        value_states = value_states.view(hidden_shape)\n        self.rotary_emb(query_states, key_states, cos, sin)\n\n        kv_cache.store(\n            key=key_states,\n            value=value_states,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query_states,\n                key=key_states,\n                value=value_states,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query_states,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.max_past,\n            )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        return self.o_proj(attn_output)\n\n\nclass Qwen3DecoderLayer(nn.Module):\n    def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = Qwen3Attention(\n            config=config,\n            prefix=f\"{prefix}.self_attn\",\n            weights=weights,\n            layer_idx=layer_idx,\n            rotary_emb=rotary_emb,\n        )\n        self.mlp = Qwen2MLP(config=config, prefix=f\"{prefix}.mlp\", weights=weights)\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ) -> torch.Tensor:\n        residual = hidden_states\n        hidden_states, _ = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states, _ = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\nclass Qwen3Model(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                Qwen3DecoderLayer(\n                    config=config,\n                    prefix=f\"{prefix}.layers.{layer_idx}\",\n                    weights=weights,\n                    layer_idx=layer_idx,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids,\n        )\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n\n        for i, decoder_layer in enumerate(self.layers):\n            hidden_states = decoder_layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        return hidden_states\n\n\nclass Qwen3ForCausalLM(nn.Module):\n\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.model = Qwen3Model(config=config, prefix=\"model\", weights=weights)\n        self.vocab_size = config.vocab_size\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=f\"{prefix}.{suffix}\" if prefix else suffix,\n            weights=weights,\n        )\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\" if prefix else \"model.embed_tokens\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        inputs_embeds = self.embed_tokens(input_ids)\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py",
    "content": "# coding=utf-8\n# Copyright 5 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom text_generation_server.layers.attention import (\n    attention,\n    paged_attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n    FastLinear,\n)\n\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom .flash_qwen2_modeling import Qwen2MLP\nfrom .flash_qwen3_modeling import Qwen3Attention\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass Qwen3MoeAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config, prefix, weights, layer_idx, rotary_emb):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = (\n            config.num_attention_heads // config.num_key_value_heads\n        )\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = FastLinear.load(\n            config, f\"{prefix}.q_proj\", weights, bias=config.attention_bias\n        )\n\n        self.k_proj = FastLinear.load(\n            config, f\"{prefix}.k_proj\", weights, bias=config.attention_bias\n        )\n        self.v_proj = FastLinear.load(\n            config, f\"{prefix}.v_proj\", weights, bias=config.attention_bias\n        )\n        self.o_proj = FastLinear.load(\n            config, f\"{prefix}.o_proj\", weights, bias=config.attention_bias\n        )\n        self.rotary_emb = rotary_emb\n\n        self.q_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.k_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_key_value_groups)\n\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states, _ = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))\n        key_states, _ = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))\n        value_states = self.v_proj(hidden_states).view(hidden_shape)\n\n        self.rotary_emb(query_states, key_states, cos, sin)\n        # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        kv_cache.store(\n            key=key_states,\n            value=value_states,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query_states,\n                key=key_states,\n                value=value_states,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.scaling,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query_states,\n                kv_cache,\n                self.kv_head_mapping,\n                self.scaling,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.max_past,\n            )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output\n\n\nclass Qwen3MoE(nn.Module):\n    def __init__(self, prefix, config, moe_layer_cls: Type[MoELayer], weights):\n        super().__init__()\n\n        # gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe = moe_layer_cls(\n            n_expert_group=None,\n            n_experts=config.num_experts,\n            prefix=f\"{prefix}.experts\",\n            renormalize=True,\n            topk=config.num_experts_per_tok,\n            topk_group=None,\n            weights=weights,\n        )\n        # gate_proj_name=\"w1\",\n        # up_proj_name=\"w3\",\n        # down_proj_name=\"w2\",\n\n        assert isinstance(self.moe, MoELayer)\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        router_logits = self.gate(x)\n        out = self.moe(x, gating_output=router_logits)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass Qwen3MoeMLP(nn.Module):\n    def __init__(self, prefix, config, weights, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = (\n            intermediate_size\n            if intermediate_size is not None\n            else config.intermediate_size\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        gate_up_states = self.gate_up_proj(x)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])\n\n\nclass Qwen3MoeSparseMoeBlock(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        # self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n        self.experts = nn.ModuleList(\n            [\n                Qwen3MoeMLP(\n                    prefix=f\"{prefix}.experts.{i}\",\n                    config=config,\n                    weights=weights,\n                    intermediate_size=config.moe_intermediate_size,\n                )\n                for i in range(self.num_experts)\n            ]\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        input_shape = hidden_states.shape\n        _, hidden_dim = hidden_states.shape\n        # hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=hidden_states.dtype)\n        routing_weights, selected_experts = torch.topk(\n            routing_weights, self.top_k, dim=-1\n        )\n        if self.norm_topk_prob:  # only diff with mixtral sparse moe block!\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (input_shape), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(\n            selected_experts, num_classes=self.num_experts\n        ).permute(2, 1, 0)\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = (\n                expert_layer(current_state) * routing_weights[top_x, idx, None]\n            )\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(\n                0, top_x, current_hidden_states.to(hidden_states.dtype)\n            )\n        final_hidden_states = final_hidden_states.reshape(input_shape)\n        return final_hidden_states\n\n\nclass Qwen3MoeDecoderLayer(nn.Module):\n    def __init__(self, config, prefix, weights, layer_idx: int, rotary_emb):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        if config.num_key_value_heads // weights.process_group.size() > 0:\n            self.self_attn = Qwen3Attention(\n                config,\n                prefix=f\"{prefix}.self_attn\",\n                weights=weights,\n                layer_idx=layer_idx,\n                rotary_emb=rotary_emb,\n            )\n        else:\n            self.self_attn = Qwen3MoeAttention(\n                config,\n                prefix=f\"{prefix}.self_attn\",\n                weights=weights,\n                layer_idx=layer_idx,\n                rotary_emb=rotary_emb,\n            )\n\n        moe_layer_cls = (\n            SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer\n        )\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen3MoE(f\"{prefix}.mlp\", config, moe_layer_cls, weights)\n            # self.mlp = Qwen3MoeSparseMoeBlock(f\"{prefix}.mlp\", config, weights)\n\n        else:\n            self.mlp = Qwen2MLP(config=config, prefix=f\"{prefix}.mlp\", weights=weights)\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ) -> torch.Tensor:\n        if residual is None:\n            residual = hidden_states\n\n        hidden_states, _ = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states, _ = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\nclass Qwen3MoeModel(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n        head_dim = getattr(\n            config, \"head_dim\", config.hidden_size // config.num_attention_heads\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                Qwen3MoeDecoderLayer(\n                    config=config,\n                    prefix=f\"{prefix}.layers.{layer_idx}\",\n                    weights=weights,\n                    layer_idx=layer_idx,\n                    rotary_emb=rotary_emb,\n                )\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, inputs_embeds.shape[0]\n            )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids,\n        )\n\n        residual = None\n        for i, decoder_layer in enumerate(self.layers):\n            hidden_states = decoder_layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n\n        hidden_states, _ = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        return hidden_states\n\n\nclass Qwen3MoeForCausalLM(nn.Module):\n\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.model = Qwen3MoeModel(config=config, prefix=\"model\", weights=weights)\n        self.vocab_size = config.vocab_size\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=f\"{prefix}.{suffix}\" if prefix else suffix,\n            weights=weights,\n        )\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\" if prefix else \"model.embed_tokens\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n\n        inputs_embeds = self.embed_tokens(input_ids)\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_utils import PreTrainedModel\nfrom text_generation_server.layers import (\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import FastLayerNorm\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.attention import (\n    attention,\n    paged_attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nimport habana_frameworks.torch as htorch\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    if config.parallel_attn:\n        return linear\n    else:\n        return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\nclass RWConfig(PretrainedConfig):\n    attribute_map = {\n        \"num_hidden_layers\": \"n_layer\",\n        \"num_attention_heads\": \"n_head\",\n        \"num_key_value_heads\": \"n_head_kv\",\n    }\n\n    def __init__(\n        self,\n        model_type=\"RefinedWeb\",\n        vocab_size=250880,\n        hidden_size=64,\n        num_hidden_layers=None,\n        num_attention_heads=None,\n        num_ln_in_prallel_attention=None,\n        layer_norm_epsilon=1e-5,\n        initializer_range=0.02,\n        use_cache=True,\n        bos_token_id=1,\n        eos_token_id=2,\n        hidden_dropout=0.0,\n        attention_dropout=0.0,\n        num_kv_heads=None,\n        multi_query=False,\n        alibi=False,\n        new_decoder_architecture=None,\n        bias=False,\n        parallel_attn=False,\n        rope_theta=10_000.0,\n        **kwargs,\n    ):\n        if alibi:\n            raise NotImplementedError(\n                \"alibi is not supported by this version of the model\"\n            )\n\n        self.model_type = model_type\n        self.alibi = False\n        self.rotary = True\n        self.rope_theta = rope_theta\n        self.max_position_embeddings = 2048\n\n        self.vocab_size = vocab_size\n        # Backward compatibility with n_embed kwarg\n        n_embed = kwargs.pop(\"n_embed\", None)\n        self.hidden_size = hidden_size if n_embed is None else n_embed\n        self.n_layer = (\n            num_hidden_layers\n            if num_hidden_layers is not None\n            else kwargs.pop(\"n_layer\", 2)\n        )\n        self.n_head = (\n            num_attention_heads\n            if num_attention_heads is not None\n            else kwargs.pop(\"n_head\", 8)\n        )\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.num_ln_in_parallel_attn = num_ln_in_prallel_attention\n        self.initializer_range = initializer_range\n        self.use_cache = use_cache\n        self.hidden_dropout = hidden_dropout\n        self.attention_dropout = attention_dropout\n        self.bias = bias\n        self.parallel_attn = parallel_attn\n\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n\n        if num_kv_heads is not None:\n            self.n_head_kv = num_kv_heads\n        else:\n            old_n_head_kv = kwargs.pop(\"n_head_kv\", None)\n            if old_n_head_kv is not None:\n                self.n_head_kv = old_n_head_kv\n            else:\n                self.n_head_kv = 1 if multi_query else self.n_head\n\n        if new_decoder_architecture is not None:\n            self.new_decoder_architecture = new_decoder_architecture\n        elif model_type == \"RefinedWeb\":\n            self.new_decoder_architecture = True\n        else:\n            self.new_decoder_architecture = False\n\n        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)\n\n\nclass FlashRWAttention(torch.nn.Module):\n    def __init__(\n        self,\n        config,\n        prefix: str,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.num_heads = config.n_head\n        self.num_heads_kv = config.n_head_kv\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n        self.rope_theta = config.rope_theta\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=config.bias,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=config.bias\n        )\n\n        if self.num_heads_kv == 1:\n            self.kv_head_mapping = torch.zeros(\n                self.num_heads, dtype=torch.int32, device=weights.device\n            )\n        else:\n            self.kv_head_mapping = torch.arange(\n                0, self.num_heads, dtype=torch.int32, device=weights.device\n            )\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n\n        # Split query from key_value\n        query, kv = qkv.split(\n            [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],\n            dim=1,\n        )\n\n        # Prepare query and key_value for indexing\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)\n\n        # Inplace rotary\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass FlashRWLargeAttention(torch.nn.Module):\n    def __init__(\n        self,\n        config,\n        prefix: str,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n\n        hidden_size = config.hidden_size\n        num_heads = config.n_head\n        # num_heads_kv = config.n_head_kv\n        num_groups = config.n_head_kv\n\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n        self.num_groups = num_groups\n        self.rope_theta = config.rope_theta\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        # self.num_groups = num_heads // (num_heads_kv * 2)\n        self.num_heads = num_heads // self.num_groups\n        # self.num_heads_kv = num_heads_kv // self.num_groups\n        process_group = weights.process_group\n\n        if process_group.size() > self.num_groups:\n            raise NotImplementedError(\n                \"Tensor Parallelism is not implemented for world_size > n groups\"\n            )\n        if self.num_groups % process_group.size() != 0:\n            raise NotImplementedError(\n                f\"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}\"\n            )\n\n        self.num_groups = self.num_groups // process_group.size()\n\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=config.bias,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=config.bias\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_groups, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_heads)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)\n\n        # Split on group dimension\n        query, kv = qkv.split(\n            [self.num_heads, 2],\n            dim=2,\n        )\n        # Merge groups and heads\n        query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)\n\n        # Inplace rotary\n        self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, :, 0].contiguous(),\n            value=kv[:, :, 1].contiguous(),\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, :, 0],\n                value=kv[:, :, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.dense(\n            attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)\n        )\n\n\nclass FlashMLP(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        self.act = torch.nn.functional.gelu\n\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=config.bias\n        )\n        self.dense_4h_to_h = load_row(\n            config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=config.bias\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        return hidden_states\n\n\nclass FlashRWLayer(nn.Module):\n    def __init__(\n        self,\n        layer_id,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n\n        parallel_attn = config.parallel_attn\n        self.parallel_attn = parallel_attn\n\n        prefix = f\"{prefix}.h.{layer_id}\"\n\n        # NOTE: Falcon 180B uses the ln_attn prefix\n        ln_prefix = \"input_layernorm\"\n        if config.num_hidden_layers == 80:\n            ln_prefix = \"ln_attn\"\n\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.{ln_prefix}\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.self_attention = FlashRWAttention(\n            config,\n            prefix=f\"{prefix}.self_attention\",\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        self.post_attention_layernorm = (\n            FastLayerNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n            if not parallel_attn\n            else None\n        )\n\n        self.mlp = FlashMLP(\n            config,\n            prefix=f\"{prefix}.mlp\",\n            weights=weights,\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        if self.parallel_attn:\n            ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            attn_output = self.self_attention(\n                ln_hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n\n            mlp_output = self.mlp(ln_hidden_states)\n            intermediate = mlp_output + attn_output\n\n            if self.process_group.size() > 1:\n                torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n            return intermediate, residual\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            hidden_states = self.self_attention(\n                hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n\n            if self.post_attention_layernorm is not None:\n                hidden_states, residual = self.post_attention_layernorm(\n                    hidden_states, residual\n                )\n\n            mlp_output = self.mlp(hidden_states)\n\n            return mlp_output, residual\n\n\nclass FlashRWLayerNorm(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        # Falcon2 includes the number of layer norms in the config\n        # in the case no number of layer norms is provided, we default to 1\n        self.num_ln = getattr(config, \"num_ln_in_parallel_attn\", 1)\n\n        # Falcon 180B uses the ln_attn prefix and has 2 layer norms\n        if config.num_hidden_layers == 80:\n            self.num_ln = 2\n\n        if self.num_ln == 1:\n            self.input_ln = FastLayerNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n        elif self.num_ln == 2:\n            self.ln_attn = FastLayerNorm.load(\n                prefix=f\"{prefix}.ln_attn\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n            self.ln_mlp = FastLayerNorm.load(\n                prefix=f\"{prefix}.ln_mlp\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n        else:\n            raise ValueError(\"Number of layer norms can either be 1 or 2.\")\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n    ):\n        if self.num_ln == 1:\n            ln_hidden_states, residual = self.input_ln(hidden_states, residual)\n            return ln_hidden_states, ln_hidden_states, residual\n        elif self.num_ln == 2:\n            ln_attn, residual = self.ln_attn(hidden_states, residual)\n            ln_mlp, _ = self.ln_mlp(residual)\n            return ln_attn, ln_mlp, residual\n\n\nclass FlashRWLargeLayer(nn.Module):\n    def __init__(self, layer_id, prefix: str, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"{prefix}.h.{layer_id}\"\n\n        self.ln_layer = FlashRWLayerNorm(config, prefix, weights)\n\n        self.self_attention = FlashRWLargeAttention(\n            config,\n            prefix=f\"{prefix}.self_attention\",\n            weights=weights,\n            rotary_emb=rotary_emb,\n        )\n        assert config.parallel_attn, \"This version doesn't support non parallel_attn\"\n\n        self.mlp = FlashMLP(config, prefix=f\"{prefix}.mlp\", weights=weights)\n\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        # Layer norm.\n        ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)\n\n        # Self attention.\n        attn_output = self.self_attention(\n            ln_attn,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        # MLP.\n        mlp_output = self.mlp(ln_mlp)\n\n        intermediate = attn_output + mlp_output\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n        return intermediate, residual\n\n\nclass FlashRWPreTrainedModel(PreTrainedModel):\n    config_class = RWConfig\n\n\nclass FlashRWModel(FlashRWPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.config = config\n\n        self.word_embeddings = TensorParallelEmbedding(\n            prefix=f\"{prefix}.word_embeddings\", weights=weights\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.hidden_size // config.n_head,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        if config.new_decoder_architecture:\n            self.h = nn.ModuleList(\n                [\n                    FlashRWLargeLayer(layer_id, prefix, config, weights, rotary_emb)\n                    for layer_id in range(config.num_hidden_layers)\n                ]\n            )\n            self.cache_size = self.h[0].self_attention.num_groups\n        else:\n            self.h = nn.ModuleList(\n                [\n                    FlashRWLayer(layer_id, prefix, config, weights, rotary_emb)\n                    for layer_id in range(config.num_hidden_layers)\n                ]\n            )\n            self.cache_size = self.h[0].self_attention.num_heads_kv\n\n        self.ln_f = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.head_size = self.h[0].self_attention.head_size\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.word_embeddings(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.h):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashRWForCausalLM(FlashRWPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        self.transformer = FlashRWModel(prefix, config, weights)\n\n        self.lm_head = SpeculativeHead.load(config, prefix=\"lm_head\", weights=weights)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.transformer(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    SpeculativeHead,\n    TensorParallelEmbedding,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.gptq import GPTQWeightsLoader\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nimport habana_frameworks.torch as htorch\n\n\ndef load_multi_mqa(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    if config.quantize == \"gptq\":\n        return _load_multi_mqa_gptq(\n            config, prefix, weights, bias, head_size, num_heads, hidden_size\n        )\n    else:\n        return _load_multi_mqa(\n            config, prefix, weights, bias, head_size, num_heads, hidden_size\n        )\n\n\ndef _load_multi_mqa_gptq(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    from text_generation_server.layers.gptq import GPTQWeight\n\n    if any(\"c_attn\" in k for k in weights.routing.keys()) and not config.transpose:\n        world_size = weights.process_group.size()\n        rank = weights.process_group.rank()\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.qweight\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - 2 * head_size) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert (shape[1] - 2 * head_size) % world_size == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size :]\n        qweight = torch.cat([q_tensor, kv_tensor], dim=1)\n        qweight = qweight.to(device=weights.device)\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.scales\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - 2 * head_size) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert (shape[1] - 2 * head_size) % world_size == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size :]\n        scales = torch.cat([q_tensor, kv_tensor], dim=1)\n        scales = scales.to(device=weights.device)\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.qzeros\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert 2 * head_size % (32 // 4) == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]\n        qzeros = torch.cat([q_tensor, kv_tensor], dim=1)\n        qzeros = qzeros.to(device=weights.device)\n\n        loader = weights.weights_loader\n        assert isinstance(loader, GPTQWeightsLoader)\n        loader._get_gptq_params(weights)\n        if loader.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.c_attn.g_idx\")\n            g_idx = g_idx.to(device=weights.device)\n        elif loader.quant_method == \"awq\":\n            g_idx = None\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n\n        from text_generation_server.layers.gptq import HAS_EXLLAMA\n\n        weight = GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=loader.bits,\n            groupsize=loader.groupsize,\n            use_awq_kernel=loader.quantize == \"awq\",\n            use_exllama=HAS_EXLLAMA,\n        )\n\n        if bias:\n            slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n            shape = slice_.get_shape()\n            block_size = (shape[0] - 2 * head_size) // world_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[start:stop]\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            bias = torch.cat([q_tensor, kv_tensor], dim=0)\n            bias = bias.to(device=weights.device)\n\n        return TensorParallelColumnLinear(get_linear(weight, bias))\n    else:\n        raise NotImplementedError(\"Gptq loading with santacoder is not implemented\")\n\n\ndef _load_multi_mqa(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    if any(\"c_attn\" in k for k in weights.routing.keys()):\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.weight\")\n        shape = slice_.get_shape()\n        world_size = weights.process_group.size()\n        rank = weights.process_group.rank()\n        if config.transpose:\n            block_size = (shape[1] - 2 * head_size) // world_size\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            assert (shape[1] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[:, start:stop]\n            kv_tensor = slice_[:, -2 * head_size :]\n            weight = torch.cat([q_tensor, kv_tensor], dim=1).T\n        else:\n            block_size = (shape[0] - 2 * head_size) // world_size\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            weight = torch.cat([q_tensor, kv_tensor], dim=0)\n        if bias:\n            slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n            shape = slice_.get_shape()\n            block_size = (shape[0] - 2 * head_size) // world_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            bias = torch.cat([q_tensor, kv_tensor], dim=0)\n    else:\n        if config.transpose:\n            w = [\n                weights.get_sharded(f\"{prefix}.q_attn.weight\", dim=1).T,\n                weights.get_tensor(f\"{prefix}.kv_attn.weight\").T,\n            ]\n            weight = torch.cat(w, dim=0)\n        else:\n            w = [\n                weights.get_sharded(f\"{prefix}.q_attn.weight\", dim=0),\n                weights.get_tensor(f\"{prefix}.kv_attn.weight\"),\n            ]\n            weight = torch.cat(w, dim=1)\n\n        if bias:\n            b = [\n                weights.get_sharded(f\"{prefix}.q_attn.bias\", dim=0),\n                weights.get_tensor(f\"{prefix}.kv_attn.bias\"),\n            ]\n            bias = torch.cat(b, dim=0)\n        else:\n            bias = None\n\n    weight = weight.to(dtype=weights.dtype).to(device=weights.device)\n    assert list(weight.shape) == [\n        (num_heads + 2) * head_size,\n        hidden_size,\n    ], f\"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}\"\n    if bias is not None:\n        bias = bias.to(dtype=weights.dtype).to(device=weights.device)\n        assert list(bias.shape) == [\n            (num_heads + 2) * head_size\n        ], f\"{weight.shape} != {[(num_heads + 2) * head_size]}\"\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_col(config, prefix: str, weights, bias: bool):\n    if config.transpose:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=1).T\n    else:\n        weight = weights.get_multi_weights_col([prefix], dim=0)\n\n    if bias:\n        bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    else:\n        bias = None\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    if config.transpose:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0).T\n    else:\n        weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n    return TensorParallelRowLinear(\n        get_linear(weight, bias), process_group=weights.process_group\n    )\n\n\nclass FlashMQAttention(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        num_heads = config.num_attention_heads\n        hidden_size = config.hidden_size\n\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        self.c_attn = load_multi_mqa(\n            config,\n            prefix=prefix,\n            weights=weights,\n            bias=True,\n            head_size=self.head_size,\n            hidden_size=hidden_size,\n            num_heads=self.num_heads,\n        )\n        self.c_proj = load_row(\n            config, prefix=f\"{prefix}.c_proj\", weights=weights, bias=True\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.kv_head_mapping = torch.zeros(\n            self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        qkv = self.c_attn(hidden_states)\n\n        # Split query from key_value\n        query, key_value = qkv.split(\n            [self.head_size * self.num_heads, 2 * self.head_size], dim=1\n        )\n\n        # Prepare query and key_value for indexing\n        query = query.view(-1, self.num_heads, self.head_size)\n        key_value = key_value.view(-1, 2, 1, self.head_size)\n\n        kv_cache.store(\n            key=key_value[:, 0],\n            value=key_value[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=key_value[:, 0],\n                value=key_value[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n            )\n\n        return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass MLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.c_fc = load_col(\n            config, prefix=f\"{prefix}.c_fc\", weights=weights, bias=True\n        )\n        self.c_proj = load_row(\n            config, prefix=f\"{prefix}.c_proj\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        return hidden_states\n\n\nclass Block(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.h.{layer_id}\"\n        self.ln_1 = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.ln_2 = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_2\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.self_attn = FlashMQAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.mlp = MLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        hpu_attention_meta,\n    ):\n        hidden_states, residual = self.ln_1(hidden_states, residual)\n        hidden_states = self.self_attn(\n            hidden_states,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n\n        hidden_states, residual = self.ln_2(hidden_states, residual)\n\n        mlp_output = self.mlp(hidden_states)\n\n        return mlp_output, residual\n\n\nclass FlashSantacoderModel(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.config = config\n\n        self.process_group = weights.process_group\n        self.wte = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wte\",\n            weights=weights,\n            reduce=False,\n        )\n        self.wpe = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wpe\",\n            weights=weights,\n            reduce=False,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                Block(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.ln_f = FastLayerNorm.load(\n            prefix=\"transformer.ln_f\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.wte(input_ids) + self.wpe(position_ids)\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(hidden_states, group=self.process_group)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashSantacoderForCausalLM(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        config.transpose = config.architectures[0].startswith(\"GPT2\")\n        self.model = FlashSantacoderModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config, prefix=f\"{prefix}.wte\", weights=weights\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    set_block_mapping,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.layers import (\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n    FastRMSNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\nimport habana_frameworks.torch as htorch\n\n\nclass Starcoder2Config(PretrainedConfig):\n    model_type = \"starcoder2\"\n\n    def __init__(\n        self,\n        vocab_size=49152,\n        hidden_size=3072,\n        intermediate_size=12288,\n        num_hidden_layers=30,\n        num_attention_heads=24,\n        num_key_value_heads=2,\n        mlp_type=\"default\",\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=4096,\n        initializer_range=0.018042,\n        norm_type=\"layer_norm\",\n        norm_epsilon=1e-5,\n        use_cache=True,\n        bos_token_id=50256,\n        eos_token_id=50256,\n        rope_theta=10000.0,\n        sliding_window=None,\n        attention_dropout=0.0,\n        residual_dropout=0.0,\n        embedding_dropout=0.0,\n        use_bias: bool = True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n        self.use_bias = use_bias\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.mlp_type = mlp_type\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.norm_type = norm_type\n        self.norm_epsilon = norm_epsilon\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n        self.residual_dropout = residual_dropout\n        self.embedding_dropout = embedding_dropout\n\n        super().__init__(\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            **kwargs,\n        )\n\n\ndef load_attention(config, prefix, weights, layer_id):\n    prefixes = [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n    head_size = config.hidden_size // config.num_attention_heads\n    sizes = [\n        head_size * config.num_attention_heads,\n        head_size * config.num_key_value_heads,\n        head_size * config.num_key_value_heads,\n    ]\n    if config.num_attention_heads != config.num_key_value_heads:\n        base_layer = _load_gqa(config, prefix, weights)\n    else:\n        base_layer = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            dim=0,\n            weights=weights,\n            bias=config.use_bias,\n        )\n    return TensorParallelMultiAdapterLinear.load(\n        base_layer=base_layer,\n        layer_id=layer_id,\n        layer_names=prefixes,\n        sizes=sizes,\n        process_group=weights.process_group,\n    )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    if config.use_bias:\n        w = [\n            weights.get_sharded(f\"{p}.bias\", dim=0)\n            for p in [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n        ]\n        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=bias))\n\n\nclass Starcoder2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        index: int,\n        prefix: str,\n        config,\n        weights,\n        rotary_emb,\n    ):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n        self.rotary_emb = rotary_emb\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights, index)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=getattr(config, \"use_bias\", False),\n        )\n\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            index,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # sdpa\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                seqlen,\n                kv_scales=self.kv_scales,\n                hpu_attention_meta=hpu_attention_meta,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Starcoder2MLP(nn.Module):\n    def __init__(self, prefix, config, weights, index):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        c_fc = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.c_fc\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n        c_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n\n        self.c_fc = TensorParallelMultiAdapterLinear.load(\n            c_fc,\n            layer_id=index,\n            layer_names=[f\"{prefix}.c_fc\"],\n            sizes=[config.intermediate_size, config.intermediate_size],\n            process_group=weights.process_group,\n        )\n\n        self.c_proj = TensorParallelAdapterRowLinear.load(\n            c_proj,\n            index,\n            \"c_proj\",\n            process_group=weights.process_group,\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        hidden_states = self.c_fc(hidden_states, adapter_data)\n        hidden_states = self.act(hidden_states)\n        return self.c_proj(hidden_states, adapter_data)\n\n\nclass Starcoder2GatedMLP(nn.Module):\n    def __init__(self, index, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        prefixes = [f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"]\n        sizes = [\n            config.intermediate_size,\n            config.intermediate_size,\n        ]\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            weights=weights,\n            dim=0,\n            bias=config.use_bias,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            index,\n            layer_names=prefixes,\n            sizes=sizes,\n            process_group=weights.process_group,\n        )\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            index,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nSTARCODER2_NORMALIZATION_CLASSES = {\n    \"layer_norm\": FastLayerNorm,\n    \"rms_norm\": FastRMSNorm,\n}\n\nSTARCODER2_MLP_CLASSES = {\n    \"default\": Starcoder2MLP,\n    \"gated\": Starcoder2GatedMLP,\n}\n\n\nclass Starcoder2Layer(nn.Module):\n    def __init__(self, layer_id, config, weights, rotary_emb):\n        super().__init__()\n        prefix = f\"model.layers.{layer_id}\"\n        self.self_attn = Starcoder2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            index=layer_id,\n            rotary_emb=rotary_emb,\n        )\n\n        self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, index=layer_id\n        )\n\n        self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.norm_epsilon\n        )\n        self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[\n            config.norm_type\n        ].load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.norm_epsilon,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        slots,\n        seqlen,\n        adapter_data,\n        hpu_attention_meta,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n\n        return mlp_output, attn_res\n\n\nclass Starcoder2Model(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=config.hidden_size // config.num_attention_heads,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n        self.layers = nn.ModuleList(\n            [\n                Starcoder2Layer(\n                    layer_id,\n                    config,\n                    weights,\n                    rotary_emb,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.norm_epsilon\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        adapter_data,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n    ) -> torch.Tensor:\n        if hpu_attention_meta is not None:\n            hpu_attention_meta = set_block_mapping(\n                hpu_attention_meta, input_ids.shape[0]\n            )\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)\n\n        residual = None\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                slots,\n                seqlen,\n                adapter_data,\n                hpu_attention_meta,\n            )\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashStarcoder2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = Starcoder2Model(prefix, config, weights)\n        try:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=\"lm_head\",\n                weights=weights,\n            )\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=f\"{prefix}.embed_tokens\",\n                weights=weights,\n            )\n\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            slots,\n            seqlen,\n            adapter_data,\n            hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/idefics2.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Idefics2 model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nimport math\n\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n)\nfrom text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Idefics2VisionEmbeddings(nn.Module):\n    \"\"\"\n    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable\n    resolution.\n\n    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)\n    which allows treating images in their native aspect ratio and without the need to resize them to the same\n    fixed size. In particular, we start from the original pre-trained SigLIP model\n    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n\n        self.num_patches_per_side = self.image_size // self.patch_size\n        self.num_patches = self.num_patches_per_side**2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n\n    def forward(\n        self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor\n    ) -> torch.Tensor:\n        batch_size, _, max_im_h, max_im_w = pixel_values.shape\n\n        patch_embeds = self.patch_embedding(pixel_values)\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        max_nb_patches_h, max_nb_patches_w = (\n            max_im_h // self.patch_size,\n            max_im_w // self.patch_size,\n        )\n        boundaries = torch.arange(\n            1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side\n        )\n        position_ids = torch.full(\n            size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0\n        )\n\n        for batch_idx, p_attn_mask in enumerate(patch_attention_mask):\n            nb_patches_h = p_attn_mask[:, 0].sum()\n            nb_patches_w = p_attn_mask[0].sum()\n\n            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)\n            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)\n\n            bucket_coords_h = torch.bucketize(\n                fractional_coords_h, boundaries, right=True\n            )\n            bucket_coords_w = torch.bucketize(\n                fractional_coords_w, boundaries, right=True\n            )\n\n            pos_ids = (\n                bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w\n            ).flatten()\n            position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids\n\n        position_ids = position_ids.to(self.position_embedding.weight.device)\n        embeddings = embeddings + self.position_embedding(position_ids)\n        return embeddings\n\n\nclass Idefics2VisionAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n        self.is_causal = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        batch_size, q_len, _ = hidden_states.size()\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n\n        k_v_seq_len = key_states.shape[-2]\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale\n        )\n\n        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics2VisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Idefics2EncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = Idefics2VisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.mlp = Idefics2VisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n    # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Idefics2Encoder(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                Idefics2EncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    # Ignore copy\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n        return hidden_states\n\n\nclass Idefics2VisionTransformer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embeddings = Idefics2VisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = Idefics2Encoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.post_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        pixel_values,\n        patch_attention_mask: Optional[torch.BoolTensor] = None,\n    ):\n        batch_size = pixel_values.size(0)\n        if patch_attention_mask is None:\n            patch_size = self.config.patch_size\n            patch_attention_mask = torch.ones(\n                (\n                    batch_size,\n                    pixel_values.size(2) // patch_size,\n                    pixel_values.size(3) // patch_size,\n                )\n            )\n            patch_attention_mask = patch_attention_mask.to(\n                dtype=torch.bool, device=pixel_values.device\n            )\n\n        hidden_states = self.embeddings(\n            pixel_values=pixel_values, patch_attention_mask=patch_attention_mask\n        )\n\n        patch_attention_mask = patch_attention_mask.view(batch_size, -1)\n        # The call to `_upad_input` in `_flash_attention_forward` is expensive\n        # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),\n        # avoiding passing the attention_mask, which is equivalent to attending to the full sequence\n        if not torch.any(~patch_attention_mask):\n            patch_attention_mask = None\n        else:\n            patch_attention_mask = _prepare_4d_attention_mask(\n                patch_attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=patch_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        return last_hidden_state\n\n\nclass Idefics2MLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.text_config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(self, hidden_states):\n        start_shape = hidden_states.shape[:-1]\n        gate_up_states = self.gate_up_proj(hidden_states)\n        intermediate_size = gate_up_states.shape[-1] // 2\n        gate_up_states = gate_up_states.view(-1, 2, intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]\n        ).view(*start_shape, -1)\n\n\nclass Idefics2RMSNorm(nn.Module):\n    def __init__(self, prefix, weights, eps):\n        \"\"\"\n        Idefics2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.weight\"), requires_grad=False\n        )\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass Idefics2PerceiverAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.layer_idx = None\n        self.hidden_size = config.text_config.hidden_size\n        self.num_heads = config.perceiver_config.resampler_n_heads\n        self.head_size = config.perceiver_config.resampler_head_dim\n        self.num_key_value_heads = config.perceiver_config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.attention_dropout = config.perceiver_config.attention_dropout\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            self.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.q_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.q_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.kv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.o_proj\", weights=weights, bias=False\n        )\n\n        self.is_causal = False\n\n    def forward(\n        self,\n        latents: torch.Tensor,\n        context: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = latents.size()\n        kv_seq_len = q_len + context.size()[1]\n\n        hidden_states = torch.concat([context, latents], dim=-2)\n        query_states = self.q_proj(latents)\n        kv = self.kv(hidden_states)\n        key_states, value_states = kv.split(\n            [\n                self.head_size * self.num_key_value_heads,\n                self.head_size * self.num_key_value_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, kv_seq_len, self.num_key_value_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, kv_seq_len, self.num_key_value_heads, self.head_size\n        ).transpose(1, 2)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(\n            query_states, key_states.transpose(2, 3)\n        ) / math.sqrt(self.head_size)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics2PerceiverLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.text_config.hidden_size\n        self.n_latents = config.perceiver_config.resampler_n_latents\n        self.depth = config.perceiver_config.resampler_depth\n        self.rms_norm_eps = config.text_config.rms_norm_eps\n\n        self.input_latents_norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.input_latents_norm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.input_context_norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.input_context_norm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.self_attn = Idefics2PerceiverAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.post_attention_layernorm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.mlp = Idefics2MLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n    def forward(\n        self,\n        latents: torch.Tensor,\n        context: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"\n        Args:\n            latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n        \"\"\"\n        residual = latents\n\n        latents = self.input_latents_norm(latents)\n        context = self.input_context_norm(context)\n\n        latents = self.self_attn(\n            latents=latents,\n            context=context,\n            attention_mask=attention_mask,\n        )\n        latents = residual + latents\n        residual = latents\n\n        latents = self.post_attention_layernorm(latents)\n        latents = self.mlp(latents)\n        latents = residual + latents\n\n        return latents\n\n\nclass Idefics2PerceiverResampler(nn.Module):\n    def __init__(self, prefix, config, weights) -> None:\n        super().__init__()\n        self.hidden_size = config.text_config.hidden_size\n        self.hidden_act = config.perceiver_config.hidden_act\n        self.n_latents = config.perceiver_config.resampler_n_latents\n        self.depth = config.perceiver_config.resampler_depth\n        self.rms_norm_eps = config.text_config.rms_norm_eps\n\n        # Create Latents for Perceiver\n        self.latents = weights.get_tensor(f\"{prefix}.latents\")\n\n        # Create Transformer Blocks\n        self.layers = nn.ModuleList(\n            [\n                Idefics2PerceiverLayer(\n                    prefix=f\"{prefix}.layers.{idx}\", config=config, weights=weights\n                )\n                for idx in range(self.depth)\n            ]\n        )\n        self.norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.norm\",\n            weights=weights,\n            eps=config.text_config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        context: torch.Tensor,\n        attention_mask,\n    ) -> torch.Tensor:\n        # seq embed -> bsz seq embed\n        latents = self.latents.unsqueeze(0).expand(\n            (context.shape[0], *self.latents.size())\n        )\n\n        latent_attention_mask = torch.ones(\n            (attention_mask.size(0), latents.size(1)),\n            dtype=attention_mask.dtype,\n            device=attention_mask.device,\n        )\n        attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)\n        attention_mask = _prepare_4d_attention_mask(\n            attention_mask, latents.dtype, tgt_len=self.n_latents\n        )\n\n        compressed_context = latents\n        for perceiver_layer in self.layers:\n            compressed_context = perceiver_layer(\n                compressed_context,\n                context,\n                attention_mask=attention_mask,\n            )\n        compressed_context = self.norm(compressed_context)\n\n        return compressed_context\n\n\nclass Idefics2Connector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.modality_projection = Idefics2MLP(\n            prefix=f\"{prefix}.modality_projection\", config=config, weights=weights\n        )\n        self.perceiver_resampler = Idefics2PerceiverResampler(\n            prefix=f\"{prefix}.perceiver_resampler\", config=config, weights=weights\n        )\n\n    def forward(self, image_hidden_states, attention_mask):\n        image_hidden_states = self.modality_projection(image_hidden_states)\n        image_hidden_states = self.perceiver_resampler(\n            context=image_hidden_states, attention_mask=attention_mask\n        )\n        return image_hidden_states\n\n\nclass Idefics2ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n\n        vision_config = config.vision_config\n        self.text_model = load_text_model(\n            prefix=\"model\" if not prefix else f\"{prefix}.model\",\n            config=config.text_config,\n            weights=weights,\n            name=\"text_model\",\n        )\n        self.dtype = weights.dtype\n\n        # The vision and connector models are not quantized.\n        with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):\n            self.vision_model = Idefics2VisionTransformer(\n                prefix=(\n                    f\"{prefix}.model.vision_model\" if prefix else \"model.vision_model\"\n                ),\n                config=vision_config,\n                weights=weights,\n            )\n\n            config.quantize = None\n            self.connector = Idefics2Connector(\n                prefix=f\"{prefix}.model.connector\" if prefix else \"model.connector\",\n                config=config,\n                weights=weights,\n            )\n\n        self.config = config\n        self.image_seq_len = config.perceiver_config.resampler_n_latents\n        self.image_token_id = config.image_token_id\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        # mask = input_ids == self.config.image_token_index\n        #  - replace `==` with torch.where to fix the issue in hpu graph\n        mask = torch.where(input_ids == self.config.image_token_id)\n        # Let's pray we have enabled enough slots !\n        inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        assert pixel_values is not None\n        batch_size, num_images, num_channels, height, width = pixel_values.shape\n        all_states = []\n        all_pixel_values = pixel_values\n        all_pixel_mask = pixel_attention_mask\n        for i in range(batch_size):\n            pixel_values = all_pixel_values.to(dtype=self.dtype)  # fp16 compatibility\n            pixel_values = pixel_values[i : i + 1]\n            pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])\n\n            # Remove padding images - padding images are full 0.\n            nb_values_per_image = pixel_values.shape[1:].numel()\n            real_images_inds = (pixel_values == 0.0).sum(\n                dim=(-1, -2, -3)\n            ) != nb_values_per_image\n            pixel_values = pixel_values[real_images_inds].contiguous()\n\n            # Handle the vision attention mask\n            if pixel_attention_mask is None:\n                pixel_attention_mask = torch.ones(\n                    size=(\n                        pixel_values.size(0),\n                        pixel_values.size(2),\n                        pixel_values.size(3),\n                    ),\n                    dtype=torch.bool,\n                    device=pixel_values.device,\n                )\n            else:\n                # Remove padding images from the mask/pP p\n                pixel_attention_mask = all_pixel_mask[i : i + 1]\n                pixel_attention_mask = pixel_attention_mask.view(\n                    1 * num_images, *pixel_attention_mask.shape[2:]\n                )\n                pixel_attention_mask = pixel_attention_mask[\n                    real_images_inds\n                ].contiguous()\n\n            patch_size = self.config.vision_config.patch_size\n            \"\"\"\n            patches_subgrid = pixel_attention_mask.unfold(\n                dimension=1, size=patch_size, step=patch_size\n            )\n            patches_subgrid = patches_subgrid.unfold(\n                dimension=2, size=patch_size, step=patch_size\n            )\n            patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()\n            \"\"\"\n            # hpu does none support unfold\n            conv_kernel = torch.ones(\n                [1, 1, patch_size, patch_size],\n                dtype=pixel_values.dtype,\n                device=pixel_values.device,\n            )\n            patches_subgrid = torch.nn.functional.conv2d(\n                pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),\n                conv_kernel,\n                stride=patch_size,\n            ).squeeze(1)\n            patch_attention_mask = torch.gt(patches_subgrid, 0)\n\n            # Get sequence from the vision encoder\n            image_hidden_states = self.vision_model(\n                pixel_values=pixel_values,\n                patch_attention_mask=patch_attention_mask,\n            )\n\n            # Modality projection & resampling\n            image_hidden_states = self.connector(\n                image_hidden_states,\n                attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),\n            )\n            all_states.append(image_hidden_states)\n        image_hidden_states = torch.stack(all_states, dim=0)\n        return image_hidden_states.view(-1, image_hidden_states.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/idefics3.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Idefics3 model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n)\nfrom text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Idefics3VisionEmbeddings(nn.Module):\n    \"\"\"\n    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable\n    resolution.\n\n    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)\n    which allows treating images in their native aspect ratio and without the need to resize them to the same\n    fixed size. In particular, we start from the original pre-trained SigLIP model\n    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n\n        self.num_patches_per_side = self.image_size // self.patch_size\n        self.num_patches = self.num_patches_per_side**2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n\n    def forward(\n        self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor\n    ) -> torch.Tensor:\n        batch_size, _, max_im_h, max_im_w = pixel_values.shape\n\n        patch_embeds = self.patch_embedding(pixel_values)\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        max_nb_patches_h, max_nb_patches_w = (\n            max_im_h // self.patch_size,\n            max_im_w // self.patch_size,\n        )\n        boundaries = torch.arange(\n            1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side\n        )\n        position_ids = torch.full(\n            size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0\n        )\n\n        for batch_idx, p_attn_mask in enumerate(patch_attention_mask):\n            nb_patches_h = p_attn_mask[:, 0].sum()\n            nb_patches_w = p_attn_mask[0].sum()\n\n            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)\n            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)\n\n            bucket_coords_h = torch.bucketize(\n                fractional_coords_h, boundaries, right=True\n            )\n            bucket_coords_w = torch.bucketize(\n                fractional_coords_w, boundaries, right=True\n            )\n\n            pos_ids = (\n                bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w\n            ).flatten()\n            position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids\n\n        position_ids = position_ids.to(self.position_embedding.weight.device)\n        embeddings = embeddings + self.position_embedding(position_ids)\n        return embeddings\n\n\nclass Idefics3VisionAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n        self.is_causal = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        batch_size, q_len, _ = hidden_states.size()\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n\n        k_v_seq_len = key_states.shape[-2]\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale\n        )\n\n        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics3VisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Idefics3EncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = Idefics3VisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.mlp = Idefics3VisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n    # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Idefics3Encoder(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                Idefics3EncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    # Ignore copy\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n        return hidden_states\n\n\nclass Idefics3VisionTransformer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embeddings = Idefics3VisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = Idefics3Encoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.post_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        pixel_values,\n        patch_attention_mask: Optional[torch.BoolTensor] = None,\n    ):\n        batch_size = pixel_values.size(0)\n        if patch_attention_mask is None:\n            patch_size = self.config.patch_size\n            patch_attention_mask = torch.ones(\n                (\n                    batch_size,\n                    pixel_values.size(2) // patch_size,\n                    pixel_values.size(3) // patch_size,\n                )\n            )\n            patch_attention_mask = patch_attention_mask.to(\n                dtype=torch.bool, device=pixel_values.device\n            )\n\n        hidden_states = self.embeddings(\n            pixel_values=pixel_values, patch_attention_mask=patch_attention_mask\n        )\n\n        patch_attention_mask = patch_attention_mask.view(batch_size, -1)\n        # The call to `_upad_input` in `_flash_attention_forward` is expensive\n        # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),\n        # avoiding passing the attention_mask, which is equivalent to attending to the full sequence\n        if not torch.any(~patch_attention_mask):\n            patch_attention_mask = None\n        else:\n            patch_attention_mask = _prepare_4d_attention_mask(\n                patch_attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=patch_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        return last_hidden_state\n\n\nclass Idefics3SimpleMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        input_size = config.vision_config.hidden_size * (config.scale_factor**2)\n        output_size = config.text_config.hidden_size\n        proj = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.modality_projection.proj.weight\"),\n            requires_grad=False,\n        ).to(weights.dtype)\n        self.proj = nn.Linear(input_size, output_size, bias=False)\n        self.proj.weight = proj\n\n    def forward(self, x):\n        return self.proj(x)\n\n\nclass Idefics3Connector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)\n        self.scale_factor = config.scale_factor\n\n    def pixel_shuffle(self, x, scale_factor=2):\n        bsz, seq, embed_dim = x.size()\n        height = width = int(seq**0.5)\n        x = x.view(bsz, height, width, embed_dim)\n        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)\n        x = x.permute(0, 2, 1, 3)\n        x = x.reshape(\n            bsz,\n            int(width / scale_factor),\n            int(height / scale_factor),\n            embed_dim * (scale_factor**2),\n        )\n        x = x.permute(0, 2, 1, 3)\n        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))\n        return x\n\n    def forward(self, image_hidden_states):\n        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)\n        image_hidden_states = self.modality_projection(image_hidden_states)\n        return image_hidden_states\n\n\nclass Idefics3ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`\n        # since Idefics3 uses the `embed_tokens` for the final prediction\n        # config.text_config.tie_word_embeddings = True\n\n        vision_config = config.vision_config\n        self.text_model = load_text_model(\n            prefix=\"model\" if not prefix else f\"{prefix}.model\",\n            config=config.text_config,\n            weights=weights,\n            name=\"text_model\",\n        )\n        self.dtype = weights.dtype\n\n        # The vision and connector models are not quantized.\n        with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):\n            self.vision_model = Idefics3VisionTransformer(\n                prefix=(\n                    f\"{prefix}.model.vision_model\" if prefix else \"model.vision_model\"\n                ),\n                config=vision_config,\n                weights=weights,\n            )\n\n            config.quantize = None\n            self.connector = Idefics3Connector(\n                prefix=f\"{prefix}.model.connector\" if prefix else \"model.connector\",\n                config=config,\n                weights=weights,\n            )\n\n        self.config = config\n        self.image_token_id = config.image_token_id\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        # mask = input_ids == self.config.image_token_index\n        #  - replace `==` with torch.where to fix the issue in hpu graph\n        mask = torch.where(input_ids == self.config.image_token_id)\n        # Let's pray we have enabled enough slots !\n        inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        batch_size, num_images, num_channels, height, width = pixel_values.shape\n        all_states = []\n        all_pixel_values = pixel_values\n        all_pixel_mask = pixel_attention_mask\n        for i in range(batch_size):\n            pixel_values = all_pixel_values.to(dtype=self.dtype)  # fp16 compatibility\n            pixel_values = pixel_values[i : i + 1]\n            pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])\n\n            # Remove padding images - padding images are full 0.\n            nb_values_per_image = pixel_values.shape[1:].numel()\n            real_images_inds = (pixel_values == 0.0).sum(\n                dim=(-1, -2, -3)\n            ) != nb_values_per_image\n            pixel_values = pixel_values[real_images_inds].contiguous()\n            # Handle the vision attention mask\n            if pixel_attention_mask is None:\n                pixel_attention_mask = torch.ones(\n                    size=(\n                        pixel_values.size(0),\n                        pixel_values.size(2),\n                        pixel_values.size(3),\n                    ),\n                    dtype=torch.bool,\n                    device=pixel_values.device,\n                )\n            else:\n                # Remove padding images from the mask/pP p\n                pixel_attention_mask = all_pixel_mask[i : i + 1]\n                pixel_attention_mask = pixel_attention_mask.view(\n                    1 * num_images, *pixel_attention_mask.shape[2:]\n                )\n                pixel_attention_mask = pixel_attention_mask[\n                    real_images_inds\n                ].contiguous()\n\n            patch_size = self.config.vision_config.patch_size\n\n            \"\"\"\n            patches_subgrid = pixel_attention_mask.unfold(\n                dimension=1, size=patch_size, step=patch_size\n            )\n            patches_subgrid = patches_subgrid.unfold(\n                dimension=2, size=patch_size, step=patch_size\n            )\n            patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()\n            \"\"\"\n            # hpu does none support unfold\n            conv_kernel = torch.ones(\n                [1, 1, patch_size, patch_size],\n                dtype=pixel_values.dtype,\n                device=pixel_values.device,\n            )\n            patches_subgrid = torch.nn.functional.conv2d(\n                pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),\n                conv_kernel,\n                stride=patch_size,\n            ).squeeze(1)\n            patch_attention_mask = torch.gt(patches_subgrid, 0)\n\n            # Get sequence from the vision encoder\n            image_hidden_states = self.vision_model(\n                pixel_values=pixel_values,\n                patch_attention_mask=patch_attention_mask,\n            )\n\n            # Modality projection & resampling\n            image_hidden_states = self.connector(\n                image_hidden_states,\n            )\n\n            all_states.append(image_hidden_states)\n        image_hidden_states = torch.stack(all_states, dim=0)\n\n        return image_hidden_states.view(-1, image_hidden_states.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/mamba_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom mamba_ssm.ops.triton.selective_state_update import selective_state_update\nfrom mamba_ssm.ops.selective_scan_interface import selective_scan_fn\nfrom torch import nn\nfrom typing import Optional, Tuple, Any\nfrom transformers.configuration_utils import PretrainedConfig\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers import (\n    SpeculativeHead,\n    TensorParallelEmbedding,\n    FastLinear,\n)\nfrom text_generation_server.layers.layernorm import FastRMSNorm\n\nfrom einops import rearrange\nfrom causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nimport math\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass InferenceParams:\n    \"\"\"Inference parameters that are passed to the main model in order\n    to efficienly calculate and store the context during inference.\"\"\"\n\n    max_seqlen: int\n    max_batch_size: int\n    conv_states: torch.Tensor\n    ssm_states: torch.Tensor\n    seqlen_offset: int\n\n\nclass MambaConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=50280,\n        d_model=768,\n        d_state=16,\n        n_layer=32,\n        layer_norm_epsilon=1e-5,\n        tie_word_embeddings=False,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        expand=2,\n        dt_rank=\"auto\",\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.n_layer = n_layer\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.d_model = d_model\n        self.d_inner = d_model * 2\n        self.d_conv = 4\n        self.d_state = d_state\n        self.expand = expand\n        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == \"auto\" else dt_rank\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass MambaBlock(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        self.layer_id = layer_id\n        self.in_proj = FastLinear.load(config, f\"{prefix}.in_proj\", weights, bias=False)\n        self.x_proj = FastLinear.load(config, f\"{prefix}.x_proj\", weights, bias=False)\n        self.dt_proj = FastLinear.load(config, f\"{prefix}.dt_proj\", weights, bias=True)\n        self.dt_proj_no_bias = FastLinear.load(\n            config, f\"{prefix}.dt_proj\", weights, bias=False\n        )\n        self.out_proj = FastLinear.load(\n            config, f\"{prefix}.out_proj\", weights, bias=False\n        )\n        self.conv1d = FastLinear.load(config, f\"{prefix}.conv1d\", weights, bias=True)\n        self.negA = -torch.exp(weights.get_tensor(f\"{prefix}.A_log\").float())\n        self.D = weights.get_tensor(f\"{prefix}.D\")\n        self.activation = \"silu\"\n        self.dt_rank = config.dt_rank\n        self.d_state = config.d_state\n        self.d_conv = config.d_conv\n        self.act = nn.SiLU()\n\n    # inference_params\n    def forward(self, hidden_states: torch.Tensor, inference_params=None):\n        if inference_params.seqlen_offset > 0:\n            conv_state = inference_params.conv_states[self.layer_id]\n            ssm_state = inference_params.ssm_states[self.layer_id]\n            out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)\n            return out, conv_state, ssm_state\n\n        _, seqlen, _ = hidden_states.shape\n        projected_states = self.in_proj(hidden_states).transpose(1, 2)\n        # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f\"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]\"\n        x, z = projected_states.chunk(2, dim=1)\n        conv_state = F.pad(x, (self.d_conv - seqlen, 0))\n        x = causal_conv1d_fn(\n            x=x,\n            weight=self.conv1d.weight.squeeze(1),\n            bias=self.conv1d.bias,\n            activation=self.activation,\n        )\n\n        # We're careful here about the layout, to avoid extra transposes.\n        # We want dt to have d as the slowest moving dimension\n        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.\n        x_dbl = self.x_proj(rearrange(x, \"b d l -> (b l) d\"))  # (bl d)\n        dt, B, C = torch.split(\n            x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1\n        )\n        dt = self.dt_proj.weight @ dt.t()\n        dt = rearrange(dt, \"d (b l) -> b d l\", l=seqlen)\n        B = rearrange(B, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n        C = rearrange(C, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n        y, last_state = selective_scan_fn(\n            x,\n            dt,\n            self.negA,\n            B,\n            C,\n            self.D.float(),\n            z=z,\n            delta_bias=self.dt_proj.bias.float(),\n            delta_softplus=True,\n            return_last_state=True,\n        )\n        y = rearrange(y, \"b d l -> b l d\")\n        attn_outputs = self.out_proj(y)\n        return attn_outputs, conv_state, last_state\n\n    def step(self, hidden_states, conv_state, ssm_state):\n        xz = self.in_proj(hidden_states.squeeze(1))\n        x, z = xz.chunk(2, dim=-1)  # (B D)\n        x = causal_conv1d_update(\n            x,\n            conv_state,\n            self.conv1d.weight.squeeze(1),\n            self.conv1d.bias,\n            self.activation,\n        )\n        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)\n        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n        dt = F.linear(dt, self.dt_proj.weight)\n        A = self.negA\n        y = selective_state_update(\n            ssm_state,\n            x,\n            dt,\n            A,\n            B,\n            C,\n            self.D,\n            z=z,\n            dt_bias=self.dt_proj.bias,\n            dt_softplus=True,\n        )\n        out = self.out_proj(y)\n        return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        self.mamba_block = MambaBlock(\n            prefix=f\"{prefix}.mixer\", config=config, weights=weights, layer_id=layer_id\n        )\n        self.layer_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        inference_params: Optional[Any] = None,\n    ):\n        residual = (hidden_states + residual) if residual is not None else hidden_states\n        shape = residual.shape\n        hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))\n        hidden_states, conv_state, last_ssm_state = self.mamba_block(\n            hidden_states.view(*shape), inference_params\n        )\n        return hidden_states, residual, conv_state, last_ssm_state\n\n\nclass MambaModel(nn.Module):\n    def __init__(self, config, weights):\n        super().__init__()\n        prefix = \"backbone\"\n        try:\n            self.embed_tokens = TensorParallelEmbedding(f\"{prefix}.embeddings\", weights)\n        except RuntimeError:\n            self.embed_tokens = TensorParallelEmbedding(f\"{prefix}.embedding\", weights)\n        self.blocks = nn.ModuleList(\n            [\n                ResidualBlock(f\"{prefix}.layers.{i}\", config, weights, layer_id=i)\n                for i in range(config.n_layer)\n            ]\n        )\n        self.norm_f = FastRMSNorm.load(\n            f\"{prefix}.norm_f\", weights, eps=config.layer_norm_epsilon\n        )\n        try:\n            self.lm_head = SpeculativeHead.load(config, f\"{prefix}.embeddings\", weights)\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(config, f\"{prefix}.embedding\", weights)\n        self.config = config\n\n    def forward(\n        self, input_ids: torch.Tensor, inference_params=None, residual=None\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.embed_tokens(input_ids)\n        for i, block in enumerate(self.blocks):\n            hidden_states, residual, conv_state, ssm_state = block(\n                hidden_states, residual, inference_params\n            )\n            inference_params.conv_states[i].copy_(conv_state)\n            inference_params.ssm_states[i].copy_(ssm_state)\n\n        hidden_states = (\n            hidden_states + residual if residual is not None else hidden_states\n        )\n        hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))\n        hidden_states = hidden_states.view(residual.shape)\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        # update the offset for the next inference using these params\n        inference_params.seqlen_offset += input_ids.size(1)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py",
    "content": "# coding=utf-8\n# Copyright 2025 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2.5 VL model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom habana_frameworks.torch.hpex.kernels import FusedSDPA\nfrom vllm_hpu_extension.utils import ModuleFusedSDPA\n\n\nimport numpy as np\n\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n    Qwen2Model,\n)\nfrom habana_frameworks.torch.hpex.kernels import (\n    RotaryPosEmbeddingMode,\n    apply_rotary_pos_emb,\n)\nimport habana_frameworks.torch as htorch\n\n# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py\nfrom typing import Union\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.image_utils import ImageInput\nfrom transformers.video_utils import VideoInput\nfrom transformers.processing_utils import (\n    ProcessingKwargs,\n    ProcessorMixin,\n    Unpack,\n    VideosKwargs,\n)\nfrom transformers.tokenization_utils_base import PreTokenizedInput, TextInput\n\n\nclass Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):\n    fps: Union[List[float], float]\n\n\nclass Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):\n    videos_kwargs: Qwen2_5_VLVideosProcessorKwargs\n    _defaults = {\n        \"text_kwargs\": {\n            \"padding\": False,\n        },\n        \"videos_kwargs\": {\"fps\": 2.0},\n    }\n\n\nclass Qwen2_5_VLProcessor(ProcessorMixin):\n    r\"\"\"\n    Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.\n    [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the\n    [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.\n    Args:\n        image_processor ([`Qwen2VLImageProcessor`], *optional*):\n            The image processor is a required input.\n        tokenizer ([`Qwen2TokenizerFast`], *optional*):\n            The tokenizer is a required input.\n        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages\n            in a chat into a tokenizable string.\n    \"\"\"\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n    valid_kwargs = [\"chat_template\"]\n\n    image_processor_class = \"AutoImageProcessor\"\n    tokenizer_class = (\"Qwen2Tokenizer\", \"Qwen2TokenizerFast\")\n\n    def __init__(\n        self, image_processor=None, tokenizer=None, chat_template=None, **kwargs\n    ):\n        self.image_token = (\n            \"<|image_pad|>\"\n            if not hasattr(tokenizer, \"image_token\")\n            else tokenizer.image_token\n        )\n        self.video_token = (\n            \"<|video_pad|>\"\n            if not hasattr(tokenizer, \"video_token\")\n            else tokenizer.video_token\n        )\n        super().__init__(image_processor, tokenizer, chat_template=chat_template)\n\n    def __call__(\n        self,\n        images: ImageInput = None,\n        text: Union[\n            TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]\n        ] = None,\n        videos: VideoInput = None,\n        **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],\n    ) -> BatchFeature:\n        \"\"\"\n        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`\n        and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode\n        the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to\n        Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.\n\n        Args:\n            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch\n                tensor. Both channels-first and channels-last formats are supported.\n            text (`str`, `List[str]`, `List[List[str]]`):\n                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n            videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch\n                tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.\n            return_tensors (`str` or [`~utils.TensorType`], *optional*):\n                If set, will return tensors of a particular framework. Acceptable values are:\n                - `'tf'`: Return TensorFlow `tf.constant` objects.\n                - `'pt'`: Return PyTorch `torch.Tensor` objects.\n                - `'np'`: Return NumPy `np.ndarray` objects.\n                - `'jax'`: Return JAX `jnp.ndarray` objects.\n\n        Returns:\n            [`BatchFeature`]: A [`BatchFeature`] with the following fields:\n\n            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.\n            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when\n              `return_attention_mask=True` or if *\"attention_mask\"* is in `self.model_input_names` and if `text` is not\n              `None`).\n            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.\n            - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.\n            - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.\n            - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.\n            - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.\n        \"\"\"\n        output_kwargs = self._merge_kwargs(\n            Qwen2_5_VLProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer.init_kwargs,\n            **kwargs,\n        )\n        if images is not None:\n            image_inputs = self.image_processor(\n                images=images, videos=None, **output_kwargs[\"images_kwargs\"]\n            )\n            image_grid_thw = image_inputs[\"image_grid_thw\"]\n        else:\n            image_inputs = {}\n            image_grid_thw = None\n\n        if videos is not None:\n            videos_inputs = self.image_processor(\n                images=None, videos=videos, **output_kwargs[\"images_kwargs\"]\n            )\n            video_grid_thw = videos_inputs[\"video_grid_thw\"]\n\n            fps = output_kwargs[\"videos_kwargs\"].pop(\"fps\", 2.0)\n            if isinstance(fps, (int, float)):\n                second_per_grid_ts = [\n                    self.image_processor.temporal_patch_size / fps\n                ] * len(video_grid_thw)\n            elif hasattr(fps, \"__len__\") and len(fps) == len(video_grid_thw):\n                second_per_grid_ts = [\n                    self.image_processor.temporal_patch_size / tmp for tmp in fps\n                ]\n            else:\n                raise ValueError(\n                    f\"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number.\"\n                )\n            videos_inputs.update({\"second_per_grid_ts\": second_per_grid_ts})\n\n        else:\n            videos_inputs = {}\n            video_grid_thw = None\n\n        if not isinstance(text, list):\n            text = [text]\n\n        if image_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.image_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.image_token,\n                        \"<|placeholder|>\"\n                        * (image_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.image_token)\n\n        if video_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.video_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.video_token,\n                        \"<|placeholder|>\"\n                        * (video_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.video_token)\n\n        text_inputs = self.tokenizer(text, **output_kwargs[\"text_kwargs\"])\n\n        return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})\n\n    def batch_decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.batch_decode(*args, **kwargs)\n\n    def decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to\n        the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.decode(*args, **kwargs)\n\n    def post_process_image_text_to_text(self, generated_outputs):\n        \"\"\"\n        Post-process the output of the model to decode the text.\n\n        Args:\n            generated_outputs (`torch.Tensor` or `np.ndarray`):\n                The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`\n                or `(sequence_length,)`.\n\n        Returns:\n            `List[str]`: The decoded text.\n        \"\"\"\n        return self.tokenizer.batch_decode(\n            generated_outputs,\n            skip_special_tokens=True,\n            clean_up_tokenization_spaces=False,\n        )\n\n    @property\n    def model_input_names(self):\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n        names_from_processor = list(\n            dict.fromkeys(tokenizer_input_names + image_processor_input_names)\n        )\n        return names_from_processor + [\"second_per_grid_ts\"]\n\n\n# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py\nclass Qwen2_5_VLVisionConfig(PretrainedConfig):\n    model_type = \"qwen2_5_vl\"\n    base_config_key = \"vision_config\"\n\n    def __init__(\n        self,\n        depth=32,\n        hidden_size=3584,\n        hidden_act=\"silu\",\n        intermediate_size=3420,\n        num_heads=16,\n        in_channels=3,\n        patch_size=14,\n        spatial_merge_size=2,\n        spatial_patch_size=14,\n        temporal_patch_size=2,\n        tokens_per_second=4,\n        window_size=112,\n        out_hidden_size=3584,\n        fullatt_block_indexes=[7, 15, 23, 31],\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_patch_size = spatial_patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.tokens_per_second = tokens_per_second\n        self.window_size = window_size\n        self.fullatt_block_indexes = fullatt_block_indexes\n        self.out_hidden_size = out_hidden_size\n\n\nclass Qwen2_5_VLConfig(PretrainedConfig):\n\n    def __init__(\n        self,\n        vocab_size=152064,\n        hidden_size=8192,\n        intermediate_size=29568,\n        num_hidden_layers=80,\n        num_attention_heads=64,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-05,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=1000000.0,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=80,\n        attention_dropout=0.0,\n        vision_config=None,\n        rope_scaling=None,\n        **kwargs,\n    ):\n        if vision_config is not None:\n            self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)\n\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window\n        self.max_window_layers = max_window_layers\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n        self.rope_scaling = rope_scaling\n\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations\n        # one can set it to \"linear\"/\"dynamic\" etc. to have scaled RoPE\n        # TODO: @raushan update config in the hub\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            if self.rope_scaling[\"type\"] == \"mrope\":\n                self.rope_scaling[\"type\"] = \"default\"\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n\n\nclass Qwen2_5VLAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size // weights.process_group.size()\n        self.head_dim = config.hidden_size // config.num_heads\n        self.num_heads = config.num_heads // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv\",\n            weights=weights,\n            bias=False,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_heads,\n        )\n        self.qkv.linear.bias = weights.get_sharded(f\"{prefix}.qkv.bias\", dim=0)\n\n        self.proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.proj\",\n            weights=weights,\n            bias=True,\n        )\n        self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        max_seqlen: int,\n    ) -> torch.Tensor:\n        # apply the qkv linear layer to the hidden state\n        qkv = self.qkv(hidden_state)\n        query, key, value = qkv.split(\n            [self.embed_dim, self.embed_dim, self.embed_dim], dim=1\n        )\n\n        # reshape the query, key, and value tensors\n        _shape = (\n            hidden_state.shape[0],\n            self.num_heads,\n            self.embed_dim // self.num_heads,\n        )\n        query = query.view(*_shape)\n        key = key.view(*_shape)\n        value = value.view(*_shape)\n        # apply rotary positional embeddings\n        rope_mode = RotaryPosEmbeddingMode.BLOCKWISE\n        rotary_dim = cos.shape[-1]\n        query_rot = query[..., :rotary_dim]\n        query_pass = query[..., rotary_dim:]\n        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)\n        query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))\n\n        key_rot = key[..., :rotary_dim]\n        key_pass = key[..., rotary_dim:]\n        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)\n        key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))\n\n        # execute sdpa\n        causal = False\n        query = query.transpose(0, 1)\n        key = key.transpose(0, 1)\n        value = value.transpose(0, 1)\n        fsdpa_op = ModuleFusedSDPA(FusedSDPA)\n        attention_mask = torch.zeros(\n            [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool\n        )\n        for i in range(1, len(cu_seqlens)):\n            attention_mask[\n                :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]\n            ] = True\n        attn_output = fsdpa_op(\n            query,\n            key,\n            value,\n            attn_mask=attention_mask,\n            dropout_p=0.0,\n            is_causal=causal,\n            scale=None,\n            softmax_mode=\"None\",\n            recompute_mode=None,\n            valid_sequence_lengths=None,\n        )\n        attn_output = attn_output.transpose(0, 1)\n\n        # reshape output to original dimensions\n        attn_output = attn_output.reshape(hidden_state.shape[0], -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5VLVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.activation_fn = ACT2FN[config.hidden_act]\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        self.up = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.up_proj\", weights=weights, config=config, bias=True\n        )\n        self.gate = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.gate_proj\", weights=weights, config=config, bias=True\n        )\n        self.down = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.down_proj\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        gate_states = self.gate(hidden_states)\n        up_states = self.up(hidden_states)\n        activated_states = self.activation_fn(gate_states) * up_states\n        down_states = self.down(activated_states)\n        return down_states\n\n\nclass Qwen2_5VLVisionBlock(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.attn = Qwen2_5VLAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.norm1 = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm1\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.norm2 = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm2\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.mlp = Qwen2_5VLVisionMLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:\n        norm1_out, _ = self.norm1(hidden_states)\n        attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)\n        hidden_states = hidden_states + attn_out\n        norm2_out, _ = self.norm2(hidden_states)\n        mlp_out = self.mlp(norm2_out)\n        hidden_states = hidden_states + mlp_out\n        return hidden_states\n\n\nclass Qwen2_5VLPatchMerger(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)\n        self.patch_merger_ln_q = FastRMSNorm.load(\n            prefix=f\"{prefix}.ln_q\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.mlp.0\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.mlp.2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_states, _ = self.patch_merger_ln_q(hidden_states)\n        hidden_states = hidden_states.view(-1, self.hidden_size)\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = F.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2_5VisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n\n        self.spatial_merge_size = config.spatial_merge_size\n        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]\n        self.patch_embedding = nn.Conv3d(\n            in_channels=config.in_channels,\n            out_channels=config.hidden_size,\n            kernel_size=kernel_size,\n            stride=kernel_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embed.proj.weight\"), requires_grad=False\n        )\n        head_dim = config.hidden_size // config.num_heads\n\n        theta = 10000.0\n        dim = head_dim // 2\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        self.blocks = nn.ModuleList(\n            [\n                Qwen2_5VLVisionBlock(\n                    prefix=f\"{prefix}.blocks.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(config.depth)\n            ]\n        )\n        self.merger = Qwen2_5VLPatchMerger(\n            prefix=f\"{prefix}.merger\",\n            config=config,\n            weights=weights,\n        )\n\n        self.temporal_patch_size = config.temporal_patch_size\n        self.spatial_patch_size = config.spatial_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.hidden_size\n        self.window_size = config.window_size\n        self.patch_size = config.patch_size\n        self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size\n        self.fullatt_block_indexes = config.fullatt_block_indexes\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def get_window_index(self, grid_thw):\n        window_index: list = []\n        cu_window_seqlens: list = [0]\n        window_index_id = 0\n        vit_merger_window_size = (\n            self.window_size // self.spatial_merge_size // self.patch_size\n        )\n\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.spatial_merge_size,\n                grid_w // self.spatial_merge_size,\n            )\n            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(\n                grid_t, llm_grid_h, llm_grid_w\n            )\n            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size\n            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size\n            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size\n            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size\n            index_padded = F.pad(index, (0, pad_w, 0, pad_h), \"constant\", -100)\n            index_padded = index_padded.reshape(\n                grid_t,\n                num_windows_h,\n                vit_merger_window_size,\n                num_windows_w,\n                vit_merger_window_size,\n            )\n            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(\n                grid_t,\n                num_windows_h * num_windows_w,\n                vit_merger_window_size,\n                vit_merger_window_size,\n            )\n            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)\n            index_padded = index_padded.reshape(-1)\n            index_new = index_padded[index_padded != -100]\n            window_index.append(index_new + window_index_id)\n            cu_seqlens_tmp = (\n                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]\n            )\n            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())\n            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()\n        window_index = torch.cat(window_index, dim=0)\n\n        return window_index, cu_window_seqlens\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: Optional[torch.LongTensor] = None,\n    ) -> torch.Tensor:\n\n        # reshape the input tensor for processing\n        shape = (\n            -1,\n            self.in_channels,\n            self.temporal_patch_size,\n            self.spatial_patch_size,\n            self.spatial_patch_size,\n        )\n        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)\n        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)\n        # TODO: revisit to see if we can avoid some of these reshapes\n\n        # find the position ids for the input tensor based on the grid_thw\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n\n        max_grid_size = grid_thw[:, 1:].max()\n\n        # apply the positional embeddings to the position ids\n        seq = torch.arange(\n            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype\n        )\n        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        window_index, cu_window_seqlens = self.get_window_index(grid_thw)\n        seq_len = hidden_states.shape[0]\n        patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        og_shape = (seq_len, -1)\n\n        hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(\n            og_shape\n        )\n        rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(\n            og_shape\n        )\n\n        rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)\n        cos = rotary_pos_emb.cos()\n        sin = rotary_pos_emb.sin()\n        cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)\n        sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)\n\n        cu_window_seqlens = torch.tensor(\n            cu_window_seqlens,\n            device=\"cpu\",\n            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,\n        )\n        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(\n            hidden_states.device\n        )\n\n        # create a cu_seqlens tensor to be used in the attention mask\n        cu_seqlens = torch.repeat_interleave(\n            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]\n        ).cumsum(dim=0, dtype=torch.int32)\n        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)\n        max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])\n\n        # iterately apply the blocks to the hidden states\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for layer_num, block in enumerate(self.blocks):\n            # NOTE: qwen2_5_vl.py has a concept of full attention blocks\n            # that are applied at specific layers.\n            if layer_num in self.fullatt_block_indexes:\n                cu_seqlens_now = cu_seqlens\n            else:\n                cu_seqlens_now = cu_window_seqlens\n\n            hidden_states = block(hidden_states, cu_seqlens_now, cos, sin, max_seqlen)\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        # apply the final patch merger to the hidden states\n        hidden_states = self.merger(hidden_states)\n        reverse_indices = torch.argsort(window_index)\n        hidden_states = hidden_states[reverse_indices, :]\n        return hidden_states\n\n\nclass Qwen2_5VLForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        # set rope_scaling.type == \"mrope\" since AutoConfig.from_pretrained incorrectly\n        # returns rope_scaling.type == \"default\" for Qwen2_5-VL model at the moment\n        if (\n            hasattr(config, \"rope_scaling\")\n            and config.rope_scaling is not None\n            and config.rope_scaling.get(\"type\", None) == \"default\"\n        ):\n            config.rope_scaling.update({\"rope_type\": \"mrope\"})\n        self.hidden_size = config.hidden_size\n        self.vision_start_token_id = config.vision_start_token_id\n        self.vision_end_token_id = config.vision_end_token_id\n        self.image_token_id = config.image_token_id\n        self.video_token_id = config.video_token_id\n        self.spatial_merge_size = config.vision_config.spatial_merge_size\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=\"model.embed_tokens\", weights=weights\n        )\n        self.visual = Qwen2_5VisionModel(\n            prefix=\"visual\", config=config.vision_config, weights=weights\n        )\n        self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=suffix if not prefix else f\"{prefix}.{suffix}\",\n            weights=weights,\n        )\n        self.device = weights.device\n\n    # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391\n    # modified to first find segments then initialize position ids for each segment\n    # Steps:\n    #  locate all vision and text segments\n    #  calculate `vision_segment_lengths` for each vision segment to be use as offset\n    #  calculate `text_segment_lengths` for each text segment to be used as offset\n    #  create position ids for each vision segment based on the image grid\n    #  create position ids for each text segment\n    #  combine all the position ids\n    #  the final segment is the difference between the last vision segment and the end of the input\n    #  combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)\n    def get_position_ids(\n        self,\n        input_ids: torch.Tensor,\n        image_grid_thw: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if image_grid_thw is None:\n            return (\n                torch.arange(input_ids.shape[0], device=input_ids.device)\n                .unsqueeze(1)\n                .repeat(1, 3)\n            )\n\n        spatial_merge_size = self.spatial_merge_size\n        vision_start_token_id = self.vision_start_token_id\n        vision_end_token_id = self.vision_end_token_id\n        device = input_ids.device\n        dtype = input_ids.dtype\n        input_ids_len = input_ids.shape[0]\n\n        vision_starts = torch.where(input_ids == vision_start_token_id)[0]\n        vision_ends = torch.where(input_ids == vision_end_token_id)[0]\n        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)\n        prev_vision_end = torch.cat(\n            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]\n        )\n        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1\n        vision_widths_max = torch.cat(\n            [\n                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),\n                image_grid_thw[:-1, 2] // spatial_merge_size,\n            ]\n        )\n        vision_segment_lengths = vision_widths_max + text_lengths_between_vision\n        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)\n        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision\n\n        # create position ids for each vision segment based on the image grid\n        llm_pos_ids_list = []\n        for i, _ in enumerate(vision_segments):\n            t, h, w = (\n                image_grid_thw[i][0],\n                image_grid_thw[i][1] // spatial_merge_size,\n                image_grid_thw[i][2] // spatial_merge_size,\n            )\n            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)\n            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)\n            w_indices = torch.arange(w, device=device).repeat(t * h)\n            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)\n\n            # offset by the position of the last vision segment\n            im = image_position_ids + vision_segment_lengths[i]\n            llm_pos_ids_list.append(im)\n\n        # create position ids for each text segment\n        text_ranges = [\n            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)\n            + text_segment_lengths[i]\n            for i, seq_len in enumerate(text_lengths_between_vision)\n        ]\n\n        full_llm_pos_ids_list = [\n            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist\n        ]\n        max_s = full_llm_pos_ids_list[-1].max() + 1\n        final_text_len = input_ids_len - vision_ends[-1]\n        if final_text_len > 0:\n            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)\n            full_llm_pos_ids_list.append(m + max_s)\n\n        position_ids = (\n            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)\n        )\n        return position_ids\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)\n        return image_embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        # apply the visual model to the pixel values if they are provided\n        if vision_embeds is not None:\n            mask = torch.where(input_ids == self.image_token_id)\n            inputs_embeds[mask] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor],\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n\n        hidden_states = self.text_model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_vl.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2 VL model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\n\nfrom habana_frameworks.torch.hpex.kernels import FusedSDPA\nfrom vllm_hpu_extension.utils import ModuleFusedSDPA\n\n\nimport numpy as np\n\nfrom transformers.activations import ACT2FN\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    HPUPagedAttentionMetadata,\n)\nfrom text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n    Qwen2Model,\n)\nfrom habana_frameworks.torch.hpex.kernels import (\n    RotaryPosEmbeddingMode,\n    apply_rotary_pos_emb,\n)\nimport habana_frameworks.torch as htorch\n\n\nclass Qwen2VLAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.embed_dim // weights.process_group.size()\n        self.head_dim = config.hidden_size // config.num_heads\n        self.num_heads = config.num_heads // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv\",\n            weights=weights,\n            bias=False,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_heads,\n        )\n        self.qkv.linear.bias = weights.get_sharded(f\"{prefix}.qkv.bias\", dim=0)\n        self.proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.proj\",\n            weights=weights,\n            bias=True,\n        )\n        self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        max_seqlen: int,\n    ) -> torch.Tensor:\n        # apply the qkv linear layer to the hidden state\n        qkv = self.qkv(hidden_state)\n        query, key, value = qkv.split(\n            [self.embed_dim, self.embed_dim, self.embed_dim], dim=1\n        )\n\n        # reshape the query, key, and value tensors\n        _shape = (\n            hidden_state.shape[0],\n            self.num_heads,\n            self.embed_dim // self.num_heads,\n        )\n        query = query.view(*_shape)\n        key = key.view(*_shape)\n        value = value.view(*_shape)\n\n        # apply rotary positional embeddings\n        rope_mode = RotaryPosEmbeddingMode.BLOCKWISE\n        rotary_dim = cos.shape[-1]\n        query_rot = query[..., :rotary_dim]\n        query_pass = query[..., rotary_dim:]\n        query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)\n        query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query.shape))\n\n        key_rot = key[..., :rotary_dim]\n        key_pass = key[..., rotary_dim:]\n        key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)\n        key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key.shape))\n\n        # execute sdpa\n        causal = False\n        query = query.transpose(0, 1)\n        key = key.transpose(0, 1)\n        value = value.transpose(0, 1)\n        fsdpa_op = ModuleFusedSDPA(FusedSDPA)\n        attention_mask = torch.zeros(\n            [1, max_seqlen, max_seqlen], device=query.device, dtype=torch.bool\n        )\n        for i in range(1, len(cu_seqlens)):\n            attention_mask[\n                :, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]\n            ] = True\n        attn_output = fsdpa_op(\n            query,\n            key,\n            value,\n            attn_mask=attention_mask,\n            dropout_p=0.0,\n            is_causal=causal,\n            scale=None,\n            softmax_mode=\"None\",\n            recompute_mode=None,\n            valid_sequence_lengths=None,\n        )\n        attn_output = attn_output.transpose(0, 1)\n        # reshape output to original dimensions\n        attn_output = attn_output.reshape(hidden_state.shape[0], -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2VLVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VLVisionBlock(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.attn = Qwen2VLAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.norm1 = FastLayerNorm.load(\n            prefix=f\"{prefix}.norm1\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.norm2 = FastLayerNorm.load(\n            prefix=f\"{prefix}.norm2\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.mlp = Qwen2VLVisionMLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(self, hidden_states, cu_seqlens, cos, sin, max_seqlen) -> torch.Tensor:\n        norm1_out, residual = self.norm1(hidden_states)\n        attn_out = self.attn(norm1_out, cu_seqlens, cos, sin, max_seqlen)\n        hidden_states = attn_out + residual\n        norm2_out, residual = self.norm2(hidden_states)\n        hidden_states = hidden_states + self.mlp(norm2_out)\n        return hidden_states\n\n\nclass Qwen2VLPatchMerger(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)\n        self.patch_merger_ln_q = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_q\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.mlp.0\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.mlp.2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_states, _ = self.patch_merger_ln_q(hidden_states)\n        hidden_states = hidden_states.view(-1, self.hidden_size)\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = F.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.spatial_merge_size = config.spatial_merge_size\n        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]\n        self.patch_embedding = nn.Conv3d(\n            in_channels=config.in_chans,\n            out_channels=config.embed_dim,\n            kernel_size=kernel_size,\n            stride=kernel_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embed.proj.weight\"), requires_grad=False\n        )\n        head_dim = config.embed_dim // config.num_heads\n        # TODO: replace with static positional embeddings once implemented\n        theta = 10000.0\n        dim = head_dim // 2\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        self.blocks = nn.ModuleList(\n            [\n                Qwen2VLVisionBlock(\n                    prefix=f\"{prefix}.blocks.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(config.depth)\n            ]\n        )\n        self.merger = Qwen2VLPatchMerger(\n            prefix=f\"{prefix}.merger\",\n            config=config,\n            weights=weights,\n        )\n\n        self.temporal_patch_size = config.temporal_patch_size\n        self.spatial_patch_size = config.spatial_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.embed_dim\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: Optional[torch.LongTensor] = None,\n    ) -> torch.Tensor:\n        # reshape the input tensor for processing\n        shape = (\n            -1,\n            self.in_channels,\n            self.temporal_patch_size,\n            self.spatial_patch_size,\n            self.spatial_patch_size,\n        )\n        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)\n        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)\n        # TODO: revisit to see if we can avoid some of these reshapes\n\n        # find the position ids for the input tensor based on the grid_thw\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n\n        # apply the positional embeddings to the position ids\n        seq = torch.arange(\n            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype\n        )\n        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)\n\n        cos = rotary_pos_emb.cos()\n        sin = rotary_pos_emb.sin()\n        cos = torch.cat((cos, cos), dim=-1).unsqueeze(1)\n        sin = torch.cat((sin, sin), dim=-1).unsqueeze(1)\n\n        # create a cu_seqlens tensor to be used in the attention mask\n        cu_seqlens = torch.repeat_interleave(\n            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]\n        ).cumsum(dim=0, dtype=torch.int32)\n        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)\n        max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])\n        # iterately apply the blocks to the hidden states\n        lazy_mode = htorch.utils.internal.is_lazy()\n        if lazy_mode:\n            htorch.core.mark_step()\n        for block in self.blocks:\n            hidden_states = block(hidden_states, cu_seqlens, cos, sin, max_seqlen)\n            if lazy_mode:\n                htorch.core.mark_step()\n\n        # apply the final patch merger to the hidden states\n        hidden_states = self.merger(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VLForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        # set rope_scaling.type == \"mrope\" since AutoConfig.from_pretrained incorrectly\n        # returns rope_scaling.type == \"default\" for Qwen2-VL model at the moment\n        if (\n            hasattr(config, \"rope_scaling\")\n            and config.rope_scaling is not None\n            and config.rope_scaling.get(\"type\", None) == \"default\"\n        ):\n            config.rope_scaling.update({\"rope_type\": \"mrope\"})\n        self.hidden_size = config.hidden_size\n        self.vision_start_token_id = config.vision_start_token_id\n        self.vision_end_token_id = config.vision_end_token_id\n        self.image_token_id = config.image_token_id\n        self.video_token_id = config.video_token_id\n        self.spatial_merge_size = config.vision_config.spatial_merge_size\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=\"model.embed_tokens\", weights=weights\n        )\n        self.visual = Qwen2VisionModel(\n            prefix=\"visual\", config=config.vision_config, weights=weights\n        )\n        self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=suffix if not prefix else f\"{prefix}.{suffix}\",\n            weights=weights,\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=\"model.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.device = weights.device\n\n    # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391\n    # modified to first find segments then initialize position ids for each segment\n    # Steps:\n    #  locate all vision and text segments\n    #  calculate `vision_segment_lengths` for each vision segment to be use as offset\n    #  calculate `text_segment_lengths` for each text segment to be used as offset\n    #  create position ids for each vision segment based on the image grid\n    #  create position ids for each text segment\n    #  combine all the position ids\n    #  the final segment is the difference between the last vision segment and the end of the input\n    #  combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)\n    def get_position_ids(\n        self,\n        input_ids: torch.Tensor,\n        image_grid_thw: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if image_grid_thw is None:\n            return (\n                torch.arange(input_ids.shape[0], device=input_ids.device)\n                .unsqueeze(1)\n                .repeat(1, 3)\n            )\n\n        spatial_merge_size = self.spatial_merge_size\n        vision_start_token_id = self.vision_start_token_id\n        vision_end_token_id = self.vision_end_token_id\n        device = input_ids.device\n        dtype = input_ids.dtype\n        input_ids_len = input_ids.shape[0]\n\n        vision_starts = torch.where(input_ids == vision_start_token_id)[0]\n        vision_ends = torch.where(input_ids == vision_end_token_id)[0]\n        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)\n        prev_vision_end = torch.cat(\n            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]\n        )\n        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1\n        vision_widths_max = torch.cat(\n            [\n                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),\n                image_grid_thw[:-1, 2] // spatial_merge_size,\n            ]\n        )\n        vision_segment_lengths = vision_widths_max + text_lengths_between_vision\n        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)\n        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision\n\n        # create position ids for each vision segment based on the image grid\n        llm_pos_ids_list = []\n        for i, _ in enumerate(vision_segments):\n            t, h, w = (\n                image_grid_thw[i][0],\n                image_grid_thw[i][1] // spatial_merge_size,\n                image_grid_thw[i][2] // spatial_merge_size,\n            )\n            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)\n            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)\n            w_indices = torch.arange(w, device=device).repeat(t * h)\n            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)\n\n            # offset by the position of the last vision segment\n            im = image_position_ids + vision_segment_lengths[i]\n            llm_pos_ids_list.append(im)\n\n        # create position ids for each text segment\n        text_ranges = [\n            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)\n            + text_segment_lengths[i]\n            for i, seq_len in enumerate(text_lengths_between_vision)\n        ]\n\n        full_llm_pos_ids_list = [\n            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist\n        ]\n        max_s = full_llm_pos_ids_list[-1].max() + 1\n        final_text_len = input_ids_len - vision_ends[-1]\n        if final_text_len > 0:\n            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)\n            full_llm_pos_ids_list.append(m + max_s)\n\n        position_ids = (\n            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)\n        )\n        return position_ids\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)\n        return image_embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        # apply the visual model to the pixel values if they are provided\n        if vision_embeds is not None:\n            mask = torch.where(input_ids == self.image_token_id)\n            inputs_embeds[mask] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        hpu_attention_meta: Optional[HPUPagedAttentionMetadata],\n        lm_head_indices: Optional[torch.Tensor],\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n        hidden_states = self.text_model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            slots=slots,\n            seqlen=seqlen,\n            hpu_attention_meta=hpu_attention_meta,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/siglip.py",
    "content": "from typing import Optional, Tuple\nimport warnings\nimport math\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPooling,\n)\nfrom transformers import SiglipConfig, SiglipVisionConfig\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\nfrom text_generation_server.layers.tensor_parallel import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\nclass SiglipVisionEmbeddings(nn.Module):\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions, device=weights.device).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        patch_embeds = self.patch_embedding(\n            pixel_values\n        )  # shape = [*, width, grid, grid]\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass SiglipAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        self.head_size = self.head_dim\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n\n        self.k_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=True\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=True\n        )\n        self.q_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=True\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states)\n        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        # scale post matmul\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(attn_weights.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights\n\n\nclass SiglipMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(  # config.hidden_size, config.intermediate_size\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(  # config.intermediate_size, config.hidden_size\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass SiglipEncoderLayer(nn.Module):\n    def __init__(self, prefix, config: SiglipConfig, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = SiglipAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.mlp = SiglipMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", weights=weights, eps=config.layer_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> Tuple[torch.FloatTensor]:\n        residual = hidden_states\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states, None\n\n\nclass SiglipMultiheadAttentionPoolingHead(nn.Module):\n    \"\"\"Multihead Attention Pooling.\"\"\"\n\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n\n        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))\n        self.attention = torch.nn.MultiheadAttention(\n            config.hidden_size, config.num_attention_heads, batch_first=True\n        )\n        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.mlp = SiglipMLP(prefix, config, weights)\n\n    def forward(self, hidden_state):\n        batch_size = hidden_state.shape[0]\n        probe = self.probe.repeat(batch_size, 1, 1)\n\n        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]\n\n        residual = hidden_state\n        hidden_state = self.layernorm(hidden_state)\n        hidden_state = residual + self.mlp(hidden_state)\n\n        return hidden_state[:, 0]\n\n\ndef _trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    # Values are generated by using a truncated uniform distribution and\n    # then using the inverse CDF for the normal distribution.\n    # Get upper and lower cdf values\n    lower = norm_cdf((a - mean) / std)\n    upper = norm_cdf((b - mean) / std)\n\n    # Uniformly fill tensor with values from [l, u], then translate to\n    # [2l-1, 2u-1].\n    tensor.uniform_(2 * lower - 1, 2 * upper - 1)\n\n    # Use inverse cdf transform for normal distribution to get truncated\n    # standard normal\n    tensor.erfinv_()\n\n    # Transform to proper mean, std\n    tensor.mul_(std * math.sqrt(2.0))\n    tensor.add_(mean)\n\n    # Clamp to ensure it's in the proper range\n    tensor.clamp_(min=a, max=b)\n\n\ndef trunc_normal_tf_(\n    tensor: torch.Tensor,\n    mean: float = 0.0,\n    std: float = 1.0,\n    a: float = -2.0,\n    b: float = 2.0,\n) -> torch.Tensor:\n    \"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\\\leq \\text{mean} \\\\leq b`.\n\n    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the\n    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0\n    and the result is subsquently scaled and shifted by the mean and std args.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    \"\"\"\n    with torch.no_grad():\n        _trunc_normal_(tensor, 0, 1.0, a, b)\n        tensor.mul_(std).add_(mean)\n\n\ndef variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"normal\"):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n    if mode == \"fan_in\":\n        denom = fan_in\n    elif mode == \"fan_out\":\n        denom = fan_out\n    elif mode == \"fan_avg\":\n        denom = (fan_in + fan_out) / 2\n\n    variance = scale / denom\n\n    if distribution == \"truncated_normal\":\n        # constant is stddev of standard normal truncated to (-2, 2)\n        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)\n    elif distribution == \"normal\":\n        with torch.no_grad():\n            tensor.normal_(std=math.sqrt(variance))\n    elif distribution == \"uniform\":\n        bound = math.sqrt(3 * variance)\n        with torch.no_grad():\n            tensor.uniform_(-bound, bound)\n    else:\n        raise ValueError(f\"invalid distribution {distribution}\")\n\n\ndef lecun_normal_(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"truncated_normal\")\n\n\ndef default_flax_embed_init(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"normal\")\n\n\nclass SiglipEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`SiglipEncoderLayer`].\n\n    Args:\n        config: SiglipConfig\n    \"\"\"\n\n    def __init__(self, prefix, config: SiglipConfig, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                SiglipEncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for idx, encoder_layer in enumerate(self.layers):\n            hidden_states, _ = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n\n        return hidden_states\n\n\nclass SiglipVisionTransformer(nn.Module):\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = SiglipVisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = SiglipEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        hidden_states = self.embeddings(pixel_values)\n\n        # NOTE: up until this point, the code logits are exactly\n        # the same as the transformers code. The values evaulate\n        # slightly differently in our encoder layer.\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n        )\n        last_hidden_state = encoder_outputs\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            # pooler_output=pooled_output,\n            # hidden_states=encoder_outputs,\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/custom_modeling/vlm.py",
    "content": "def load_text_model(prefix, config, weights, name=None):\n    if config.model_type == \"llama\":\n        from text_generation_server.models.custom_modeling.flash_llama_modeling import (\n            FlashLlamaForCausalLM,\n        )\n\n        return FlashLlamaForCausalLM(prefix, config, weights, name=name)\n    elif config.model_type == \"mistral\":\n        from text_generation_server.models.custom_modeling.flash_mistral_modeling import (\n            FlashMistralForCausalLM,\n        )\n\n        return FlashMistralForCausalLM(prefix, config, weights, name=name)\n    elif config.model_type == \"gemma\":\n        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n            FlashGemmaForCausalLM,\n        )\n\n        return FlashGemmaForCausalLM(prefix, config, weights)\n    elif config.model_type == \"gemma2\":\n        from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (\n            FlashGemma2ForCausalLM,\n        )\n\n        return FlashGemma2ForCausalLM(prefix, config, weights)\n    elif config.model_type == \"gemma3\" or config.model_type == \"gemma3_text\":\n        from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (\n            FlashGemma3ForCausalLM,\n        )\n\n        return FlashGemma3ForCausalLM(prefix, config, weights)\n    elif config.model_type == \"paligemma\":\n        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n            FlashGemmaForCausalLM,\n        )\n\n        return FlashGemmaForCausalLM(prefix, config, weights)\n    else:\n        raise RuntimeError(f\"Unsupported model type {config.model_type}\")\n\n\ndef load_vision_model(prefix, config, weights):\n    if config.model_type == \"clip_vision_model\":\n        from text_generation_server.models.custom_modeling.clip import (\n            CLIPVisionTransformer,\n        )\n\n        return CLIPVisionTransformer(\n            prefix=f\"{prefix}.vision_model\", config=config, weights=weights\n        )\n    if (\n        config.model_type == \"siglip_vision_model\"\n        or config.model_type == \"gemma3_vision\"\n    ):\n        from text_generation_server.models.custom_modeling.siglip import (\n            SiglipVisionTransformer,\n        )\n\n        # TODO: ensure that using the prefix doesn't break any existing models\n        # that rely on the old prefix (update the old models if necessary)\n        return SiglipVisionTransformer(\n            prefix=f\"{prefix}.vision_model\",\n            config=config,\n            weights=weights,\n        )\n    else:\n        raise RuntimeError(f\"Unsupported model type {config.model_type}\")\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/flash_causal_lm.py",
    "content": "import math\nimport os\nimport time\nimport torch\nimport torch.distributed\n\nimport numpy as np\n\nfrom loguru import logger\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    PreTrainedTokenizerBase,\n    AutoConfig,\n    AutoTokenizer,\n    GenerationConfig,\n)\nfrom typing import (\n    Any,\n    Iterable,\n    Optional,\n    Tuple,\n    List,\n    Type,\n    Dict,\n    Union,\n)\nimport torch.nn.functional as F\nfrom text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.models import Model\nfrom text_generation_server.utils.log import log_master\nfrom text_generation_server.utils.tokens import batch_top_tokens\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n    pad_next_token_chooser_parameters,\n)\nfrom text_generation_server.models.types import (\n    Batch,\n    Tokens,\n    Generation,\n    GeneratedText,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.globals import (\n    BLOCK_SIZE,\n    REQUEST_LOGPROBS,\n    TGI_WIGGLE_ROOM,\n    get_adapter_to_index,\n)\nfrom text_generation_server.layers.attention import (\n    KVCache,\n    KVCompressCache,\n    Seqlen,\n    HPUPagedAttentionMetadata,\n    trim_attn_metadata,\n    trim_seqlen_metadata,\n    _async_h2d_tensor_copy,\n)\nfrom text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser\nfrom text_generation_server.utils.dist import MEMORY_FRACTION\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.segments import SegmentConcatBuilder, find_segments\nfrom text_generation_server.utils.import_utils import (\n    empty_cache,\n    synchronize,\n    get_free_memory,\n)\nfrom text_generation_server.utils.prefill_chunking import (\n    get_max_prefill_tokens,\n)\nimport vllm_hpu_extension.environment as environment\nimport habana_frameworks.torch as htorch\nimport itertools\nfrom vllm_hpu_extension.bucketing.common import get_bucketing_context\nfrom vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes\n\ntracer = trace.get_tracer(__name__)\n\n\ndef generate_block_metadata(\n    dtype,\n    use_contiguous_pa,\n    slots,\n    block_tables,\n    bucketing_ctx,\n    slots_in_window=None,\n    block_bucket_size=None,\n):\n    # Prepare values if we need to continue decoding\n    # need for HPUPagedAttentionMetadata preparation\n    def flatten(in_list):\n        return list(itertools.chain(*in_list))\n\n    def gather_list(input, indices, v):\n        return [input[i] if i is not None else v for i in indices]\n\n    def pad_list(input, k, v):\n        input_len = len(input)\n        target_len = (input_len + k - 1) // k * k\n        padding = target_len - input_len\n        return input + [v] * padding\n\n    last_block_usage = [slot % BLOCK_SIZE + 1 for slot in slots]\n    block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)]\n    block_usage = [\n        [BLOCK_SIZE] * (len(bt) - 1) + [lbu]\n        for bt, lbu in zip(block_tables, last_block_usage)\n        if bt\n    ]\n\n    block_list = flatten(block_tables)\n    block_groups = flatten(block_groups)\n    block_usage = flatten(block_usage)\n    assert len(block_list) == len(block_groups)\n    assert len(block_list) == len(block_usage)\n    if use_contiguous_pa:\n        if block_bucket_size is None:\n            block_bucket_size = max(max(block_list) + 1, len(block_list))\n            if bucketing_ctx is not None:\n                block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(\n                    block_bucket_size\n                )\n        indices: List[Any]\n        indices = [None] * block_bucket_size\n        for i, bid in enumerate(block_list):\n            indices[bid] = i\n        block_list = gather_list(block_list, indices, 0)\n        block_groups = gather_list(block_groups, indices, -1)\n        block_usage = gather_list(block_usage, indices, 1)\n    else:\n        if block_bucket_size is None:\n            block_bucket_size = len(block_list)\n            if bucketing_ctx is not None:\n                block_bucket_size = bucketing_ctx.get_padded_decode_num_blocks(\n                    block_bucket_size\n                )\n        block_list = pad_list(block_list, block_bucket_size, 0)\n        block_groups = pad_list(block_groups, block_bucket_size, -1)\n        block_usage = pad_list(block_usage, block_bucket_size, 1)\n    slots_in_window_mask = None\n    if slots_in_window is not None:\n        slot_list = [\n            block_id * BLOCK_SIZE + slot_idx\n            for block_id in block_list\n            for slot_idx in range(BLOCK_SIZE)\n        ]\n        slot_list = torch.tensor(slot_list, dtype=torch.int64)\n        slot_list = slot_list.view(-1, BLOCK_SIZE)\n        slots_in_window_mask = torch.isin(slot_list, slots_in_window)\n        for i in range(slots_in_window_mask.shape[0]):\n            if not slots_in_window_mask[i].any():\n                slots_in_window_mask[i, 0] = True\n\n    block_list = torch.tensor(block_list, dtype=torch.int, device=\"cpu\")\n    block_groups = torch.tensor(block_groups, dtype=torch.int, device=\"cpu\")\n    block_usage = torch.tensor(block_usage, dtype=dtype, device=\"cpu\")\n    return (\n        block_list,\n        block_groups,\n        block_usage,\n        slots_in_window_mask,\n        block_bucket_size,\n    )\n\n\n@dataclass\nclass FlashCausalLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    # request id -> idx in list mapping\n    requests_idx_mapping: Dict[int, int]\n\n    # Decoder values\n    # Can be a list for easy filtering\n    # If `input_ids` is a list, it needs to be materialized to a tensor first\n    input_ids: Union[torch.Tensor, List[List[int]]]\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    position_ids: Optional[torch.Tensor]\n    speculative_ids: Optional[torch.Tensor]\n\n    # Set when creating the batch\n    # tensor of indices of the currently used slots, length = \\sum_{i=0}^{b} s_i in prefill, length = b in decode\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    slot_indices: Optional[torch.Tensor]\n\n    # list of length b of list of length s_i // block_size\n    block_tables: List[List[int]]\n    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences\n    block_tables_tensor: torch.Tensor\n    # tensor of length \\sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences\n    slots: torch.Tensor\n    # list of length b + 1  containing the cumulative sequence slot lengths of the sequences in the batch\n    # used for filtering\n    cu_slots: torch.Tensor\n\n    max_input_length: int\n    max_current_length: int\n\n    # Whether this batch contains at least one request that is prefilling\n    prefilling: bool\n    # Whether each request is prefilling\n    prefilling_mask: List[bool]\n\n    # Prefill metadata tensors to efficiently compute logprobs\n    # tensor of length b + 1  containing the cumulative sequence lengths of the sequences in the batch, only used in prefill\n    cu_seqlen_prefill: Optional[torch.Tensor]\n    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers\n    # as we only keep SLIDING_WINDOW values instead of the whole tensor\n    prefill_cache_indices: Optional[torch.Tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_head_indices: Optional[torch.Tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_next_token_indices: Optional[torch.tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_cu_outlens: Optional[List[int]]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_logprob_tokens: List[Optional[Tokens]]\n\n    # All tokens\n    all_input_ids: List[List[int]]\n    all_input_ids_tensor: torch.Tensor\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    # size [b], containing the number of blocks that can be retrieved from the cache\n    cache_lengths: List[int]\n    prompt_lengths: List[int]\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    input_lengths_tensor: Optional[torch.Tensor]\n    cache_lengths_tensor: Optional[torch.Tensor]\n    prompt_lengths_tensor: torch.Tensor\n\n    prefix_offsets: List[Optional[int]]\n    read_offsets: List[Optional[int]]\n\n    # Generation helpers\n    next_token_chooser: HeterogeneousNextTokenChooser\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Adapter metadata for each request\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    adapter_meta: Optional[AdapterBatchMetadata]\n\n    # Number of blocks in this batch\n    num_blocks: int\n    # Maximum number of blocks\n    max_blocks: int\n\n    hpu_attn_meta: Optional[HPUPagedAttentionMetadata]\n\n    next_token_logits: Optional[torch.Tensor]\n    speculative_logits: Optional[torch.Tensor]\n    valid_indices: Optional[List[int]]\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.num_blocks * BLOCK_SIZE,\n            current_tokens=(\n                sum([len(i) for i in self.input_ids])\n                if isinstance(self.input_ids, list)\n                else len(self.input_ids)\n            ),\n        )\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[generate_pb2.Request], tokenizer\n    ):\n        max_length = 0\n        all_input_ids = []\n        batch_size = 0\n        for r in requests:\n            batch_size += 1\n            inputs = concat_text_chunks(r.input_chunks.chunks)\n            input_ids = tokenizer(\n                inputs,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=r.add_special_tokens,\n            )[\"input_ids\"]\n            max_length = max(max_length, len(input_ids))\n            all_input_ids.append(input_ids)\n        return all_input_ids\n\n    @classmethod\n    def from_tokenized(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        batch_tokenized_inputs,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashCausalLMBatch\":\n        cache_lengths = []\n        input_lengths = []\n        prompt_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        all_postfix_ids = []\n        requests_idx_mapping = {}\n        slots = []\n        cu_slots = [0]\n\n        next_token_chooser_parameters = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        num_blocks = 0\n        max_input_length = 0\n        max_current_length = 0\n        max_length = 0\n        max_blocks = 0\n\n        cu_blocks = [0]\n        block_tables = []\n        block_tables_ragged = []\n\n        # Parse batch\n        for i, (r, tokenized_input) in enumerate(\n            zip(pb.requests, batch_tokenized_inputs)\n        ):\n            ### XXX: This consumes so much memory on long requests\n            ### Deactivating it by default seems like the best course.\n            if not REQUEST_LOGPROBS:\n                r.prefill_logprobs = False\n            else:\n                assert False, \"prefill_logprobs not supported yet\"\n            # request id -> idx in list mapping\n            requests_idx_mapping[r.id] = i\n\n            prompt_length = len(tokenized_input)\n            prompt_lengths.append(prompt_length)\n\n            cache_length = r.cache_len\n\n            assert (\n                cache_length <= prompt_length\n            ), f\"Prefix {cache_length} vs input {prompt_length}\"\n            if cache_length == prompt_length:\n                assert False, \"unreachable\"\n\n            # `chunk_len` is an optional field in the protobuf\n            # It is only set if the model support chunking\n            # Use all the remaining ids\n            postfix_ids = tokenized_input[cache_length:]\n            input_length = len(postfix_ids)\n\n            input_lengths.append(input_length)\n\n            prefix_offsets.append(prompt_length - 5)\n            read_offsets.append(prompt_length)\n\n            all_postfix_ids.append(postfix_ids)\n            all_input_ids.append(tokenized_input)\n\n            next_token_chooser_parameters.append(r.parameters)\n\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            max_new_tokens = stopping_criteria.max_new_tokens\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n\n            # Paged attention\n            # Remove one as the first token des not have a past\n            speculative_length = get_speculate()\n            speculative_length = 0 if speculative_length is None else speculative_length\n\n            # Tokens that need to be mapped to blocks.\n            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length\n\n            # blocks and slots can be empty (for example in warmup)\n            if not r.blocks:\n                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)\n                request_blocks = [\n                    b for b in range(num_blocks, num_blocks + needed_blocks)\n                ]\n                request_slots = [\n                    s\n                    for b in request_blocks\n                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)\n                ]\n            else:\n                request_blocks = r.blocks\n                request_slots = r.slots\n\n            block_tables.append(request_blocks)\n            block_tables_ragged.extend(request_blocks)\n            cu_blocks.append(len(block_tables_ragged))\n\n            slots.extend(request_slots)\n            cu_slots.append(len(slots))\n\n            cache_lengths.append(cache_length)\n            num_blocks += len(request_blocks)\n\n            # Update\n            max_blocks = max(max_blocks, len(request_blocks))\n            max_input_length = max(max_input_length, input_length)\n            max_current_length = max(max_current_length, cache_length + input_length)\n            max_length = max(\n                max_length,\n                prompt_length + max_new_tokens + speculative_length,\n            )\n\n        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n            next_token_chooser_parameters, dtype, device, tokenizer\n        )\n\n        # Padded all_input_ids_tensor\n        all_input_ids_tensor = np.zeros(\n            (len(all_input_ids), max_length), dtype=np.int64\n        )\n        for i, input_ids in enumerate(all_input_ids):\n            all_input_ids_tensor[i, : len(input_ids)] = input_ids\n\n        # put on cpu temporarily, move to hpu in prepare_for_prefill\n        all_input_ids_tensor = torch.tensor(all_input_ids_tensor, dtype=torch.int64)\n\n        top_n_tokens_tensor = torch.tensor(top_n_tokens, dtype=torch.int64)\n\n        block_tables_ragged = torch.tensor(block_tables_ragged, dtype=torch.int32)\n        cu_blocks = torch.tensor(cu_blocks, dtype=torch.int64)\n        block_tables_tensor = torch.empty(\n            (len(block_tables), max_blocks),\n            dtype=torch.int32,\n        )\n\n        for i, request_blocks in enumerate(block_tables):\n            block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)\n\n        prompt_lengths_tensor = torch.tensor(prompt_lengths, dtype=torch.int32)\n\n        slots = torch.tensor(slots, dtype=torch.int64)\n        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=all_postfix_ids,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            cache_lengths=cache_lengths,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=True,\n            prefilling_mask=[True] * len(pb.requests),\n            prefill_logprob_tokens=[None] * len(pb.requests),\n            input_lengths=input_lengths,\n            prompt_lengths=prompt_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=all_input_ids_tensor,\n            next_token_chooser=next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=None,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids=None,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=None,\n            slots=slots,\n            cu_slots=cu_slots,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            cache_lengths_tensor=None,\n            input_lengths_tensor=None,\n            adapter_meta=None,\n            hpu_attn_meta=None,\n            next_token_logits=None,\n            speculative_logits=None,\n            valid_indices=None,\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashCausalLMBatch\":\n        assert len(pb.requests) > 0\n        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)\n        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> \"FlashCausalLMBatch\":\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        # We assume that if len(requests) == len(self) then the requests are the same\n        if len(request_ids) == len(self):\n            return self\n\n        device = self.block_tables_tensor.device\n\n        # New values after filtering\n        requests_idx_mapping = {}\n\n        # Used to index into tensors\n        indices = []\n\n        # slots to keep after filtering\n        slot_filtering_indices = torch.zeros(self.slots.shape[0], dtype=torch.bool)\n\n        # Create on CPU to only move to GPU once instead of at every copy\n        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)\n        max_input_length = 0\n        max_current_length = 0\n\n        requests = []\n        block_tables = []\n        all_input_ids = []\n        input_ids = []\n\n        prompt_lengths = []\n        input_lengths = []\n        cache_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        cu_slots = [0]\n\n        prefilling_mask = []\n        prefill_logprob_tokens = []\n\n        stopping_criterias = []\n        adapter_set = set()\n\n        num_blocks = 0\n        max_blocks = 0\n        max_slots = 0\n        cumulative_slot_tokens = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            indices.append(idx)\n            requests_idx_mapping[request_id] = i\n\n            requests.append(self.requests[idx])\n\n            # Prefilling\n            request_prefilling = self.prefilling_mask[idx]\n            prefilling_mask.append(request_prefilling)\n\n            # Get length\n            request_input_length = self.input_lengths[idx]\n            request_cache_length = self.cache_lengths[idx]\n            max_input_length = max(max_input_length, request_input_length)\n            max_current_length = max(\n                max_current_length, request_cache_length + request_input_length\n            )\n\n            all_input_ids.append(self.all_input_ids[idx])\n\n            prompt_lengths.append(self.prompt_lengths[idx])\n            input_lengths.append(request_input_length)\n            cache_lengths.append(request_cache_length)\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n\n            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])\n\n            ADAPTER_TO_INDEX = get_adapter_to_index()\n            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)\n            adapter_set.add(adapter_index)\n\n            request_block_table = self.block_tables[idx]\n            num_blocks += len(request_block_table)\n            block_tables.append(request_block_table)\n\n            start_slot = self.cu_slots[idx]\n            end_slot = self.cu_slots[idx + 1]\n            slot_length = end_slot - start_slot\n\n            # Set slice\n            slot_filtering_indices[start_slot:end_slot] = True\n\n            cu_slots.append(cumulative_slot_tokens + slot_length)\n\n            # Input ids if the request was part of a prefilling batch\n            # If the batch was decoding we can index into the tensor directly later\n            if self.prefilling:\n                input_ids.append(self.input_ids[idx])\n            else:\n                # Copy to tensor (CPU)\n                slot_indices[i] = cumulative_slot_tokens + request_cache_length\n\n            cumulative_slot_tokens += slot_length\n            max_blocks = max(max_blocks, len(request_block_table))\n            max_slots = max(max_slots, slot_length)\n\n        block_tables_tensor = self.block_tables_tensor[indices]\n        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]\n\n        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)\n\n        slots = self.slots[slot_filtering_indices]\n\n        if self.prefilling:\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids = None\n            slot_indices = None\n            cache_lengths_tensor = None\n            input_lengths_tensor = None\n            adapter_meta = None\n        else:\n            # Index into tensors\n            input_ids = self.input_ids[indices]\n            position_ids = self.position_ids[indices]\n            input_lengths_tensor = self.input_lengths_tensor[indices]\n            cache_lengths_tensor = self.cache_lengths_tensor[indices]\n\n            # Move to GPU now that we have the whole tensor\n            slot_indices = slot_indices.to(device)\n            if self.adapter_meta is not None:\n                adapter_indices = self.adapter_meta.adapter_indices[indices]\n                adapter_segments, adapter_segment_indices = find_segments(\n                    adapter_indices\n                )\n                adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)\n                adapter_meta = AdapterBatchMetadata(\n                    adapter_indices=adapter_indices,\n                    adapter_set=adapter_set,\n                    adapter_segments=adapter_segments,\n                    segment_indices=adapter_segment_indices,\n                )\n            else:\n                adapter_meta = None\n        htorch.core.mark_step()\n        return type(self)(\n            batch_id=self.batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=slot_indices,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            slots=slots,\n            cu_slots=cu_slots,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=self.prefilling,\n            prefilling_mask=prefilling_mask,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            prefill_logprob_tokens=prefill_logprob_tokens,\n            prompt_lengths=prompt_lengths,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            input_lengths=input_lengths,\n            input_lengths_tensor=input_lengths_tensor,\n            cache_lengths=cache_lengths,\n            cache_lengths_tensor=cache_lengths_tensor,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=self.all_input_ids_tensor,\n            next_token_chooser=self.next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=self.top_n_tokens,\n            top_n_tokens_tensor=self.top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=self.speculative_ids,\n            adapter_meta=adapter_meta,\n            hpu_attn_meta=None,\n            valid_indices=indices,\n            next_token_logits=self.next_token_logits,\n            speculative_logits=self.speculative_logits,\n        )\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(\n        cls, batches: List[\"FlashCausalLMBatch\"], padded_total_bs: int = 0\n    ) -> \"FlashCausalLMBatch\":\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n\n        prefilling = False\n        num_blocks = 0\n        total_batch_size = 0\n        total_slots = 0\n        max_blocks = 0\n        max_length = 0\n        max_input_length = 0\n        max_current_length = 0\n        ADAPTER_TO_INDEX = get_adapter_to_index()\n        for b in batches:\n            total_batch_size += len(b)\n            max_blocks = max(max_blocks, b.max_blocks)\n            total_slots += len(b.slots)\n            num_blocks += b.num_blocks\n            speculative_length = (\n                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0\n            )\n            max_input_length = max(max_input_length, b.max_input_length)\n            max_current_length = max(max_current_length, b.max_current_length)\n            max_length = max(\n                max_length,\n                max(\n                    prompt_length\n                    + stopping_criteria.max_new_tokens\n                    + speculative_length\n                    for prompt_length, stopping_criteria in zip(\n                        b.prompt_lengths, b.stopping_criterias\n                    )\n                ),\n            )\n            prefilling = prefilling or b.prefilling\n\n        slots = batches[0].slots.new_empty(total_slots)\n        cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)\n        if prefilling:\n            input_ids = []\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids = None\n            slot_indices = None\n            cache_lengths_tensor = None\n            input_lengths_tensor = None\n            adapter_meta = None\n            adapter_segment_builder = None\n        else:\n            if padded_total_bs == batches[0].input_ids.shape[0]:\n                input_ids = batches[0].input_ids\n            else:\n                input_ids = batches[0].input_ids.new_empty(total_batch_size)\n            if (\n                batches[0].position_ids is not None\n                and batches[0].position_ids.dim() == 2\n            ):\n                # Qwen2_vl case:\n                position_ids = batches[0].position_ids.new_empty(\n                    (total_batch_size, batches[0].position_ids.shape[-1])\n                )\n            else:\n                position_ids = batches[0].position_ids.new_empty(total_batch_size)\n            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)\n            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(\n                total_batch_size\n            )\n            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(\n                total_batch_size\n            )\n            if ADAPTER_TO_INDEX:\n                total_indices_size = sum(\n                    b.adapter_meta.adapter_indices.shape[0] for b in batches\n                )\n                adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(\n                    total_indices_size\n                )\n                adapter_segment_builder = SegmentConcatBuilder()\n                adapter_set = set()\n\n        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(\n            total_batch_size\n        )\n        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(\n            (total_batch_size, max_blocks)\n        )\n        all_input_ids_tensor = batches[0].all_input_ids_tensor\n        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n            total_batch_size,\n        )\n\n        block_tables = []\n        cache_lengths = []\n        all_input_ids = []\n\n        prompt_lengths = []\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n\n        prefill_logprob_tokens = []\n\n        next_token_chooser_parameters = []\n        fsm_grammar_states = []\n        stopping_criterias = []\n        top_n_tokens = []\n        prefilling_mask = []\n\n        # Cumulative length\n        cumulative_batch_size = 0\n        cumulative_slots = 0\n        cumulative_adapter_indices_size = 0\n\n        for i, batch in enumerate(batches):\n            requests.extend(batch.requests)\n            valid_bsize = len(batch)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + cumulative_batch_size\n\n            start_index = cumulative_batch_size\n            end_index = cumulative_batch_size + valid_bsize\n\n            index = torch.tensor(list(range(start_index, end_index)), device=\"cpu\")\n            top_n_tokens_tensor.index_copy_(0, index, batch.top_n_tokens_tensor)\n            if i > 0:\n                all_input_ids_tensor.index_copy_(\n                    0,\n                    index.to(batch.all_input_ids_tensor.device),\n                    batch.all_input_ids_tensor[:valid_bsize, :],\n                )\n\n            block_tables_tensor[\n                start_index:end_index, : batch.block_tables_tensor.shape[1]\n            ] = batch.block_tables_tensor[:, :max_blocks]\n            prompt_lengths_tensor.index_copy_(0, index, batch.prompt_lengths_tensor)\n\n            slots_start_index = cumulative_slots\n            slots_end_index = cumulative_slots + len(batch.slots)\n            slot_index = torch.tensor(\n                list(range(slots_start_index, slots_end_index)),\n                device=batch.slots.device,\n            )\n\n            slots.index_copy_(0, slot_index, batch.slots)\n            cu_slots[start_index + 1 : end_index + 1] = (\n                batch.cu_slots[1:] + cumulative_slots\n            )\n\n            if not prefilling:\n                if padded_total_bs != batches[0].input_ids.shape[0] or i > 0:\n                    input_ids.index_copy_(\n                        0, index.to(input_ids.device), batch.input_ids[:valid_bsize]\n                    )\n                position_ids.index_copy_(0, index, batch.position_ids[:valid_bsize])\n                slot_indices.index_copy_(\n                    0, index, batch.slot_indices + cumulative_slots\n                )\n                input_lengths_tensor.index_copy_(\n                    0, index, batch.input_lengths_tensor[:valid_bsize]\n                )\n                cache_lengths_tensor.index_copy_(\n                    0, index, batch.cache_lengths_tensor[:valid_bsize]\n                )\n                if ADAPTER_TO_INDEX:\n                    adapter_start_index = cumulative_adapter_indices_size\n                    adapter_end_index = (\n                        cumulative_adapter_indices_size\n                        + batch.adapter_meta.adapter_indices.shape[0]\n                    )\n                    adapter_indices[adapter_start_index:adapter_end_index] = (\n                        batch.adapter_meta.adapter_indices\n                    )\n                    cumulative_adapter_indices_size = adapter_end_index\n                    adapter_set.update(batch.adapter_meta.adapter_set)\n                    adapter_segment_builder.concat(\n                        batch.adapter_meta.adapter_segments,\n                        batch.adapter_meta.segment_indices,\n                    )\n            else:\n                if isinstance(batch.input_ids, torch.Tensor):\n                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()\n                input_ids.extend(batch.input_ids)\n\n            prefilling_mask.extend(batch.prefilling_mask)\n            block_tables.extend(batch.block_tables)\n            cache_lengths.extend(batch.cache_lengths)\n            all_input_ids.extend(batch.all_input_ids)\n\n            prompt_lengths.extend(batch.prompt_lengths)\n            input_lengths.extend(batch.input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n\n            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)\n\n            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])\n            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)\n            stopping_criterias.extend(batch.stopping_criterias)\n\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            # Update\n            cumulative_slots += len(batch.slots)\n            cumulative_batch_size += len(batch)\n\n        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n            next_token_chooser_parameters,\n            dtype=batches[0].next_token_chooser.dtype,\n            device=batches[0].next_token_chooser.device,\n            tokenizer=batches[0].next_token_chooser.tokenizer,\n            fsm_grammar_states=fsm_grammar_states,\n        )\n\n        # We skip computing the speculative_ids when the batch size is too large, so\n        # we must check that all batches have them, otherwise they must be discarded\n        if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):\n            speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)\n        else:\n            speculative_ids = None\n\n        if ADAPTER_TO_INDEX and adapter_segment_builder is not None:\n            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()\n            adapter_meta = AdapterBatchMetadata(\n                adapter_indices=adapter_indices,\n                adapter_set=adapter_set,\n                adapter_segments=adapter_segments,\n                segment_indices=adapter_segment_indices,\n            )\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=slot_indices,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            cache_lengths=cache_lengths,\n            cache_lengths_tensor=cache_lengths_tensor,\n            slots=slots,\n            cu_slots=cu_slots,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=prefilling,\n            prefilling_mask=prefilling_mask,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            prefill_logprob_tokens=prefill_logprob_tokens,\n            prompt_lengths=prompt_lengths,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            input_lengths=input_lengths,\n            input_lengths_tensor=input_lengths_tensor,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=all_input_ids_tensor,\n            next_token_chooser=next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=speculative_ids,\n            adapter_meta=adapter_meta if ADAPTER_TO_INDEX else None,\n            hpu_attn_meta=None,\n            next_token_logits=None,\n            speculative_logits=None,\n            valid_indices=None,\n        )\n\n    def prepare_for_decode(\n        self, dtype, use_contiguous_pa, bucketing_ctx, pad_token_id, sliding_window\n    ):\n        block_num = [length // BLOCK_SIZE + 1 for length in self.cache_lengths]\n        block_tables = []\n        for i, bt in enumerate(self.block_tables):\n            block_tables.append(bt[0 : block_num[i]])\n        if bucketing_ctx is not None:\n            padded_bs = bucketing_ctx.get_padded_decode_batch_size(\n                self.input_ids.shape[0]\n            )\n        else:\n            padded_bs = self.input_ids.shape[0]\n        slots = self.slots[self.slot_indices]\n\n        block_list, block_groups, block_usage, _, block_bucket_size = (\n            generate_block_metadata(\n                dtype,\n                use_contiguous_pa,\n                slots,\n                block_tables,\n                bucketing_ctx,\n            )\n        )\n        meta = HPUPagedAttentionMetadata(\n            block_list=_async_h2d_tensor_copy(block_list),\n            block_groups=_async_h2d_tensor_copy(block_groups),\n            block_usage=_async_h2d_tensor_copy(block_usage),\n            block_mapping=None,\n            attn_bias=None,\n        )\n        if sliding_window is not None:\n            block_tables_in_window = []\n            for i, bt in enumerate(self.block_tables):\n                block_num_in_window = (\n                    sliding_window + 2 * BLOCK_SIZE - 2 - slots[i] % BLOCK_SIZE\n                ) // BLOCK_SIZE\n                block_tables_in_window.append(\n                    bt[max(0, block_num[i] - block_num_in_window) : block_num[i]]\n                )\n            slots_in_window = []\n            for i, indice in enumerate(self.slot_indices):\n                start_idx = indice - self.cache_lengths[i]\n                mask = (\n                    indice\n                    - torch.arange(\n                        start_idx,\n                        indice + 1,\n                        device=self.slots.device,\n                    )\n                ) < sliding_window\n                slots_in_window.append(self.slots[start_idx : indice + 1][mask])\n            slots_in_window = torch.cat(slots_in_window, dim=0)\n            (\n                block_list_in_window,\n                block_groups_in_window,\n                block_usage_in_window,\n                slots_in_window_mask,\n                _,\n            ) = generate_block_metadata(\n                dtype,\n                use_contiguous_pa,\n                slots,\n                block_tables_in_window,\n                bucketing_ctx,\n                slots_in_window,\n                block_bucket_size,\n            )\n            meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)\n            meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)\n            meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)\n            meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)\n\n        self.hpu_attn_meta = trim_attn_metadata(meta)\n        self.input_ids = F.pad(\n            self.input_ids, (0, padded_bs - self.input_ids.shape[0]), value=pad_token_id\n        )\n\n        if self.position_ids.dim() == 2:\n            # Qwen VL case\n            self.position_ids = F.pad(\n                self.position_ids,\n                (0, 0, 0, padded_bs - self.position_ids.shape[0]),\n                value=1,\n            )\n        else:\n            self.position_ids = F.pad(\n                self.position_ids, (0, padded_bs - self.position_ids.shape[0]), value=1\n            )\n        self.input_lengths_tensor = F.pad(\n            self.input_lengths_tensor,\n            (0, padded_bs - self.input_lengths_tensor.shape[0]),\n            value=0,\n        )\n        self.cache_lengths_tensor = F.pad(\n            self.cache_lengths_tensor,\n            (0, padded_bs - self.cache_lengths_tensor.shape[0]),\n            value=0,\n        )\n        if len(self.next_token_chooser.do_sample) != padded_bs:\n            next_token_chooser_parameters = []\n            next_token_chooser_parameters.extend([r.parameters for r in self.requests])\n            pad_next_token_chooser_parameters(next_token_chooser_parameters, padded_bs)\n            # update past grammar states\n            fsm_grammar_states = [0] * padded_bs\n\n            for i, req in enumerate(self.requests):\n                fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]\n\n            self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n                next_token_chooser_parameters,\n                self.next_token_chooser.dtype,\n                self.next_token_chooser.device,\n                self.next_token_chooser.tokenizer,\n                fsm_grammar_states,\n            )\n\n    def prepare_for_prefill(\n        self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id\n    ):\n        # Prepare values if we need to continue prefilling\n        # Speculation must be ignored while we prefill even with chunking\n        # it simplifies everything\n        assert self.speculative_ids is None\n\n        # device = self.block_tables_tensor.device\n\n        # hpu does not support varlen for prefill, use sdpa instead. so need to pad input_tensor, position\n        # padding to left to work with sliding window\n        # use prefill_cache_indices to indicate the valid kv slot, update prefill_next_token_indices to indicate\n        # the right logit position\n        input_ids_padded_length = []\n        # need extra pad to match warmup seq\n        extra_pad = max_padded_input_len - self.max_input_length\n        extra_pad_bs = max_padded_bs - len(self)\n        device = \"hpu\"\n        if isinstance(self.input_ids, list) and len(self) > 1:\n            input_ids_padded_length = []\n            input_ids = []\n            for input_id in self.input_ids:\n                padded = self.max_input_length - len(input_id) + extra_pad\n                if padded > 0:\n                    input_id = [pad_token_id] * padded + input_id\n                input_ids.append(input_id)\n                input_ids_padded_length.append(padded)\n            input_ids = np.concatenate(input_ids, dtype=np.int64)\n            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)\n        elif isinstance(self.input_ids, list):\n            input_ids = self.input_ids[0]\n            input_ids_padded_length.append(extra_pad)\n            input_ids = [pad_token_id] * extra_pad + input_ids\n            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)\n        else:\n            input_ids = torch.full(\n                (max_padded_input_len * len(self),),\n                pad_token_id,\n                dtype=torch.int64,\n                device=self.input_ids.device,\n            )\n            src_pos = 0\n            for i in range(len(self)):\n                end_pos = (i + 1) * max_padded_input_len\n                start_pos = end_pos - self.input_lengths[i]\n                input_ids[start_pos:end_pos] = self.input_ids[\n                    src_pos : src_pos + self.input_lengths[i]\n                ]\n                input_ids_padded_length.append(\n                    max_padded_input_len - self.input_lengths[i]\n                )\n                src_pos += self.input_lengths[i]\n            self.input_ids = input_ids\n\n        self.input_ids = F.pad(\n            self.input_ids, (0, extra_pad_bs * max_padded_input_len), value=pad_token_id\n        )\n\n        self.input_lengths_tensor = torch.tensor(self.input_lengths, dtype=torch.int32)\n\n        self.input_lengths_tensor = F.pad(\n            self.input_lengths_tensor, (0, extra_pad_bs), value=0\n        )\n\n        cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(max_padded_bs + 1)\n        torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)\n        self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)\n        self.cache_lengths_tensor = torch.tensor(self.cache_lengths, dtype=torch.int32)\n        self.cache_lengths_tensor = F.pad(\n            self.cache_lengths_tensor, (0, extra_pad_bs), value=0\n        )\n\n        position_ids = []\n        slot_indices = []\n        prefill_cache_indices = []\n        all_prefill_logprobs = True\n        no_prefill_logprobs = True\n        prefill_cu_outlens = [0]\n\n        # Cumulative length\n        cumulative_length = 0\n        cumulative_slot_tokens = 0\n        prefill_out_cumulative_length = 0\n\n        adapter_indices_list = []\n        adapter_set = set()\n\n        for i, (\n            r,\n            cache_length,\n            input_length,\n            prompt_length,\n            request_prefilling,\n            blocks,\n        ) in enumerate(\n            zip(\n                self.requests,\n                self.cache_lengths,\n                self.input_lengths,\n                self.prompt_lengths,\n                self.prefilling_mask,\n                self.block_tables,\n            )\n        ):\n            next_chunk_length = input_length\n\n            # Position ids\n            request_position_ids = torch.arange(\n                cache_length, cache_length + input_length, dtype=torch.int32\n            )\n            request_position_ids = F.pad(\n                request_position_ids, (input_ids_padded_length[i], 0), value=1\n            )\n            position_ids.append(request_position_ids)\n\n            if not r.slots:\n                request_slots = [\n                    s\n                    for b in blocks\n                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)\n                ]\n            else:\n                request_slots = r.slots\n\n            request_slot_indices = torch.arange(\n                cache_length + cumulative_slot_tokens,\n                cache_length + cumulative_slot_tokens + input_length,\n                dtype=torch.int64,\n            )\n\n            slot_indices.append(request_slot_indices)\n\n            # Update\n            cumulative_slot_tokens += len(request_slots)\n\n            # Create tensor to slice into the kv tensor in prefill\n            # hpu need request_prefill_cache_indices to skip padding in kv cache\n            sliding_window = input_length\n            cumulative_length += input_ids_padded_length[i]\n            if sliding_window is not None:\n                request_prefill_cache_indices = torch.arange(\n                    cumulative_length + max(0, input_length - sliding_window),\n                    cumulative_length + input_length,\n                    dtype=torch.int64,\n                )\n\n            # Prefill logprobs is ignored if the request is done prefilling\n            prefill_logprobs = r.prefill_logprobs and request_prefilling\n\n            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs\n            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs\n\n            if prefill_logprobs:\n                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)\n                prefill_out_cumulative_length += input_length\n            else:\n                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)\n                prefill_out_cumulative_length += 1\n\n            prefill_cache_indices.append(request_prefill_cache_indices)\n\n            ADAPTER_TO_INDEX = get_adapter_to_index()\n            if ADAPTER_TO_INDEX:\n                adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)\n                adapter_indices_list.append(\n                    torch.full((next_chunk_length,), adapter_index)\n                )\n                adapter_set.add(adapter_index)\n\n            # Update\n            cumulative_length += next_chunk_length\n\n        if not all_prefill_logprobs and not no_prefill_logprobs:\n            prefill_head_indices = []\n            prefill_next_token_indices = []\n\n            # Cumulative length\n            cumulative_length = 0\n            prefill_out_cumulative_length = 0\n\n            for i, (\n                r,\n                input_length,\n                request_prefilling,\n            ) in enumerate(\n                zip(\n                    self.requests,\n                    self.input_lengths,\n                    self.prefilling_mask,\n                )\n            ):\n                # Prefill logprobs is ignored if the request is done prefilling\n                prefill_logprobs = r.prefill_logprobs and request_prefilling\n\n                if prefill_logprobs:\n                    prefill_head_indices.append(\n                        torch.arange(\n                            cumulative_length,\n                            cumulative_length + input_length,\n                            dtype=torch.int32,\n                        )\n                    )\n                    prefill_next_token_indices.append(\n                        prefill_out_cumulative_length + input_length - 1\n                    )\n                    prefill_out_cumulative_length += input_length\n                else:\n                    prefill_head_indices.append(\n                        torch.tensor(\n                            [cumulative_length + input_length - 1],\n                            dtype=torch.int32,\n                        )\n                    )\n                    prefill_next_token_indices.append(prefill_out_cumulative_length)\n                    prefill_out_cumulative_length += 1\n\n                # Update\n                cumulative_length += input_length\n\n        if len(self) > 1:\n            if position_ids:\n                position_ids = torch.cat(position_ids)\n            if slot_indices:\n                slot_indices = torch.cat(slot_indices)\n            prefill_cache_indices = torch.cat(prefill_cache_indices)\n        else:\n            if position_ids:\n                position_ids = position_ids[0]\n            if slot_indices:\n                slot_indices = slot_indices[0]\n            prefill_cache_indices = prefill_cache_indices[0]\n\n        self.position_ids = position_ids\n        self.position_ids = F.pad(\n            self.position_ids, (0, extra_pad_bs * max_padded_input_len), value=1\n        )\n        self.slot_indices = slot_indices\n\n        self.prefill_cu_outlens = prefill_cu_outlens\n        self.prefill_cache_indices = torch.zeros_like(\n            self.input_ids, dtype=torch.bool, device=\"cpu\"\n        )\n        self.prefill_cache_indices[prefill_cache_indices] = True\n\n        if all_prefill_logprobs:\n            prefill_head_indices = None\n            prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1\n        elif no_prefill_logprobs:\n            prefill_head_indices = self.cu_seqlen_prefill[1:] - 1\n            prefill_next_token_indices = None\n        else:\n            prefill_head_indices = torch.cat(prefill_head_indices)\n            prefill_next_token_indices = torch.tensor(\n                prefill_next_token_indices, dtype=torch.int64\n            )\n\n        self.prefill_head_indices = prefill_head_indices\n        self.prefill_next_token_indices = prefill_next_token_indices\n        input_ids_padded_length_tensor = torch.cumsum(\n            torch.tensor(input_ids_padded_length, dtype=torch.int32),\n            dim=-1,\n        ).to(torch.int32)\n        input_ids_padded_length_tensor = F.pad(\n            input_ids_padded_length_tensor, (0, extra_pad_bs), value=0\n        )\n        if self.prefill_head_indices is not None:\n            self.prefill_head_indices = (\n                self.prefill_head_indices + input_ids_padded_length_tensor\n            )\n\n        if self.prefill_next_token_indices is not None:\n            self.prefill_next_token_indices = (\n                self.prefill_next_token_indices + input_ids_padded_length_tensor\n            )\n        all_input_ids_tensor = torch.full(\n            (max_padded_bs, max(max_total_tokens, self.all_input_ids_tensor.shape[-1])),\n            pad_token_id,\n            dtype=torch.int64,\n            device=\"hpu\",\n        )\n        for i in range(len(self)):\n            all_input_ids_tensor[i, : self.all_input_ids_tensor.shape[-1]] = (\n                self.all_input_ids_tensor[i]\n            )\n        self.all_input_ids_tensor = all_input_ids_tensor\n        if len(self.next_token_chooser.do_sample) != max_padded_bs:\n            next_token_chooser_parameters = []\n            next_token_chooser_parameters.extend([r.parameters for r in self.requests])\n            pad_next_token_chooser_parameters(\n                next_token_chooser_parameters, max_padded_bs\n            )\n            # update past grammar states\n            fsm_grammar_states = [0] * max_padded_bs\n\n            for i, req in enumerate(self.requests):\n                fsm_grammar_states[i] = self.next_token_chooser.fsm_grammar_states[i]\n\n            self.next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n                next_token_chooser_parameters,\n                self.next_token_chooser.dtype,\n                self.next_token_chooser.device,\n                self.next_token_chooser.tokenizer,\n                fsm_grammar_states,\n            )\n\n        if ADAPTER_TO_INDEX:\n            if adapter_set:\n                adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64)\n                adapter_segments, adapter_segment_indices = find_segments(\n                    adapter_indices\n                )\n            else:\n                adapter_indices = torch.zeros_like(self.input_ids)\n                adapter_segments = [0, len(adapter_indices)]\n                adapter_segment_indices = [len(adapter_indices) - 1]\n\n            adapter_segments = torch.tensor(adapter_segments, dtype=torch.int32)\n            self.adapter_meta = AdapterBatchMetadata(\n                adapter_indices=adapter_indices,\n                adapter_set=adapter_set,\n                adapter_segments=adapter_segments,\n                segment_indices=adapter_segment_indices,\n            )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nADAPTER_LAYERS = [\n    \"q_proj\",\n    \"k_proj\",\n    \"v_proj\",\n    \"o_proj\",\n    \"gate_proj\",\n    \"up_proj\",\n    \"down_proj\",\n]\nROW_PARALLEL = {\"o_proj\", \"down_proj\", \"lm_head\"}\n\n\nclass FlashCausalLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n        lora_adapter_ids: Optional[list] = [],\n        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,\n        config_class: PreTrainedTokenizerBase = AutoConfig,\n        default_dtype=torch.float16,\n        aliases=None,\n        # Used for Santacoder override of config\n        num_kv_heads: Optional[int] = None,\n        # Deepseek V2 uses different QK and V dims.\n        head_size: Optional[int] = None,\n        skip_special_tokens: bool = True,\n        kv_cache_dtype: Optional[torch.dtype] = None,\n        support_chunking: bool = True,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        if world_size > 1:\n            self.process_group_cpu = torch.distributed.new_group(backend=\"gloo\")\n\n        device = torch.device(\"hpu\")\n        dtype = torch.bfloat16 if dtype is None else dtype\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        try:\n            generation_config = GenerationConfig.from_pretrained(\n                model_id, revision=revision, trust_remote_code=trust_remote_code\n            )\n            if isinstance(generation_config.eos_token_id, (list, set)):\n                # TODO Huge hack\n                tokenizer._eos_token_ids = set(generation_config.eos_token_id)\n        except Exception:\n            pass\n\n        config = config_class.from_pretrained(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n\n        torch.distributed.barrier(group=self.process_group)\n\n        weights_loader = get_loader(quantize, model_id, revision)\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device,\n            dtype,\n            process_group=self.process_group,\n            aliases=aliases,\n            weights_loader=weights_loader,\n        )\n\n        prefix = None\n        model = model_class(prefix, config, weights)\n        torch.distributed.barrier(group=self.process_group)\n\n        # VLM models define the config we care about in their text_config\n        text_config = getattr(config, \"text_config\", None)\n        if text_config is not None:\n            config = text_config\n\n        if getattr(config, \"sliding_window\", None) is None:\n            config.sliding_window = None\n        if getattr(config, \"use_sliding_window\", True) is False:\n            config.sliding_window = None\n\n        self.num_layers = config.num_hidden_layers\n        self.num_heads = config.num_attention_heads // self.process_group.size()\n        self.config = config\n        # Validation is done in the model itself\n        if num_kv_heads is None:\n            num_kv_heads = getattr(config, \"num_key_value_heads\", None)\n            # GPT-2 workaround\n            if num_kv_heads is None:\n                num_kv_heads = getattr(config, \"n_head\", None)\n        if num_kv_heads is None:\n            raise ValueError(\"Cannot get the number of key/value heads\")\n        self.num_kv_heads = (\n            num_kv_heads // self.process_group.size()\n            if num_kv_heads // self.process_group.size() > 0\n            else num_kv_heads\n        )\n        assert self.num_kv_heads > 0\n\n        if head_size is None:\n            # Some models use GQA and different sizes for o_proj\n            # and q_proj, that allows for that.\n            if getattr(config, \"head_dim\", None) is not None:\n                self.head_size = config.head_dim\n            else:\n                self.head_size = config.hidden_size // config.num_attention_heads\n        else:\n            self.head_size = head_size\n\n        self.cuda_graphs = {}\n        self.kv_cache = []\n        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype\n        self.bucketing_ctx = None\n        self.max_total_tokens = None\n        self.max_input_tokens = None\n        htorch.core.hpu_set_env()\n        if htorch.utils.internal.is_lazy():\n            htorch.hpu.wrap_in_hpu_graph(model, disable_tensor_cache=True)\n        environment.set_model_config(self.config)\n        self.use_contiguous_pa = (\n            os.environ.get(\"VLLM_CONTIGUOUS_PA\", \"true\").lower() == \"true\"\n        )\n        self.limit_hpu_graph = (\n            os.environ.get(\"LIMIT_HPU_GRAPH\", \"false\").lower() == \"true\"\n        )\n        self.skip_warmup = os.getenv(\"VLLM_SKIP_WARMUP\", \"false\").lower() == \"true\"\n        self.max_seq_len_to_capture = 8192\n        if tokenizer.pad_token_id is None:\n            if config.pad_token_id is not None:\n                tokenizer.pad_token_id = config.pad_token_id\n            elif config.eos_token_id is not None:\n                tokenizer.pad_token_id = (\n                    config.eos_token_id[0]\n                    if isinstance(config.eos_token_id, list)\n                    else config.eos_token_id\n                )\n            elif tokenizer.eos_token_id is not None:\n                tokenizer.pad_token_id = tokenizer.eos_token_id\n            else:\n                tokenizer.pad_token_id = 0\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=False,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n            sliding_window=config.sliding_window,\n            support_chunking=support_chunking,\n        )\n\n    @property\n    def batch_type(self) -> Type[FlashCausalLMBatch]:\n        return FlashCausalLMBatch\n\n    def max_past(self) -> int:\n        return getattr(self.model, \"max_past\", None)\n\n    def init_kv_cache(\n        self,\n        num_blocks: int,\n        num_layers: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n    ):\n        self.kv_cache = []\n        empty_cache()\n        if self.config.model_type in [\"deepseek_v3\", \"deepseek_v2\"]:\n            self.kv_cache = [\n                KVCompressCache(\n                    num_blocks=num_blocks,\n                    head_size=self.config.kv_lora_rank + self.config.qk_rope_head_dim,\n                    dtype=dtype,\n                    device=device,\n                )\n                for _ in range(num_layers)\n            ]\n        else:\n            self.kv_cache = [\n                KVCache(\n                    num_blocks=num_blocks,\n                    num_heads=num_heads,\n                    head_size=head_size,\n                    dtype=dtype,\n                    device=device,\n                )\n                for _ in range(num_layers)\n            ]\n\n    def warmup(\n        self,\n        batch: FlashCausalLMBatch,\n        max_input_tokens: Optional[int],\n        max_total_tokens: Optional[int],\n    ):\n        if os.environ.get(\"MAX_BATCH_SIZE\") is None:\n            raise RuntimeError(\n                \"MAX_BATCH_SIZE is not set, it should be set in the launcher \"\n                \"using `--max-batch-size xxx`\"\n            )\n        # The warmup batch is the biggest batch we could ever receive\n        self.kv_cache = []\n        empty_cache()\n        self.graphed_buckets = set()\n        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)\n        # Calculate the number of blocks that can be allocated with the free memory\n        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()\n        if self.config.model_type in [\"deepseek_v3\", \"deepseek_v2\"]:\n            cache_block_size = BLOCK_SIZE * (\n                self.config.kv_lora_rank + self.config.qk_rope_head_dim\n            )\n        else:\n            cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size\n            cache_block_size = cache_block_size * 2\n        total_cache_size = self.num_layers * cache_block_size * dtype_size\n        free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)\n        self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))\n        graph_reserved_mem = (\n            float(os.environ.get(\"TGI_GRAPH_RESERVED_MEM\", \"0.1\"))\n            if htorch.utils.internal.is_lazy()\n            else 0\n        )\n        mem_used_from_graph = int(\n            (free_memory - self.mem_reserved) * graph_reserved_mem\n        )\n        log_master(\n            logger.info,\n            f\"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}\",\n        )\n        if max_total_tokens is None:\n            max_total_tokens = sum(batch.input_lengths)\n\n        if max_input_tokens is None:\n            max_input_tokens = max_total_tokens - 1\n\n        self.max_total_tokens = max_total_tokens\n        self.max_input_tokens = max_input_tokens\n        try:\n            self.init_kv_cache(\n                batch.num_blocks,\n                self.num_layers,\n                self.num_kv_heads,\n                self.head_size,\n                self.kv_cache_dtype,\n                self.device,\n            )\n\n            batch_num_blocks = batch.num_blocks\n\n            num_tokens = batch.to_pb().current_tokens\n            synchronize(self.device)\n            _, _batch, _ = self.generate_token([batch])\n        except Exception:\n            raise RuntimeError(\n                f\"Not enough memory to handle {num_tokens} prefill tokens. \"\n                f\"You need to decrease `--max-batch-prefill-tokens`\"\n            )\n\n        synchronize(self.device)\n        free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)\n\n        kv_memory = free_memory - self.mem_reserved - mem_used_from_graph\n        num_blocks = (\n            # Leave 5% for some wiggle room\n            int(kv_memory // total_cache_size)\n            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.\n            + batch_num_blocks\n        )\n\n        log_master(logger.info, f\"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}\")\n\n        self.kv_cache = []\n        empty_cache()\n        self.init_kv_cache(\n            num_blocks,\n            self.num_layers,\n            self.num_kv_heads,\n            self.head_size,\n            self.kv_cache_dtype,\n            self.device,\n        )\n        self.max_batch_prefill_tokens = get_max_prefill_tokens()\n        max_num_seqs = int(os.getenv(\"MAX_BATCH_SIZE\"))\n        HPUBucketingContext = get_bucketing_context()\n        # need to warmup one more step since block is allocated from 1\n        block_step = os.getenv(\"VLLM_DECODE_BLOCK_BUCKET_STEP\", BLOCK_SIZE)\n        max_total_tokens_aligned = math.ceil(\n            max_total_tokens / BLOCK_SIZE\n        ) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)\n        model_max_length = self.tokenizer.model_max_length\n        max_position_embeddings = getattr(\n            self.config, \"max_position_embeddings\", model_max_length\n        )\n\n        self.bucketing_ctx = HPUBucketingContext(\n            max_num_seqs,\n            max_num_seqs,  # self.max_num_prefill_seqs, #TODO\n            BLOCK_SIZE,\n            max_num_seqs * max_total_tokens_aligned,\n            False,\n            min(model_max_length, max_position_embeddings),\n            max_input_tokens,\n            max_total_tokens_aligned,\n        )\n        max_blocks = max(\n            BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE\n        )\n        self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)\n        synchronize(self.device)\n        if self.skip_warmup:\n            self.bucketing_ctx.generate_prompt_buckets()\n            self.bucketing_ctx.generate_decode_buckets(\n                self.bucketing_ctx.num_hpu_blocks\n            )\n            log_master(\n                logger.info, \"skip warmup hpu graph, not recommmended, may cause OOM\"\n            )\n            del _batch, batch\n            return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens\n        self.warmup_hpu_graph(batch)\n        del _batch, batch\n\n        return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens\n\n    def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):\n        free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())\n        phase = \"Prompt\" if prefilling else \"Decode\"\n        dim = \"seq_len\" if prefilling else \"num_blocks\"\n        graphed_bucket = (batch_size, seq_len, prefilling)\n        bypass = graphed_bucket not in self.graphed_buckets\n        msg = (\n            f\"[Warmup][{phase}][{i+1}/{max_i}] \"\n            f\"batch_size:{batch_size} \"\n            f\"{dim}:{seq_len} \"\n            f\"bypass:{bypass} \"\n            f\"free_mem:{free_mem}\"\n            \", this may take a while...\"\n        )\n        log_master(logger.info, msg)\n\n    def use_graphs(self, prefill, seq_len, batch_size):\n        if self.limit_hpu_graph and prefill:\n            return False\n\n        if self.skip_warmup:\n            return True\n\n        return (batch_size, seq_len, prefill) in self.graphed_buckets\n\n    def align_workers(self, value, op):\n        if self.world_size <= 1:\n            return value\n        value_t = torch.tensor(value, device=\"cpu\")\n        torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)\n        return value_t.item()\n\n    def warmup_hpu_graph(self, batch):\n        prompt_graph_mem_ratio = float(os.environ.get(\"VLLM_GRAPH_PROMPT_RATIO\", \"0.3\"))\n        free_mem = HabanaMemoryProfiler.current_free_device_memory()\n        graph_free_mem = free_mem - self.mem_reserved\n        graph_free_mem = self.align_workers(\n            graph_free_mem, torch.distributed.ReduceOp.MIN\n        )\n        prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem\n        decode_available_memory = graph_free_mem - prompt_available_memory\n        msg = (\n            f\"Using {format_bytes(graph_free_mem)}\"\n            f\"/{format_bytes(free_mem)} \"\n            \"of free device memory for HPUGraphs, \"\n            f\"{format_bytes(prompt_available_memory)} for prompt and \"\n            f\"{format_bytes(decode_available_memory)} for decode \"\n            f\"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})\"\n        )\n        log_master(logger.info, msg)\n        start_time = time.time()\n        warmup_shape_count = 0\n        warmup_times = 3\n        self.bucketing_ctx.generate_prompt_buckets()\n\n        def ordering_function_min_tokens(b):\n            return (b[0] * b[1], b[1], b[0])\n\n        buckets = list(\n            sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)\n        )\n        total_batch_seq = 0.001\n        total_mem = 0\n        available_mem = prompt_available_memory\n        msg = (\n            f\"Prefill batch size list:{[bsz[0] for bsz in buckets]}\\n\"\n            f\"Prefill sequence length list:{[seq[1] for seq in buckets]}\\n\"\n        )\n        log_master(logger.info, msg)\n        for i, (batch_size, seq_len) in enumerate(buckets):\n            if batch_size * seq_len > self.max_batch_prefill_tokens:\n                continue\n            # Graph memory usage is proportional to seq dimension in a batch\n            batch_seq = batch_size * seq_len\n            mem_estimate = batch_seq / total_batch_seq * total_mem\n            graphed_bucket = (batch_size, seq_len, True)\n            if not (\n                mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture\n            ):\n                if graphed_bucket not in self.graphed_buckets:\n                    self.graphed_buckets.add(graphed_bucket)\n            warmup_shape_count += 1\n            self.log_warmup(True, i, len(buckets), batch_size, seq_len)\n            with HabanaMemoryProfiler() as mem_prof:\n                for index in range(warmup_times):\n                    self.warmup_prefill(seq_len, batch_size, batch)\n                    synchronize(self.device)\n            used_mem = self.align_workers(\n                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX\n            )\n            if graphed_bucket in self.graphed_buckets:\n                available_mem -= used_mem\n                total_mem += used_mem\n                total_batch_seq += batch_seq\n\n        log_master(logger.info, \"Prefill warmup successful.\\n\")\n\n        def ordering_function_max_bs(b):\n            return (-b[0], b[1])\n\n        self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)\n        buckets = list(\n            sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)\n        )\n        free_mem = HabanaMemoryProfiler.current_free_device_memory()\n        total_batch_seq = 0.001\n        total_mem = 0\n        available_mem = free_mem - self.mem_reserved\n        log_master(\n            logger.info, f\"Decode batch size list:{[bsz[0] for bsz in buckets]}\\n\"\n        )\n        for i, (batch_size, block_num) in enumerate(buckets):\n            if batch_size > block_num:\n                continue\n            # Graph memory usage is proportional to seq dimension in a batch\n            batch_seq = batch_size\n            mem_estimate = batch_seq / total_batch_seq * total_mem\n            graphed_bucket = (batch_size, block_num, False)\n            if not mem_estimate >= available_mem:\n                if graphed_bucket not in self.graphed_buckets:\n                    self.graphed_buckets.add(graphed_bucket)\n            warmup_shape_count += 1\n            self.log_warmup(False, i, len(buckets), batch_size, block_num)\n            with HabanaMemoryProfiler() as mem_prof:\n                for index in range(warmup_times):\n                    self.warmup_decode(batch_size, block_num, batch)\n                    synchronize(self.device)\n            used_mem = self.align_workers(\n                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX\n            )\n            if graphed_bucket in self.graphed_buckets:\n                available_mem -= used_mem\n                total_mem += used_mem\n                total_batch_seq += batch_seq\n\n        log_master(logger.info, \"Decode warmup successful.\\n\")\n\n        log_master(\n            logger.info,\n            f\"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}\",\n        )\n\n    def warmup_prefill(\n        self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch\n    ):\n        input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(\n            batch_size\n        )\n        position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(\n            batch_size\n        )\n        max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size\n        block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)\n        slot_acc = []\n        for i in range(batch_size):\n            slots = []\n            for b in block_tables[i]:\n                slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))\n            slot_acc.extend(slots[:prompt_len])\n        slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)\n\n        input_lengths = torch.ones(batch_size, dtype=torch.int32) * prompt_len\n        cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)\n        torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])\n\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        lm_head_indices = input_lengths - 1\n        kwargs = {}\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                True, prompt_len, batch_size\n            )\n        if self.sliding_window is not None:\n            attn_mask = seqlen.make_sliding_window_bias(\n                input_lengths.tolist(),\n                self.sliding_window,\n                self.dtype,\n                prompt_len,\n                batch_size,\n            )\n            seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)\n\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        self.model.forward(\n            input_ids=_async_h2d_tensor_copy(input_ids),\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),\n            kv_cache=self.kv_cache,\n            slots=_async_h2d_tensor_copy(slots),\n            seqlen=trim_seqlen_metadata(seqlen),\n            lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),\n            adapter_data=None,\n            hpu_attention_meta=None,\n            **kwargs,\n        )\n\n    def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBatch):\n        input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)\n        position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)\n        blocks = [block_num // batch_size for _ in range(batch_size)]\n        blocks[0] += block_num % batch_size\n        block_tables = []\n        slots = []\n        start_idx = 0\n        slot_indices = []\n\n        # fetch the last blocked to warmup block num\n        for i in range(batch_size):\n            block_array = list(range(start_idx, start_idx + blocks[i]))\n            slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)\n            slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)\n            block_tables.append(block_array)\n            start_idx += blocks[i]\n        input_lengths = torch.ones(batch_size, dtype=torch.int32)\n        cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)\n        torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])\n\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        block_list, block_groups, block_usage, _, block_bucket_size = (\n            generate_block_metadata(\n                self.dtype,\n                self.use_contiguous_pa,\n                slots,\n                block_tables,\n                self.bucketing_ctx,\n            )\n        )\n        meta = HPUPagedAttentionMetadata(\n            block_list=_async_h2d_tensor_copy(block_list),\n            block_groups=_async_h2d_tensor_copy(block_groups),\n            block_usage=_async_h2d_tensor_copy(block_usage),\n            block_mapping=None,\n            attn_bias=None,\n        )\n        if self.sliding_window is not None:\n            block_tables_in_window = []\n            for i, bt in enumerate(block_tables):\n                block_num_in_window = (\n                    self.sliding_window + BLOCK_SIZE - 1\n                ) // BLOCK_SIZE\n                block_tables_in_window.append(\n                    bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]\n                )\n            slots_in_window = []\n            start_idx = 0\n            for i, indice in enumerate(slot_indices):\n                mask = (\n                    indice - torch.arange(start_idx, indice + 1)\n                ) < self.sliding_window\n                slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])\n                start_idx += blocks[i] * BLOCK_SIZE\n            slots_in_window = torch.cat(slots_in_window, dim=0)\n            (\n                block_list_in_window,\n                block_groups_in_window,\n                block_usage_in_window,\n                slots_in_window_mask,\n                _,\n            ) = generate_block_metadata(\n                self.dtype,\n                self.use_contiguous_pa,\n                slots,\n                block_tables_in_window,\n                self.bucketing_ctx,\n                slots_in_window,\n                block_bucket_size,\n            )\n            meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)\n            meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)\n            meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)\n            meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)\n\n        hpu_attention_meta = trim_attn_metadata(meta)\n        slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)\n        kwargs = {}\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                False, hpu_attention_meta.block_list.shape[0], batch_size\n            )\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        self.model.forward(\n            input_ids=_async_h2d_tensor_copy(input_ids),\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=None,\n            kv_cache=self.kv_cache,\n            slots=_async_h2d_tensor_copy(slots_tensor),\n            seqlen=trim_seqlen_metadata(seqlen),\n            lm_head_indices=None,\n            adapter_data=None,\n            hpu_attention_meta=hpu_attention_meta,\n            **kwargs,\n        )\n\n    def forward(\n        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n\n            # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,\n            # then update the slots with the additional indices to ensure we're grabbing the ones that have been\n            # allocated\n            slot_indices = (\n                batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n            slots = batch.slots[slot_indices]\n\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        if cu_seqlen_prefill is None and self.max_past() is not None:\n            # In decode, not prefill, we're actually overwriting the KV-cache\n            # in a circular buffer mode.\n            # This makes sure the max_s for the decode pass is correct.\n            max_s = min(self.max_past(), max_s)\n        if batch.prefill_cache_indices is not None:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[batch.prefill_cache_indices] = slots\n            slots = slots_pad\n        else:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[: slots.shape[0]] = slots\n            slots = slots_pad\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n\n        kwargs = {}\n        batch_size = input_lengths.shape[0]\n        prompt_len = (\n            input_ids.shape[0] // batch_size\n            if batch.prefilling\n            else batch.hpu_attn_meta.block_list.shape[0]\n        )\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                batch.prefilling, prompt_len, batch_size\n            )\n        if self.sliding_window is not None and batch.prefilling:\n            attn_mask = seqlen.make_sliding_window_bias(\n                input_lengths.tolist(),\n                self.sliding_window,\n                self.dtype,\n                prompt_len,\n                batch_size,\n            )\n            seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)\n\n        logits, speculative_logits = self.model.forward(\n            input_ids=input_ids,\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),\n            kv_cache=kv_cache,\n            slots=_async_h2d_tensor_copy(slots),\n            seqlen=trim_seqlen_metadata(seqlen),\n            lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),\n            # TODO not support adapter now, need the add in the future\n            adapter_data=None,\n            hpu_attention_meta=batch.hpu_attn_meta,\n            **kwargs,\n        )\n        return logits, speculative_logits\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batches: List[FlashCausalLMBatch]\n    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:\n\n        # In order to pipeline any actions on CPU we perform the operation in 3 main stages:\n        # Stage 1. Collect next token ids of any previously started generations\n        start = time.time_ns()\n        prev_batches = []\n        requests_to_generate = []\n        for batch_id, batch in enumerate(batches):\n            if batch.next_token_logits is not None:\n                prefill = batch.prefilling\n                if batch.prefilling:\n                    batch.prefilling = False\n                    batch.prefilling_mask = [False] * len(batch)\n\n                speculate = get_speculate()\n                (\n                    next_input_ids,\n                    next_token_logprobs,\n                    logprobs,\n                    accepted_ids,\n                    speculative_ids,\n                ) = batch.next_token_chooser(\n                    batch.all_input_ids_tensor[\n                        : batch.next_token_logits.shape[0], : batch.max_current_length\n                    ],\n                    batch.next_token_logits,\n                    speculate,\n                    batch.speculative_ids,\n                    batch.speculative_logits,\n                )\n\n                batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n                    batch.top_n_tokens,\n                    _async_h2d_tensor_copy(batch.top_n_tokens_tensor),\n                    logprobs,\n                    accepted_ids,\n                )\n                if batch.valid_indices is not None:\n                    # TODO speculative decoding handling missing\n                    index = torch.arange(\n                        0,\n                        len(batch.valid_indices),\n                        device=batch.all_input_ids_tensor.device,\n                    )\n                    batch.all_input_ids_tensor.index_copy_(\n                        0, index, batch.all_input_ids_tensor[batch.valid_indices]\n                    )\n                    padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(\n                        len(batch.valid_indices)\n                    )\n                    next_input_ids.index_copy_(\n                        0, index, next_input_ids[batch.valid_indices]\n                    )\n                    next_input_ids = next_input_ids[:padded_total_bs]\n\n                    next_token_logprobs.index_copy_(\n                        0, index, next_token_logprobs[batch.valid_indices]\n                    )\n                    accepted_ids.index_copy_(\n                        0, index, accepted_ids[batch.valid_indices]\n                    )\n                    if speculative_ids is not None:\n                        speculative_ids = speculative_ids[batch.valid_indices]\n                    batch.top_n_tokens_tensor = batch.top_n_tokens_tensor[\n                        batch.valid_indices\n                    ]\n                    top_n_tokens = []\n                    batch_top_token_ids_v = []\n                    batch_top_token_logprobs_v = []\n                    for i in batch.valid_indices:\n                        top_n_tokens.append(batch.top_n_tokens[i])\n                        batch_top_token_ids_v.append(batch_top_token_ids[i])\n                        batch_top_token_logprobs_v.append(batch_top_token_logprobs[i])\n                    batch_top_token_ids = batch_top_token_ids_v\n                    batch_top_token_logprobs = batch_top_token_logprobs_v\n                    batch.top_n_tokens = top_n_tokens\n                    batch.next_token_chooser = batch.next_token_chooser.filter(\n                        batch.valid_indices\n                    )\n                    batch.valid_indices = None\n\n                # Since we are done prefilling, all the tensors that were concatenating values for all the requests\n                # instantly become of shape [BATCH_SIZE]\n                if prefill:\n                    indices = batch.cu_seqlen_prefill[1:] - 1\n                    # pad in left\n                    if batch.prefill_cache_indices is not None:\n                        batch.position_ids = batch.position_ids[\n                            batch.prefill_cache_indices\n                        ][indices]\n                    else:\n                        batch.position_ids = batch.position_ids[indices]\n\n                    batch.slot_indices = batch.slot_indices[indices[: len(batch)]]\n                    if batch.adapter_meta is not None:\n                        batch.adapter_meta.adapter_indices = (\n                            batch.adapter_meta.adapter_indices[indices]\n                        )\n                # For each member of the batch\n                # Cumulative length\n\n                if batch.speculative_logits is not None:\n                    cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)\n                    torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])\n                    for i in range(len(batch)):\n                        batch.all_input_ids_tensor[\n                            i,\n                            batch.cache_lengths[i]\n                            + batch.input_lengths[i] : batch.cache_lengths[i]\n                            + batch.input_lengths[i]\n                            + accepted_ids[i],\n                        ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]\n                    batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]\n                    accepted_ids = accepted_ids.cpu()\n                    if batch.position_ids.dim() == 2:\n                        # Qwen2_vl case:\n                        batch.position_ids += accepted_ids.unsqueeze(-1)\n                    else:\n                        batch.position_ids += accepted_ids\n                    batch.cache_lengths_tensor += (\n                        batch.input_lengths_tensor + accepted_ids - 1\n                    )\n                    batch.input_lengths_tensor = torch.ones_like(\n                        batch.input_lengths_tensor\n                    )\n                    batch.slot_indices += accepted_ids[: len(batch)]\n                else:\n                    index = batch.cache_lengths_tensor + batch.input_lengths_tensor\n                    index = F.pad(\n                        index, (0, next_input_ids.shape[0] - index.shape[0]), value=0\n                    )\n                    index = index.to(batch.all_input_ids_tensor.device)\n                    batch_idx = torch.arange(\n                        0,\n                        index.shape[0],\n                        dtype=torch.long,\n                        device=batch.all_input_ids_tensor.device,\n                    )\n                    batch.all_input_ids_tensor.index_put_(\n                        (batch_idx, index.long()), next_input_ids\n                    )\n                    batch.input_ids = next_input_ids\n                    batch.position_ids += 1\n                    batch.cache_lengths_tensor += batch.input_lengths_tensor\n                    batch.input_lengths_tensor = torch.ones_like(\n                        batch.input_lengths_tensor\n                    )\n                    batch.slot_indices += 1\n\n                batch.speculative_ids = speculative_ids\n\n                # Does a HPU <-> CPU sync internally\n                if prefill and batch.adapter_meta is not None:\n                    # adjust segment lengths to account for all request lengths being 1 during decoding\n                    adapter_segments, _ = find_segments(\n                        batch.adapter_meta.adapter_indices\n                    )\n                    batch.adapter_meta.adapter_segments = torch.tensor(\n                        adapter_segments,\n                        dtype=torch.int32,\n                        device=batch.adapter_meta.adapter_segments.device,\n                    )\n                prev_batches.append(\n                    {\n                        \"next_token_ids\": next_input_ids,\n                        \"next_token_logprobs\": next_token_logprobs,\n                        \"accepted_ids\": accepted_ids,\n                    }\n                )\n                idx = len(prev_batches) - 1\n\n                for req_idx, req in enumerate(batch.requests):\n                    new_input_length = 1\n                    if batch.speculative_logits is not None:\n                        new_cache_length = (\n                            batch.cache_lengths[req_idx]\n                            + batch.input_lengths[req_idx]\n                            + accepted_ids[req_idx]\n                            - 1\n                        )\n                    else:\n                        new_cache_length = (\n                            batch.cache_lengths[req_idx] + batch.input_lengths[req_idx]\n                        )\n                    batch.cache_lengths[req_idx] = new_cache_length\n                    batch.max_input_length = max(\n                        batch.max_input_length, new_input_length\n                    )\n                    batch.input_lengths[req_idx] = new_input_length\n                    current_length = new_cache_length + new_input_length\n                    batch.max_current_length = max(\n                        batch.max_current_length, current_length\n                    )\n\n                    requests_to_generate.append(\n                        {\n                            \"idx\": idx,\n                            \"request_id\": req.id,\n                            \"prefix_offset\": batch.prefix_offsets[req_idx],\n                            \"read_offset\": batch.read_offsets[req_idx],\n                            \"stopping_criteria\": batch.stopping_criterias[req_idx],\n                            \"all_input_ids\": batch.all_input_ids[req_idx],\n                            \"do_sample\": batch.next_token_chooser.do_sample[req_idx],\n                            \"seed\": batch.next_token_chooser.seeds[req_idx],\n                            \"top_n_tokens\": batch.top_n_tokens[req_idx],\n                            \"top_token_ids\": batch_top_token_ids[req_idx],\n                            \"top_token_logprobs\": batch_top_token_logprobs[req_idx],\n                        }\n                    )\n                if prefill:\n                    # We do not need prefill tensors anymore\n                    batch.cu_seqlen_prefill = None\n                    batch.prefill_cache_indices = None\n                    batch.prefill_cu_outlens = None\n                    batch.prefill_head_indices = None\n                    batch.prefill_next_token_indices = None\n                batch.next_token_logits = None\n                batch.speculative_ids = None\n\n        htorch.core.mark_step()\n        # Stage 2. Prepare new batch for speculative scheduling\n        if len(batches) > 1:\n            if self.bucketing_ctx is not None:\n                total_batch_size = 0\n                for b in batches:\n                    total_batch_size += len(b)\n                padded_total_bs = self.bucketing_ctx.get_padded_decode_batch_size(\n                    total_batch_size\n                )\n                batch = self.batch_type.concatenate(\n                    batches, padded_total_bs=padded_total_bs\n                )\n            else:\n                batch = self.batch_type.concatenate(batches)\n        else:\n            batch = batches[0]\n        prefill = batch.prefilling\n        if prefill:\n            if self.bucketing_ctx is not None:\n                batch.prepare_for_prefill(\n                    self.bucketing_ctx.get_padded_prompt_seq_len(\n                        batch.max_input_length\n                    ),\n                    self.bucketing_ctx.get_padded_prompt_batch_size(len(batch)),\n                    self.max_total_tokens,\n                    self.tokenizer.pad_token_id,\n                )\n            else:\n                batch.prepare_for_prefill(\n                    batch.max_input_length,\n                    len(batch),\n                    self.max_total_tokens,\n                    self.tokenizer.pad_token_id,\n                )\n        else:\n            batch.prepare_for_decode(\n                self.dtype,\n                self.use_contiguous_pa,\n                self.bucketing_ctx,\n                self.tokenizer.pad_token_id,\n                self.sliding_window,\n            )\n        if hasattr(self, \"set_inputs_embeds\") and callable(self.set_inputs_embeds):\n            self.set_inputs_embeds(batch)\n        prefill_logprobs = batch.prefill_next_token_indices is not None\n        # Update adapter indices for speculative tokens (if present)\n        adapter_meta = batch.adapter_meta\n        if adapter_meta is not None:\n            if batch.speculative_ids is not None:\n                B, speculative_length = batch.speculative_ids.shape\n                new_length = speculative_length + 1\n                adapter_indices = (\n                    adapter_meta.adapter_indices.unsqueeze(-1)\n                    .expand(B, new_length)\n                    .reshape(-1)\n                )\n                adapter_segments = adapter_meta.adapter_segments * new_length\n                adapter_meta = AdapterBatchMetadata(\n                    adapter_indices=adapter_indices,\n                    adapter_set=adapter_meta.adapter_set,\n                    adapter_segments=adapter_segments,\n                    segment_indices=adapter_meta.segment_indices,\n                )\n\n            # Assign pointers to adapter weights\n            # TODO(travis): don't update this if indices haven't changed\n            adapter_data = AdapterBatchData.from_meta(\n                adapter_meta,\n                self.layer_to_adapter_weights,\n                prefill,\n                batch.prefill_head_indices,\n            )\n        else:\n            adapter_data = None\n\n        out, speculative_logits = self.forward(batch, adapter_data)\n\n        if prefill:\n            batch.next_token_logits = (\n                out[batch.prefill_next_token_indices] if prefill_logprobs else out\n            )\n            if speculative_logits is not None:\n                speculative_logits = (\n                    speculative_logits[batch.prefill_next_token_indices]\n                    if prefill_logprobs\n                    else speculative_logits\n                )\n        else:\n            prefill_logprobs = None\n            batch.next_token_logits = out\n        batch.speculative_logits = speculative_logits\n\n        # HPU->CPU sync\n        htorch.core.mark_step()\n        start_decode = time.time_ns()\n        for prev_batch in prev_batches:\n            prev_batch[\"next_token_logprobs\"] = prev_batch[\n                \"next_token_logprobs\"\n            ].tolist()\n            prev_batch[\"next_token_ids\"] = prev_batch[\"next_token_ids\"].tolist()\n            prev_batch[\"accepted_ids\"] = prev_batch[\"accepted_ids\"].tolist()\n        htorch.core.mark_step()\n        # Stage 3. Finish and return previous generations\n        # Results\n        generations: List[Generation] = []\n        stopped = len(requests_to_generate) > 0\n        # Reset max_input_length\n        batch.max_input_length = 0\n        # For each member of the batch\n        indexs = [0] * len(prev_batches)\n        idx_accept_ids = [0] * len(prev_batches)\n        for i, req_data in enumerate(requests_to_generate):\n            idx = req_data[\"idx\"]\n            request_id = req_data[\"request_id\"]\n            prefix_offset = req_data[\"prefix_offset\"]\n            read_offset = req_data[\"read_offset\"]\n            stopping_criteria = req_data[\"stopping_criteria\"]\n            all_input_ids = req_data[\"all_input_ids\"]\n            do_sample = req_data[\"do_sample\"]\n            seed = req_data[\"seed\"]\n            top_n_tokens = req_data[\"top_n_tokens\"]\n            n_accepted_ids = prev_batches[idx][\"accepted_ids\"][idx_accept_ids[idx]]\n            top_token_ids = req_data[\"top_token_ids\"]\n            top_token_logprobs = req_data[\"top_token_logprobs\"]\n            # Append next token to all tokens\n            next_token_texts = []\n            left = 0\n\n            if n_accepted_ids > 1:\n                log_master(logger.debug, f\"speculated ids {n_accepted_ids - 1}\")\n\n            current_stopped = False\n            index = indexs[idx]\n            for j in range(index, index + n_accepted_ids):\n                # Generated token\n                next_token_id = prev_batches[idx][\"next_token_ids\"][j]\n                all_input_ids.append(next_token_id)\n                next_token_text, prefix_offset, read_offset = self.decode_token(\n                    all_input_ids,\n                    prefix_offset,\n                    read_offset,\n                )\n                next_token_texts.append(next_token_text)\n\n                stop, reason = stopping_criteria(\n                    next_token_id,\n                    next_token_text,\n                )\n\n                if stop:\n                    left = index + n_accepted_ids - j - 1\n                    current_stopped = True\n                    break\n                else:\n                    current_stopped = False\n            stopped = stopped and current_stopped\n\n            _next_token_ids = prev_batches[idx][\"next_token_ids\"][\n                index : index + n_accepted_ids - left\n            ]\n            _next_token_logprobs = prev_batches[idx][\"next_token_logprobs\"][\n                index : index + n_accepted_ids - left\n            ]\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if request_id % self.world_size == self.rank:\n                if stop:\n                    # Decode generated tokens\n                    output_text, _, _ = self.decode_token(\n                        all_input_ids,\n                        prefix_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens\n                        - 1,\n                        read_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens,\n                        skip_special_tokens=True,\n                    )\n                    generated_text = GeneratedText(\n                        output_text,\n                        stopping_criteria.current_tokens,\n                        reason,\n                        seed if do_sample else None,\n                    )\n                else:\n                    generated_text = None\n\n                if top_n_tokens > 0:\n                    all_top_tokens = []\n                    for top_token_ids, top_token_logprobs in zip(\n                        top_token_ids, top_token_logprobs\n                    ):\n                        toptoken_texts = self.tokenizer.batch_decode(\n                            top_token_ids,\n                            clean_up_tokenization_spaces=False,\n                            skip_special_tokens=False,\n                        )\n                        special_toptokens = [\n                            token_id in self.all_special_ids\n                            for token_id in top_token_ids\n                        ]\n                        top_tokens = Tokens(\n                            top_token_ids,\n                            top_token_logprobs,\n                            toptoken_texts,\n                            special_toptokens,\n                        )\n                        all_top_tokens.append(top_tokens)\n                    top_tokens = all_top_tokens\n                else:\n                    top_tokens = None\n\n                generation = Generation(\n                    request_id,\n                    None,\n                    Tokens(\n                        _next_token_ids,\n                        _next_token_logprobs,\n                        next_token_texts,\n                        [nid in self.all_special_ids for nid in _next_token_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n                # accept each new token for this specific request since we may\n                # have more than one new token per request with speculative decoding\n                for next_token_id in _next_token_ids:\n                    batch.next_token_chooser = (\n                        batch.next_token_chooser.advance_grammar_single(\n                            i, next_token_id\n                        )\n                    )\n\n            # Update values\n            indexs[idx] += n_accepted_ids\n            idx_accept_ids[idx] += 1\n\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.all_input_ids[i] = all_input_ids\n        htorch.core.mark_step()\n        if stopped:\n            # No need to return a batch if we know that all requests stopped\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/flash_vlm_causal_lm.py",
    "content": "import torch\nfrom PIL import Image\nfrom io import BytesIO\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom typing import Iterable, Optional, Tuple, List, Type, Dict\n\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.image_processing_utils import select_best_resolution\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.flash_causal_lm import (\n    FlashCausalLMBatch,\n    FlashCausalLM,\n    generate_block_metadata,\n)\nfrom text_generation_server.models.globals import PREFIX_CACHING, BLOCK_SIZE\nfrom loguru import logger\nfrom text_generation_server.utils.log import log_master\nfrom transformers import AutoProcessor\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    trim_seqlen_metadata,\n    _async_h2d_tensor_copy,\n    HPUPagedAttentionMetadata,\n    trim_attn_metadata,\n)\nimport habana_frameworks.torch as htorch\nimport time\nfrom text_generation_server.utils.import_utils import (\n    synchronize,\n)\nfrom vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes\n\ntracer = trace.get_tracer(__name__)\n\nIDEFICS2_FAKE_TOKEN = \"<fake_token_around_image>\"\nIDEFICS2_IMAGE_TOKEN = \"<image>\"\n\nIDEFICS3_IMAGE_TOKEN = \"<image>\"\nIDEFICS3_FAKE_IMAGE_TOKEN = \"<fake_token_around_image>\"\nIDEFICS3_GLOBAL_IMG_TOKEN = \"<global-img>\"\n\n\ndef prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):\n    \"\"\"\n    Create a structured string representation of image tokens\n\n    Args:\n       num_patches: Number of patches in the image\n\n    Returns:\n        String with appropriate image tokens\n    \"\"\"\n    img_string = \"<|image_start|>\"\n    ratio_h, ratio_w = aspect_ratio\n    if ratio_h * ratio_w > 1:\n        for yy in range(ratio_h):\n            for xx in range(ratio_w):\n                img_string += \"<|patch|>\" * num_patches_per_chunk\n                if xx < ratio_w - 1:\n                    img_string += \"<|tile_x_separator|>\"\n\n            img_string += \"<|tile_y_separator|>\"\n    img_string += \"<|image|>\"\n    img_string += \"<|patch|>\" * num_patches_per_chunk\n    img_string += \"<|image_end|>\"\n\n    return img_string\n\n\n# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60\ndef _prompt_split_image(\n    *,\n    image_seq_len: int,\n    image_rows: int,\n    image_cols: int,\n    fake_token_around_image: str,\n    image_token: str,\n    global_img_token: str,\n):\n    \"\"\"Prompt with expanded image tokens for when the image is split into patches.\"\"\"\n    text_split_images = \"\"\n    for n_h in range(image_rows):\n        for n_w in range(image_cols):\n            text_split_images += (\n                f\"{fake_token_around_image}\"\n                + f\"<row_{n_h + 1}_col_{n_w + 1}>\"\n                + f\"{image_token}\" * image_seq_len\n            )\n        text_split_images += \"\\n\"\n\n    text_split_images += (\n        f\"\\n{fake_token_around_image}\"\n        + f\"{global_img_token}\"\n        + f\"{image_token}\" * image_seq_len\n        + f\"{fake_token_around_image}\"\n    )\n    return text_split_images\n\n\ndef get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n    \"\"\"\n    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n\n    Args:\n        image_size (`tuple`):\n            The size of the input image in the format (height, width).\n        grid_pinpoints (`List`):\n            A list containing possible resolutions. Each item in the list should be a tuple or list\n            of the form `(height, width)`.\n        patch_size (`int`):\n            The size of each image patch.\n\n    Returns:\n        tuple: The shape of the image patch grid in the format (width, height).\n    \"\"\"\n    if not isinstance(grid_pinpoints, list):\n        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n\n    height, width = select_best_resolution(image_size, grid_pinpoints)\n    return height // patch_size, width // patch_size\n\n\ndef image_text_replacement(processor, image_input, config) -> str:\n    if config.model_type == \"idefics2\":\n        image_seq_len = 64\n        image_str = f\"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}\"\n        if processor.image_processor.do_image_splitting:\n            image_str *= 5\n        return image_str, IDEFICS2_FAKE_TOKEN\n    if config.model_type == \"idefics3\":\n        # TODO: implement this in a more general way\n        n_rows = image_input[\"rows\"][0][0]\n        n_cols = image_input[\"cols\"][0][0]\n        image_seq_len = int(\n            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)\n            / (config.scale_factor**2)\n        )\n        image_str = _prompt_split_image(\n            image_seq_len=image_seq_len,\n            image_rows=n_rows,\n            image_cols=n_cols,\n            fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,\n            image_token=IDEFICS3_IMAGE_TOKEN,\n            global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,\n        )\n        return image_str, IDEFICS3_FAKE_IMAGE_TOKEN\n    elif config.model_type == \"llava_next\":\n        height, width = image_input[\"image_sizes\"][0]\n        num_features = get_number_of_features(height, width, config)\n\n        log_master(\n            logger.info,\n            f\"Found {num_features} features in image of resolution {height}x{width}\",\n        )\n        return \"<image>\" * num_features, \"<image>\"\n\n    elif config.model_type == \"paligemma\":\n        return \"<image>\" * config.text_config.num_image_tokens, \"<image>\"\n    elif config.model_type == \"qwen2_vl\":\n        grid_t, grid_h, grid_w = image_input[\"image_grid_thw\"][0]\n        num_pads = grid_t * grid_h * grid_w // 4\n        padding = \"<|image_pad|>\" * num_pads\n        return f\"<|vision_start|>{padding}<|vision_end|>\", \"<|vision_start|>\"\n    elif config.model_type == \"qwen2_5_vl\":\n        grid_t, grid_h, grid_w = image_input[\"image_grid_thw\"][0]\n        num_pads = grid_t * grid_h * grid_w // 4\n        padding = \"<|image_pad|>\" * num_pads\n        return f\"<|vision_start|>{padding}<|vision_end|>\", \"<|vision_start|>\"\n    elif config.model_type == \"gemma3\":\n        # TODO: get correct number of features via reviewing the Gemma3 architecture\n        # and calculating the number of image tokens\n        num_pads = 256\n        padding = \"<image_soft_token>\" * num_pads\n        return f\"\\n\\n<start_of_image>{padding}<end_of_image>\\n\\n\", \"<start_of_image>\"\n    elif config.model_type == \"llama4\":\n        patch_size = config.vision_config.patch_size\n        pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio\n        downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))\n        aspect_ratios = image_input[\"aspect_ratios\"][0]\n        image_height, image_width = image_input[\"pixel_values\"][0].shape[-2:]\n\n        num_patches_per_chunk = int(\n            (image_height // patch_size)\n            * (image_width // patch_size)\n            // downsample_ratio\n        )\n        tokens_for_this_image = prompt_split_image_llama4(\n            aspect_ratios, num_patches_per_chunk\n        )\n\n        return tokens_for_this_image, \"<|image_start|>\"\n    else:\n        raise RuntimeError(f\"Unknown config {config.model_type} for multimodal\")\n\n\ndef image_text_replacement_fixup(config, text: str) -> str:\n    if config.model_type == \"idefics2\":\n        return text.replace(\n            f\"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}\", IDEFICS2_FAKE_TOKEN\n        )\n    return text\n\n\ndef preprocess_text(config, text: str) -> str:\n    if config.model_type == \"paligemma\":\n        return \"<bos>\" + text + \"\\n\"\n    return text\n\n\ndef preprocess_image(config, img):\n    model_type = config.model_type\n\n    if model_type in {\"qwen2_vl\", \"qwen2_5_vl\"} and img.width <= 20:\n        img = img.resize((img.width * 2, img.height * 2))\n    if model_type == \"paligemma\":\n        img = img.convert(\"RGB\")\n\n    if model_type not in {\"llava_next\", \"gemma3\", \"llama4\"}:\n        # TODO: check if this is needed\n        img = [img]\n\n    return img\n\n\ndef get_unpadded_features(\n    original_height: int,\n    original_width: int,\n    npatches: int,\n    num_patch_height: int,\n    num_patch_width: int,\n) -> Tuple[int, int]:\n    current_height = npatches * num_patch_height\n    current_width = npatches * num_patch_width\n\n    aspect_ratio: float = original_width / original_height\n    current_aspect_ratio: float = current_width / current_height\n\n    if aspect_ratio > current_aspect_ratio:\n        new_height = (original_height * current_width) // original_width\n        padding = (current_height - new_height) // 2\n        current_height = current_height - (2 * padding)\n    else:\n        new_width = (original_width * current_height) // original_height\n        padding = (current_width - new_width) // 2\n        current_width = current_width - (2 * padding)\n\n    unpadded_features = current_height * current_width\n    newline_features = current_height\n    return (unpadded_features, newline_features)\n\n\ndef get_number_of_features(height: int, width: int, config) -> int:\n    # From config\n    # Hardcoded for CLIP for now\n    # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]\n    image_grid_pinpoints = config.image_grid_pinpoints\n    image_size = config.vision_config.image_size\n    patch_size = config.vision_config.patch_size\n\n    assert image_size % patch_size == 0\n\n    npatches = image_size // patch_size\n\n    # Dimensions are intentionally swapped to be bug-compatible with\n    # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59\n    num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n        [height, width],\n        image_grid_pinpoints,\n        image_size,\n    )\n    unpadded_features, newline_features = get_unpadded_features(\n        height, width, npatches, num_patch_height, num_patch_width\n    )\n    # The base patch covers the entire image\n    base_features = npatches**2\n    return unpadded_features + newline_features + base_features\n\n\ndef scatter_image_embeds(\n    embeds: torch.Tensor, is_embed: Optional[torch.Tensor]\n) -> torch.Tensor:\n    if is_embed is None:\n        return embeds\n\n    placeholders = embeds.new_full(\n        (is_embed.shape[0], embeds.shape[-1]),\n        fill_value=torch.nan,\n    )\n    placeholders[is_embed.to(embeds.device)] = embeds\n    return placeholders\n\n\ndef gather_image_embeds(\n    embeds: torch.Tensor, is_embed: Optional[torch.Tensor]\n) -> Optional[torch.Tensor]:\n    if is_embed is None:\n        return embeds\n    sel = embeds[is_embed.to(embeds.device)]\n    return sel if sel.numel() else None\n\n\n@dataclass\nclass ImagePositions:\n    offset: int\n    length: int\n    id: int\n    num_placeholder_tokens: int\n    is_embed: Optional[torch.Tensor] = None\n\n\nclass FlashVlmCausalLMBatch(FlashCausalLMBatch):\n    image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]\n    image_positions: Optional[List[List[ImagePositions]]]\n    encoder_cache: Optional[List[Dict[int, torch.Tensor]]]\n    pixel_values: Optional[List[torch.Tensor]]\n    pixel_attention_mask: Optional[List[torch.Tensor]]\n    image_sizes: Optional[List[Tuple[int, int]]]\n    image_grid_thw: Optional[torch.Tensor]\n    cache_entries_to_free: List[Tuple[int, int]]\n    has_image_inputs: bool = False\n    inputs_embeds: Optional[torch.Tensor] = None\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches, padded_total_bs: int = 0):\n        batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)\n        batch.image_inputs = []\n        batch.image_positions = []\n        batch.encoder_cache = []\n        for b in batches:\n            if b.image_inputs is not None:\n                batch.image_inputs.extend(b.image_inputs)\n            else:\n                batch.image_inputs.append(None)\n            if b.image_positions is not None:\n                batch.image_positions.extend(b.image_positions)\n            else:\n                batch.image_positions.append(None)\n            if b.encoder_cache is not None:\n                batch.encoder_cache.extend(b.encoder_cache)\n            else:\n                batch.encoder_cache.append(None)\n\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n        batch.image_grid_thw = None\n        batch.inputs_embeds = None\n        # To be filled in prepare_for_prefill\n        batch.has_image_inputs = False\n        batch.cache_entries_to_free = []\n        return batch\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]):\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n\n        image_inputs = []\n        image_positions = []\n        encoder_cache = []\n\n        for request_id in request_ids:\n            idx = self.requests_idx_mapping[request_id]\n            image_inputs.append(self.image_inputs[idx])\n            image_positions.append(self.image_positions[idx])\n            encoder_cache.append(self.encoder_cache[idx])\n\n        batch = super().filter(request_ids)\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n        batch.image_grid_thw = None\n        batch.inputs_embeds = None\n        batch.image_inputs = image_inputs\n        batch.image_positions = image_positions\n        batch.encoder_cache = encoder_cache\n\n        # To be filled in prepare_for_prefill\n        batch.has_image_inputs = False\n        batch.cache_entries_to_free = []\n        return batch\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config\n    ):\n        kwargs = {}\n        if (\n            hasattr(processor, \"image_processor_class\")\n            and processor.image_processor_class == \"Idefics3ImageProcessor\"\n        ):\n            kwargs[\"return_row_col_info\"] = True\n\n        max_length = 0\n        vocab = tokenizer.get_vocab()\n\n        if not hasattr(config, \"image_token_index\"):\n            config.image_token_index = config.image_token_id\n\n        batch_tokenized_inputs: List[List[int]] = []\n        batch_image_inputs: List[Optional[List[dict]]] = []\n        batch_image_positions: List[Optional[List[ImagePositions]]] = []\n\n        for r in requests:\n            text_parts = []\n            image_inputs = []\n            image_texts = []\n\n            image_id = 0\n\n            for chunk in r.input_chunks.chunks:\n                chunk_type = chunk.WhichOneof(\"chunk\")\n                if chunk_type == \"text\":\n                    text = preprocess_text(config, chunk.text)\n                    text_parts.append(text)\n                elif chunk_type == \"image\":\n                    img = Image.open(BytesIO(chunk.image.data))\n                    img = preprocess_image(config, img)\n\n                    image_input = processor.image_processor(\n                        [img], return_tensors=\"pt\", **kwargs\n                    )\n                    image_inputs.append(image_input)\n\n                    img_text, img_start_token_str = image_text_replacement(\n                        processor, image_input, config\n                    )\n                    text_parts.append(img_text)\n\n                    image_texts.append([image_id, img_start_token_str, img_text])\n                    image_id += 1\n                else:\n                    raise RuntimeError(f\"Invalid chunk type {chunk_type}\")\n\n            full_text = image_text_replacement_fixup(config, \"\".join(text_parts))\n            input_ids = tokenizer(\n                full_text,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=(\n                    r.add_special_tokens if config.model_type != \"paligemma\" else False\n                ),\n            )[\"input_ids\"]\n            max_length = max(max_length, len(input_ids))\n\n            if len(image_inputs) > 0:\n                img_start_token = vocab[image_texts[0][1]]\n                image_positions = cls.get_image_positions(\n                    input_ids, image_texts, img_start_token, config, tokenizer\n                )\n            else:\n                image_inputs = None\n                image_positions = None\n\n            batch_tokenized_inputs.append(input_ids)\n            batch_image_inputs.append(image_inputs)\n            batch_image_positions.append(image_positions)\n\n        return batch_tokenized_inputs, batch_image_inputs, batch_image_positions\n\n    @classmethod\n    def get_image_positions(\n        cls,\n        input_ids: List[int],\n        image_texts: List[Tuple[int, str, str]],\n        img_start_token: int,\n        config,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> List[ImagePositions]:\n        image_positions = []\n        num_images = len(image_texts)\n\n        input_ids_t = torch.as_tensor(input_ids)\n        img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]\n        num_tokens = input_ids_t.numel()\n\n        last_pos = 0\n        for i in range(num_images):\n            image_id, img_start_token_str, img_text = image_texts[i]\n            img_text = image_text_replacement_fixup(config, img_text)\n\n            if config.model_type == \"gemma3\":\n                img_text = img_text.replace(\"\\n\\n\", \"\")\n\n            tokens = tokenizer(img_text, add_special_tokens=False, return_tensors=\"pt\")[\n                \"input_ids\"\n            ][0]\n            length = tokens.numel()\n\n            assert (\n                length <= num_tokens\n            ), f\"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens\"\n\n            pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)\n            index = img_start_token_pos[pos]\n            assert torch.equal(\n                input_ids_t[index : index + length], tokens\n            ), \"Image tokens not found in input_ids\"\n\n            is_embed = tokens == config.image_token_index\n            num_placeholder_tokens = int(is_embed.sum())\n            if num_placeholder_tokens == length:\n                is_embed = None\n\n            pos = ImagePositions(\n                offset=index,\n                length=length,\n                id=image_id,\n                num_placeholder_tokens=num_placeholder_tokens,\n                is_embed=is_embed,\n            )\n\n            image_positions.append(pos)\n            last_pos = index + length\n\n            if (\n                config.model_type == \"idefics2\"\n                and i + 1 != num_images\n                and input_ids[last_pos] == config.image_token_index\n            ):\n                fake_token = last_pos - 1\n                fake_token_index = torch.searchsorted(\n                    img_start_token_pos, fake_token, right=False\n                )\n                img_start_token_pos[fake_token_index] = last_pos\n                image_texts[i + 1][2] = image_texts[i + 1][2][\n                    len(img_start_token_str) :\n                ]\n\n        return image_positions\n\n    @classmethod\n    def from_pb_processor(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        processor,\n        config,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashVlmCausalLMBatch\":\n        batch_tokenized_inputs, image_inputs, image_positions = (\n            cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)\n        )\n        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n        batch.image_inputs = image_inputs\n        batch.image_positions = image_positions\n        batch.encoder_cache = [{} for _ in range(len(pb.requests))]\n        if len(image_inputs):\n            batch.pixel_values = None\n            batch.pixel_attention_mask = None\n            batch.image_sizes = None\n            batch.image_grid_thw = None\n        return batch\n\n    def prepare_for_prefill(\n        self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id\n    ):\n        super().prepare_for_prefill(\n            max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id\n        )\n\n        self.has_image_inputs = False\n        self.cache_entries_to_free = []\n\n        self.pixel_values = []\n\n        assert (\n            len(self.cache_lengths)\n            == len(self.input_lengths)\n            == len(self.prefilling_mask)\n        ), \"Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask\"\n\n        for i, (\n            cache_length,\n            input_length,\n            request_prefilling,\n        ) in enumerate(\n            zip(\n                self.cache_lengths,\n                self.input_lengths,\n                self.prefilling_mask,\n            )\n        ):\n            if not request_prefilling or self.image_positions[i] is None:\n                continue\n\n            for image_position in self.image_positions[i]:\n                if image_position is None:\n                    continue\n                start_pos = image_position.offset\n                length = image_position.length\n\n                if start_pos >= cache_length + input_length:\n                    # No encoder input required at this step\n                    break\n                if start_pos + length <= cache_length:\n                    # The encode input is already processed\n                    continue\n\n                self.has_image_inputs = True\n\n                if image_position.id not in self.encoder_cache[i]:\n                    image_inputs = self.image_inputs[i][image_position.id]\n                    self.pixel_values.append((i, image_position.id, image_inputs))\n\n                    # Remove the image from the image_inputs\n                    self.image_inputs[i][image_position.id] = None\n\n        if not self.has_image_inputs:\n            self.pixel_values = None\n            self.pixel_attention_mask = None\n            self.image_sizes = None\n            self.image_grid_thw = None\n        else:\n            image_grid_thw_list = [\n                x[2][\"image_grid_thw\"]\n                for x in self.pixel_values\n                if \"image_grid_thw\" in x[2]\n            ]\n            if image_grid_thw_list:\n                self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0)\n            else:\n                self.image_grid_thw = None\n\n    def update_encoder_cache(self, encoder_outputs, request_id, img_pos):\n        self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(\n            encoder_outputs, img_pos.is_embed\n        )\n\n    def gather_vision_embeds(self):\n        device = self.input_ids.device\n        chunks = []\n        for (\n            i,\n            cache_length,\n            input_length,\n            request_prefilling,\n        ) in zip(\n            range(len(self.requests)),\n            self.cache_lengths,\n            self.input_lengths,\n            self.prefilling_mask,\n        ):\n            if not request_prefilling or self.image_positions[i] is None:\n                continue\n\n            for image_position in self.image_positions[i]:\n                if image_position is None:\n                    continue\n                start_pos = image_position.offset\n                length = image_position.length\n\n                if start_pos >= cache_length + input_length:\n                    # No encoder input required at this step\n                    break\n                if start_pos + length <= cache_length:\n                    # The encode input is already processed\n                    continue\n\n                start_idx = max(cache_length - start_pos, 0)\n                end_idx = min(cache_length - start_pos + input_length, length)\n\n                assert (\n                    image_position.id in self.encoder_cache[i]\n                ), f\"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}\"\n                encoder_output = self.encoder_cache[i][image_position.id]\n\n                is_embed = image_position.is_embed\n                if is_embed is not None:\n                    is_embed = is_embed[start_idx:end_idx]\n\n                from loguru import logger\n\n                logger.info(\n                    f\"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}\"\n                )\n\n                embeds = gather_image_embeds(\n                    encoder_output[start_idx:end_idx],\n                    is_embed=is_embed,\n                )\n                if embeds is not None:\n                    chunks.append(embeds)\n\n                if end_idx == length:\n                    self.cache_entries_to_free.append((i, image_position.id))\n                    self.image_positions[i][image_position.id] = None\n\n        if len(chunks) == 0:\n            return None\n        return torch.cat(chunks, dim=0).to(device)\n\n    def free_encoder_cache(self):\n        for i, image_id in self.cache_entries_to_free:\n            self.encoder_cache[i].pop(image_id, None)\n\n        self.cache_entries_to_free = []\n\n\nclass FlashVlmCausalLM(FlashCausalLM):\n    def __init__(\n        self,\n        model_id: str,\n        *,\n        processor_class=AutoProcessor,\n        processor_kwargs=None,\n        batch_class=FlashVlmCausalLMBatch,\n        revision,\n        trust_remote_code: bool,\n        support_chunking: bool = False,\n        **kwargs,\n    ):\n        if PREFIX_CACHING:\n            raise NotImplementedError(\"Vlm do not work with prefix caching yet\")\n        if processor_kwargs is None:\n            processor_kwargs = {}\n        self.processor = processor_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n            **processor_kwargs,\n        )\n        self.batch_class = batch_class\n        super().__init__(\n            model_id=model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n            support_chunking=support_chunking,\n            **kwargs,\n        )\n\n    @property\n    def batch_type(self) -> Type[FlashVlmCausalLMBatch]:\n        return self.batch_class\n\n    def max_past(self) -> Optional[int]:\n        return getattr(self.model.text_model, \"max_past\", None)\n\n    def warmup_decode(\n        self, batch_size: int, block_num: int, batch: FlashVlmCausalLMBatch\n    ):\n        input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)\n        position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)\n        if batch.position_ids is not None and batch.position_ids.dim() == 2:\n            # qwen2_vl and qwen2_5_vl case\n            position_ids = position_ids.unsqueeze(-1).repeat(\n                (1, batch.position_ids.shape[-1])\n            )\n        blocks = [block_num // batch_size for _ in range(batch_size)]\n        blocks[0] += block_num % batch_size\n        block_tables = []\n        slots = []\n        start_idx = 0\n        slot_indices = []\n\n        # fetch the last blocked to warmup block num\n\n        for i in range(batch_size):\n            block_array = list(range(start_idx, start_idx + blocks[i]))\n            slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)\n            block_tables.append(block_array)\n            slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)\n            start_idx += blocks[i]\n        input_lengths = torch.ones(batch_size, dtype=torch.int32)\n\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        block_list, block_groups, block_usage, _, block_bucket_size = (\n            generate_block_metadata(\n                self.dtype,\n                self.use_contiguous_pa,\n                slots,\n                block_tables,\n                self.bucketing_ctx,\n            )\n        )\n        meta = HPUPagedAttentionMetadata(\n            block_list=_async_h2d_tensor_copy(block_list),\n            block_groups=_async_h2d_tensor_copy(block_groups),\n            block_usage=_async_h2d_tensor_copy(block_usage),\n            block_mapping=None,\n            attn_bias=None,\n        )\n        if self.sliding_window is not None:\n            block_tables_in_window = []\n            for i, bt in enumerate(block_tables):\n                block_num_in_window = (\n                    self.sliding_window + BLOCK_SIZE - 1\n                ) // BLOCK_SIZE\n                block_tables_in_window.append(\n                    bt[max(0, blocks[i] - block_num_in_window) : blocks[i]]\n                )\n            slots_in_window = []\n            start_idx = 0\n            for i, indice in enumerate(slot_indices):\n                mask = (\n                    indice - torch.arange(start_idx, indice + 1)\n                ) < self.sliding_window\n                slots_in_window.append(torch.arange(start_idx, indice + 1)[mask])\n                start_idx += blocks[i] * BLOCK_SIZE\n            slots_in_window = torch.cat(slots_in_window, dim=0)\n            (\n                block_list_in_window,\n                block_groups_in_window,\n                block_usage_in_window,\n                slots_in_window_mask,\n                _,\n            ) = generate_block_metadata(\n                self.dtype,\n                self.use_contiguous_pa,\n                slots,\n                block_tables_in_window,\n                self.bucketing_ctx,\n                slots_in_window,\n                block_bucket_size,\n            )\n            meta.block_list_in_window = _async_h2d_tensor_copy(block_list_in_window)\n            meta.block_groups_in_window = _async_h2d_tensor_copy(block_groups_in_window)\n            meta.block_usage_in_window = _async_h2d_tensor_copy(block_usage_in_window)\n            meta.slots_in_window_mask = _async_h2d_tensor_copy(slots_in_window_mask)\n\n        hpu_attention_meta = trim_attn_metadata(meta)\n        slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)\n        inputs_embeds = self.get_inputs_embeds(\n            input_ids=input_ids.to(self.device),\n        )\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        self.model.forward(\n            inputs_embeds=inputs_embeds,\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=None,\n            kv_cache=self.kv_cache,\n            slots=_async_h2d_tensor_copy(slots_tensor),\n            seqlen=trim_seqlen_metadata(seqlen),\n            hpu_attention_meta=hpu_attention_meta,\n            lm_head_indices=None,\n            attention_mask=None,\n        )\n\n    def warmup_hpu_graph(self, batch: FlashVlmCausalLMBatch):\n        free_mem = HabanaMemoryProfiler.current_free_device_memory()\n        graph_free_mem = free_mem - self.mem_reserved\n        graph_free_mem = self.align_workers(\n            graph_free_mem, torch.distributed.ReduceOp.MIN\n        )\n        decode_available_memory = graph_free_mem\n        msg = (\n            f\"Using {format_bytes(graph_free_mem)}\"\n            f\"/{format_bytes(free_mem)} \"\n            \"of free device memory for HPUGraphs, \"\n            f\"{format_bytes(decode_available_memory)} for decode \"\n        )\n        log_master(logger.info, msg)\n        start_time = time.time()\n        warmup_shape_count = 0\n        warmup_times = 3\n\n        # only warmup decode, for prefill, image pixal size may change, make the warmup useless\n        def ordering_function_max_bs(b):\n            return (-b[0], b[1])\n\n        self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)\n        buckets = list(\n            sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)\n        )\n        total_batch_seq = 0.001\n        total_mem = 0\n        available_mem = decode_available_memory\n        log_master(\n            logger.info, f\"Decode batch size list:{[bsz[0] for bsz in buckets]}\\n\"\n        )\n        for i, (batch_size, block_num) in enumerate(buckets):\n            if batch_size > block_num:\n                continue\n            # Graph memory usage is proportional to seq dimension in a batch\n            batch_seq = batch_size\n            mem_estimate = batch_seq / total_batch_seq * total_mem\n            graphed_bucket = (batch_size, block_num, False)\n            if not mem_estimate >= available_mem:\n                if graphed_bucket not in self.graphed_buckets:\n                    self.graphed_buckets.add(graphed_bucket)\n            warmup_shape_count += 1\n            self.log_warmup(False, i, len(buckets), batch_size, block_num)\n            with HabanaMemoryProfiler() as mem_prof:\n                for index in range(warmup_times):\n                    self.warmup_decode(batch_size, block_num, batch)\n                    synchronize(self.device)\n            used_mem = self.align_workers(\n                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX\n            )\n            if graphed_bucket in self.graphed_buckets:\n\n                available_mem -= used_mem\n                total_mem += used_mem\n                total_batch_seq += batch_seq\n\n        log_master(logger.info, \"Decode warmup successful.\\n\")\n\n        log_master(\n            logger.info,\n            f\"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}\",\n        )\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.Tensor,\n        pixel_attention_mask: torch.Tensor,\n        image_sizes: torch.Tensor,\n        image_grid_thw: torch.Tensor,\n    ):\n        embeds = self.model.get_vision_embeds(\n            pixel_values=pixel_values,\n            pixel_attention_mask=pixel_attention_mask,\n            image_sizes=image_sizes,\n            image_grid_thw=image_grid_thw,\n        )\n        return embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: Optional[torch.Tensor] = None,\n    ):\n        return self.model.get_inputs_embeds(\n            input_ids=input_ids,\n            vision_embeds=vision_embeds,\n        )\n\n    def encode_images(self, batch):\n        if batch.pixel_values is not None:\n            device = batch.input_ids.device\n            for request_id, image_id, image_input in batch.pixel_values:\n                pixel_values = image_input[\"pixel_values\"].to(device)\n\n                if \"pixel_attention_mask\" in image_input:\n                    pixel_attention_mask = image_input[\"pixel_attention_mask\"].to(\n                        device\n                    )\n                else:\n                    pixel_attention_mask = None\n\n                if \"image_sizes\" in image_input:\n                    image_sizes = image_input[\"image_sizes\"].to(device)\n                else:\n                    image_sizes = None\n\n                if \"image_grid_thw\" in image_input:\n                    image_grid_thw = image_input[\"image_grid_thw\"]\n                else:\n                    image_grid_thw = None\n\n                encoder_outputs = self.get_vision_embeds(\n                    pixel_values=pixel_values,\n                    pixel_attention_mask=pixel_attention_mask,\n                    image_sizes=image_sizes,\n                    image_grid_thw=image_grid_thw,\n                )\n                batch.update_encoder_cache(\n                    encoder_outputs,\n                    request_id,\n                    batch.image_positions[request_id][image_id],\n                )\n\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n\n    def set_inputs_embeds(self, batch):\n        if batch.has_image_inputs:\n            self.encode_images(batch)\n            vision_embeds = batch.gather_vision_embeds()\n            batch.has_image_inputs = False\n        else:\n            vision_embeds = None\n\n        inputs_embeds = self.get_inputs_embeds(\n            batch.input_ids, vision_embeds=vision_embeds\n        )\n\n        batch.inputs_embeds = inputs_embeds\n\n    def forward(\n        self,\n        batch: FlashVlmCausalLMBatch,\n        adapter_data: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            inputs_embeds = batch.inputs_embeds\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        if self.model.config.model_type in {\"qwen2_vl\", \"qwen2_5_vl\"}:\n            if position_ids.dim() == 1 and batch.prefilling:\n                position_ids = self.model.get_position_ids(\n                    input_ids.cpu(), batch.image_grid_thw\n                )\n                batch.position_ids = position_ids\n\n        attention_mask = None\n        attention_mask_forward = None\n        if self.model.config.model_type == \"llama4\":\n            attention_mask = (input_ids != self.tokenizer.pad_token_id).long()\n            attention_mask_forward = attention_mask.view(input_lengths.shape[0], -1)\n\n        if cu_seqlen_prefill is None and self.max_past() is not None:\n            # In decode, not prefill, we're actually overwriting the KV-cache\n            # in a circular buffer mode.\n            # This makes sure the max_s for the decode pass is correct.\n            max_s = min(self.max_past(), max_s)\n\n        if batch.prefill_cache_indices is not None:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[batch.prefill_cache_indices] = slots\n            slots = slots_pad\n        else:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[: slots.shape[0]] = slots\n            slots = slots_pad\n\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        kwargs = {}\n        batch_size = input_lengths.shape[0]\n        prompt_len = (\n            input_ids.shape[0] // batch_size\n            if batch.prefilling\n            else batch.hpu_attn_meta.block_list.shape[0]\n        )\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                batch.prefilling, prompt_len, batch_size\n            )\n        if self.sliding_window is not None:\n            attn_mask = seqlen.make_sliding_window_bias(\n                input_lengths.tolist(),\n                self.sliding_window,\n                self.dtype,\n                prompt_len,\n                batch_size,\n            )\n            seqlen.attn_mask = _async_h2d_tensor_copy(attn_mask)\n        logits, speculative_logits = self.model.forward(\n            inputs_embeds=inputs_embeds,\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),\n            kv_cache=kv_cache,\n            slots=_async_h2d_tensor_copy(slots),\n            seqlen=trim_seqlen_metadata(seqlen),\n            hpu_attention_meta=batch.hpu_attn_meta,\n            lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),\n            attention_mask=attention_mask_forward,\n            **kwargs,\n        )\n        batch.image_grid_thw = None\n        batch.free_encoder_cache()\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/globals.py",
    "content": "import os\nfrom typing import Dict, Optional\nfrom loguru import logger\nfrom text_generation_server.utils.log import log_master\n\nREQUEST_LOGPROBS = os.getenv(\"REQUEST_LOGPROBS\", \"0\").lower() in {\"1\", \"true\"}\nATTENTION = os.getenv(\"ATTENTION\", \"paged\")\n# default_prefix_caching = \"1\" if ATTENTION in {\"flashinfer\", \"flashdecoding\"} else \"0\"\nPREFIX_CACHING = os.getenv(\"PREFIX_CACHING\", \"0\").lower() in {\n    \"1\",\n    \"true\",\n}\nlog_master(logger.info, f\"Using prefix caching = {PREFIX_CACHING}\")\n_expected = {\"paged\"}\nassert (\n    ATTENTION in _expected\n), f\"Attention is not valid {ATTENTION}, expected {_expected}\"\nlog_master(logger.info, f\"Using Attention = {ATTENTION}\")\n\nTGI_WIGGLE_ROOM = float(os.getenv(\"TGI_WIGGLE_ROOM\", \"0.90\"))\nassert TGI_WIGGLE_ROOM > 0\nassert TGI_WIGGLE_ROOM < 1\n\n# This is overridden by the cli\nBLOCK_SIZE: int\n\nBLOCK_SIZE = 128\n\n\n# This is overridden at model loading.\nglobal MODEL_ID\nMODEL_ID = None\n\n\ndef set_model_id(model_id: str):\n    global MODEL_ID\n    MODEL_ID = model_id\n\n\n# NOTE: eventually we should move this into the router and pass back the\n# index in all cases.\nADAPTER_TO_INDEX: Optional[Dict[str, int]] = None\n\n\ndef set_adapter_to_index(adapter_to_index: Dict[str, int]):\n    global ADAPTER_TO_INDEX\n    ADAPTER_TO_INDEX = adapter_to_index\n\n\ndef get_adapter_to_index():\n    global ADAPTER_TO_INDEX\n    return ADAPTER_TO_INDEX\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py",
    "content": "import torch\n\nimport numpy as np\n\nfrom typing import Iterable, Optional, Tuple, List, Dict\nfrom text_generation_server.pb.generate_pb2 import Request\nfrom io import BytesIO\nfrom PIL import Image\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    PreTrainedTokenizerBase,\n)\nfrom text_generation_server.models.flash_causal_lm import (\n    generate_block_metadata,\n)\nfrom text_generation_server.models.flash_vlm_causal_lm import (\n    FlashVlmCausalLMBatch,\n    FlashVlmCausalLM,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    trim_seqlen_metadata,\n    _async_h2d_tensor_copy,\n    HPUPagedAttentionMetadata,\n    trim_attn_metadata,\n)\nimport habana_frameworks.torch as htorch\nfrom loguru import logger\nfrom text_generation_server.models.globals import BLOCK_SIZE\nfrom text_generation_server.utils.import_utils import (\n    synchronize,\n)\nimport torch.nn.functional as F\nfrom text_generation_server.utils.log import log_master\nimport time\nimport os\nfrom vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):\n    image_indices: List[int] = 42\n    aspect_ratio_ids: Optional[torch.Tensor] = None\n    aspect_ratio_mask: Optional[torch.Tensor] = None\n    cross_attention_states: Optional[torch.Tensor] = None\n\n    def prepare_for_prefill(\n        self, max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id\n    ):\n        super(FlashVlmCausalLMBatch, self).prepare_for_prefill(\n            max_padded_input_len, max_padded_bs, max_total_tokens, pad_token_id\n        )\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches, padded_total_bs: int = 0):\n        batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches, padded_total_bs)\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n\n        offset = 0\n        image_indices = []\n        attention_states = []\n        for b in batches:\n            if b.cross_attention_states is not None:\n                attention_states.append(b.cross_attention_states)\n            image_indices.extend([i + offset for i in b.image_indices])\n            offset += len(b.image_indices)\n        if len(attention_states) > 0:\n            assert len(image_indices) > 0\n            batch.cross_attention_states = torch.cat(attention_states, dim=0)\n            batch.image_indices = image_indices\n        else:\n            batch.cross_attention_states = None\n            batch.image_indices = []\n        return batch\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]):\n        assert self.image_indices is not None\n        batch = super(FlashVlmCausalLMBatch, self).filter(request_ids)\n        assert self.image_indices is not None\n        indices = []\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            indices.append(idx)\n\n        offset = 0\n        new_image_indices = []\n        prev_i = None\n        for i in self.image_indices:\n            if i in indices:\n                new_image_indices.append(offset)\n                if i != prev_i:\n                    offset += 1\n                prev_i = i\n\n        batch.image_indices = new_image_indices\n        if len(new_image_indices) > 0:\n            assert max(new_image_indices) < self.cross_attention_states.shape[0]\n            assert offset <= self.cross_attention_states.shape[0]\n            batch.cross_attention_states = self.cross_attention_states[\n                new_image_indices\n            ]\n        else:\n            batch.cross_attention_states = None\n        batch.pixel_values = None\n        return batch\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[Request], tokenizer, processor, config\n    ):\n        image_inputs = []\n        texts = []\n        image_indices = []\n        batch_tokenized_inputs = []\n\n        for i, r in enumerate(requests):\n            # Each input is encoded into a list, where each element of this input list is either a string or a URL\n            curr_text = \"\"\n            curr_image = None\n            curr_i = None\n            for chunk in r.input_chunks.chunks:\n                chunk_type = chunk.WhichOneof(\"chunk\")\n                if chunk_type == \"text\":\n                    curr_text += chunk.text\n                elif chunk_type == \"image\":\n                    image = Image.open(BytesIO(chunk.image.data))\n                    # TODO unsure about BOS\n                    curr_text += \"<|image|>\"\n                    image_input = processor.image_processor(image, return_tensors=\"pt\")\n                    curr_image = image_input\n                    curr_i = i\n                    # image_inputs.append(image_input)\n                    # image_indices.append(i)\n                else:\n                    raise RuntimeError(f\"Invalid chunk type {chunk_type}\")\n            texts.append(curr_text)\n            if curr_image is not None:\n                image_inputs.append(curr_image)\n                image_indices.append(curr_i)\n\n            input_ids = tokenizer(\n                curr_text,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=r.add_special_tokens,\n            )[\"input_ids\"]\n            batch_tokenized_inputs.append(input_ids)\n        if image_inputs:\n            image_input = image_inputs[0]\n            new_image_inputs = {\n                \"pixel_values\": torch.cat(\n                    [img[\"pixel_values\"] for img in image_inputs], dim=0\n                ),\n            }\n            if \"aspect_ratio_ids\" in image_input:\n                new_image_inputs[\"aspect_ratio_ids\"] = torch.cat(\n                    [img[\"aspect_ratio_ids\"] for img in image_inputs], dim=0\n                )\n            if \"aspect_ratio_mask\" in image_input:\n                new_image_inputs[\"aspect_ratio_mask\"] = torch.cat(\n                    [img[\"aspect_ratio_mask\"] for img in image_inputs], dim=0\n                )\n            image_inputs = new_image_inputs\n            image_inputs[\"image_indices\"] = image_indices\n        else:\n            image_inputs = None\n\n        if image_inputs is not None:\n            assert len(image_indices) == image_inputs[\"pixel_values\"].shape[0]\n\n        return batch_tokenized_inputs, image_inputs\n\n    @classmethod\n    def from_pb_processor(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        processor,\n        config,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashVlmCausalLMBatch\":\n        batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(\n            pb.requests, tokenizer, processor, config\n        )\n        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n        # XXX: <|image|> token is actually out of bounds and bugs out the logit processors.\n        batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(\n            max=config.text_config.vocab_size - 1\n        )\n        if isinstance(batch.input_ids, list):\n            if len(batch) > 1:\n                input_ids = np.concatenate(batch.input_ids, dtype=np.int64)\n            else:\n                input_ids = batch.input_ids[0]\n            batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)\n\n        batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)\n\n        if image_inputs is not None:\n            batch.pixel_values = image_inputs[\"pixel_values\"].to(\n                device=device, dtype=dtype\n            )\n            batch.aspect_ratio_ids = image_inputs[\"aspect_ratio_ids\"].to(device=device)\n            batch.aspect_ratio_mask = image_inputs[\"aspect_ratio_mask\"].to(\n                device=device\n            )\n            batch.image_indices = image_inputs[\"image_indices\"]\n        else:\n            batch.pixel_values = None\n            batch.aspect_ratio_ids = None\n            batch.aspect_ratio_mask = None\n            batch.image_indices = []\n        assert batch.image_indices is not None\n        return batch\n\n\ndef generate_cross_attention_states(\n    cross_attention_states, image_indices, input_lengths, pad_seq_len, prefilling\n):\n    if cross_attention_states is None:\n        return None, None\n    indices_list = []\n    if prefilling:\n        for i in image_indices:\n            indices_list.append(torch.arange(pad_seq_len * i, pad_seq_len * (i + 1)))\n        indices = torch.cat(indices_list, dim=0)\n    else:\n        indices = image_indices[:]\n    return indices, input_lengths.index_select(0, image_indices)\n\n\nclass FlashMllamaCausalLM(FlashVlmCausalLM):\n    def set_inputs_embeds(self, batch):\n        # Set the input embeddings to None, as we are using the input_ids for the model\n        batch.inputs_embeds = None\n\n    def warmup_decode(\n        self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch\n    ):\n        input_ids = torch.zeros(batch_size, dtype=batch.input_ids.dtype)\n        position_ids = torch.arange(batch_size, dtype=batch.position_ids.dtype)\n        blocks = [block_num // batch_size for _ in range(batch_size)]\n        blocks[0] += block_num % batch_size\n        block_tables = []\n        slots = []\n        start_idx = 0\n        slot_indices = []\n\n        # fetch the last blocked to warmup block num\n        for i in range(batch_size):\n            block_array = list(range(start_idx, start_idx + blocks[i]))\n            slots.append(BLOCK_SIZE * block_array[-1] + BLOCK_SIZE - 1)\n            block_tables.append(block_array)\n            slot_indices.append((start_idx + blocks[i]) * BLOCK_SIZE - 1)\n            start_idx += blocks[i]\n        input_lengths = torch.ones(batch_size, dtype=torch.int32)\n\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        block_list, block_groups, block_usage, _, block_bucket_size = (\n            generate_block_metadata(\n                self.dtype,\n                self.use_contiguous_pa,\n                slots,\n                block_tables,\n                self.bucketing_ctx,\n            )\n        )\n        meta = HPUPagedAttentionMetadata(\n            block_list=_async_h2d_tensor_copy(block_list),\n            block_groups=_async_h2d_tensor_copy(block_groups),\n            block_usage=_async_h2d_tensor_copy(block_usage),\n            block_mapping=None,\n            attn_bias=None,\n        )\n\n        hpu_attention_meta = trim_attn_metadata(meta)\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        image_indices = torch.tensor(batch.image_indices)\n        image_indices = image_indices.repeat(batch_size)\n        cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)\n        indices, cross_attention_len = generate_cross_attention_states(\n            cross_attention_states, image_indices, input_lengths, 1, False\n        )\n        slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)\n        kwargs = {}\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                False, hpu_attention_meta.block_list.shape[0], batch_size\n            )\n        self.model.forward(\n            input_ids=_async_h2d_tensor_copy(input_ids),\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=None,\n            kv_cache=self.kv_cache,\n            slots=_async_h2d_tensor_copy(slots_tensor),\n            seqlen=trim_seqlen_metadata(seqlen),\n            hpu_attention_meta=hpu_attention_meta,\n            lm_head_indices=None,\n            adapter_data=None,\n            cross_attention_states=cross_attention_states,\n            indices=_async_h2d_tensor_copy(indices),\n            cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),\n            **kwargs,\n        )\n\n    def warmup_prefill(\n        self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch\n    ):\n        input_ids = torch.zeros(prompt_len, dtype=batch.input_ids.dtype).repeat(\n            batch_size\n        )\n        position_ids = torch.arange(prompt_len, dtype=batch.position_ids.dtype).repeat(\n            batch_size\n        )\n        max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size\n        block_tables = torch.arange(max_bt, dtype=torch.int32).reshape(batch_size, -1)\n        slot_acc = []\n        for i in range(batch_size):\n            slots = []\n            for b in block_tables[i]:\n                slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))\n            slot_acc.extend(slots[:prompt_len])\n        slots = torch.tensor(slot_acc, dtype=batch.slots.dtype)\n\n        input_lengths = (\n            torch.ones(\n                batch_size,\n                dtype=torch.int32,\n            )\n            * prompt_len\n        )\n        cu_seqlen_prefill = torch.zeros(batch_size + 1, dtype=torch.int32)\n        torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])\n\n        lm_head_indices = input_lengths - 1\n\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        image_indices = torch.tensor(batch.image_indices)\n        image_indices = image_indices.repeat(batch_size)\n        cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)\n        indices, cross_attention_len = generate_cross_attention_states(\n            cross_attention_states, image_indices, input_lengths, prompt_len, True\n        )\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        kwargs = {}\n        if htorch.utils.internal.is_lazy():\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                True, prompt_len, batch_size\n            )\n        self.model.forward(\n            input_ids=_async_h2d_tensor_copy(input_ids),\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),\n            kv_cache=self.kv_cache,\n            slots=_async_h2d_tensor_copy(slots),\n            seqlen=trim_seqlen_metadata(seqlen),\n            hpu_attention_meta=None,\n            lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),\n            adapter_data=None,\n            cross_attention_states=cross_attention_states,\n            indices=_async_h2d_tensor_copy(indices),\n            cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),\n            **kwargs,\n        )\n\n    def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):\n        prompt_graph_mem_ratio = float(os.environ.get(\"VLLM_GRAPH_PROMPT_RATIO\", \"0.3\"))\n        free_mem = HabanaMemoryProfiler.current_free_device_memory()\n        graph_free_mem = free_mem - self.mem_reserved\n        graph_free_mem = self.align_workers(\n            graph_free_mem, torch.distributed.ReduceOp.MIN\n        )\n        prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem\n        decode_available_memory = graph_free_mem - prompt_available_memory\n        msg = (\n            f\"Using {format_bytes(graph_free_mem)}\"\n            f\"/{format_bytes(free_mem)} \"\n            \"of free device memory for HPUGraphs, \"\n            f\"{format_bytes(prompt_available_memory)} for prompt and \"\n            f\"{format_bytes(decode_available_memory)} for decode \"\n            f\"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})\"\n        )\n        log_master(logger.info, msg)\n        start_time = time.time()\n        warmup_shape_count = 0\n        warmup_times = 3\n        self.bucketing_ctx.generate_prompt_buckets()\n\n        def ordering_function_min_tokens(b):\n            return (b[0] * b[1], b[1], b[0])\n\n        buckets = list(\n            sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)\n        )\n        graph_free_mem\n        total_batch_seq = 0.001\n        total_mem = 0\n        available_mem = prompt_available_memory\n        msg = (\n            f\"Prefill batch size list:{[bsz[0] for bsz in buckets]}\\n\"\n            f\"Prefill sequence length list:{[seq[1] for seq in buckets]}\\n\"\n        )\n        log_master(logger.info, msg)\n        for i, (batch_size, seq_len) in enumerate(buckets):\n            if batch_size * seq_len > self.max_batch_prefill_tokens:\n                continue\n            # Graph memory usage is proportional to seq dimension in a batch\n            batch_seq = batch_size * seq_len\n            mem_estimate = batch_seq / total_batch_seq * total_mem\n            graphed_bucket = (batch_size, seq_len, True)\n            if not (\n                mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture\n            ):\n                if graphed_bucket not in self.graphed_buckets:\n                    self.graphed_buckets.add(graphed_bucket)\n            warmup_shape_count += 1\n            self.log_warmup(True, i, len(buckets), batch_size, seq_len)\n            with HabanaMemoryProfiler() as mem_prof:\n                for index in range(warmup_times):\n                    self.warmup_prefill(seq_len, batch_size, batch)\n                    synchronize(self.device)\n            used_mem = self.align_workers(\n                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX\n            )\n            if graphed_bucket in self.graphed_buckets:\n                available_mem -= used_mem\n                total_mem += used_mem\n                total_batch_seq += batch_seq\n\n        log_master(logger.info, \"Prefill warmup successful.\\n\")\n\n        def ordering_function_max_bs(b):\n            return (-b[0], b[1])\n\n        self.bucketing_ctx.generate_decode_buckets(self.bucketing_ctx.num_hpu_blocks)\n        buckets = list(\n            sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)\n        )\n        free_mem = HabanaMemoryProfiler.current_free_device_memory()\n        total_batch_seq = 0.001\n        total_mem = 0\n        available_mem = free_mem - self.mem_reserved\n        log_master(\n            logger.info, f\"Decode batch size list:{[bsz[0] for bsz in buckets]}\\n\"\n        )\n        for i, (batch_size, block_num) in enumerate(buckets):\n            if batch_size > block_num:\n                continue\n            # Graph memory usage is proportional to seq dimension in a batch\n            batch_seq = batch_size\n            mem_estimate = batch_seq / total_batch_seq * total_mem\n            graphed_bucket = (batch_size, block_num, False)\n            if not mem_estimate >= available_mem:\n                if graphed_bucket not in self.graphed_buckets:\n                    self.graphed_buckets.add(graphed_bucket)\n            warmup_shape_count += 1\n            self.log_warmup(False, i, len(buckets), batch_size, block_num)\n            with HabanaMemoryProfiler() as mem_prof:\n                for index in range(warmup_times):\n                    self.warmup_decode(batch_size, block_num, batch)\n                    synchronize(self.device)\n            used_mem = self.align_workers(\n                mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX\n            )\n            if graphed_bucket in self.graphed_buckets:\n                available_mem -= used_mem\n                total_mem += used_mem\n                total_batch_seq += batch_seq\n\n        log_master(logger.info, \"Decode warmup successful.\\n\")\n\n        log_master(\n            logger.info,\n            f\"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}\",\n        )\n\n    def forward(\n        self,\n        batch: FlashMllamaCausalLMBatch,\n        adapter_data: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        if cu_seqlen_prefill is None and self.max_past() is not None:\n            # In decode, not prefill, we're actually overwriting the KV-cache\n            # in a circular buffer mode.\n            # This makes sure the max_s for the decode pass is correct.\n            max_s = min(self.max_past(), max_s)\n\n        if batch.pixel_values is not None:\n            cross_attention_states = self.model.vision_forward(\n                pixel_values=batch.pixel_values,\n                aspect_ratio_ids=batch.aspect_ratio_ids,\n                aspect_ratio_mask=batch.aspect_ratio_mask,\n            )\n            batch.cross_attention_states = cross_attention_states\n\n        cross_attention_states = batch.cross_attention_states\n\n        kwargs = {}\n        if htorch.utils.internal.is_lazy():\n            batch_size = input_lengths.shape[0]\n            seqlen = (\n                input_ids.shape[0] // batch_size\n                if batch.prefilling\n                else batch.hpu_attn_meta.block_list.shape[0]\n            )\n            kwargs[\"bypass_hpu_graphs\"] = not self.use_graphs(\n                batch.prefilling, seqlen, batch_size\n            )\n\n        if batch.prefill_cache_indices is not None:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[batch.prefill_cache_indices] = slots\n            slots = slots_pad\n        else:\n            slots_pad = torch.zeros_like(input_ids, device=slots.device)\n            slots_pad[: slots.shape[0]] = slots\n            slots = slots_pad\n        orig_bs = len(batch)\n        padded_bs = batch.input_lengths_tensor.shape[0]\n        padded_input_len = input_ids.view(padded_bs, -1).shape[-1]\n        image_indices = torch.tensor(batch.image_indices)\n\n        if cross_attention_states is not None:\n            cross_attention_states = F.pad(\n                cross_attention_states,\n                (0, 0, 0, 0, 0, (padded_bs - orig_bs)),\n                value=0,\n            )\n        if len(image_indices) != 0:\n            pad_indices = torch.arange(orig_bs, padded_bs)\n            image_indices = torch.cat((image_indices, pad_indices), dim=0)\n\n        indices, cross_attention_len = generate_cross_attention_states(\n            cross_attention_states,\n            image_indices,\n            input_lengths,\n            padded_input_len,\n            batch.prefilling,\n        )\n        seqlen = Seqlen(\n            input_lengths=_async_h2d_tensor_copy(input_lengths),\n        )\n        logits, speculative_logits = self.model.forward(\n            input_ids=input_ids,\n            position_ids=_async_h2d_tensor_copy(position_ids),\n            cu_seqlen_prefill=_async_h2d_tensor_copy(cu_seqlen_prefill),\n            kv_cache=kv_cache,\n            slots=_async_h2d_tensor_copy(slots),\n            seqlen=trim_seqlen_metadata(seqlen),\n            hpu_attention_meta=batch.hpu_attn_meta,\n            lm_head_indices=_async_h2d_tensor_copy(lm_head_indices),\n            # TODO list\n            adapter_data=None,\n            cross_attention_states=cross_attention_states,\n            indices=_async_h2d_tensor_copy(indices),\n            cross_attention_len=_async_h2d_tensor_copy(cross_attention_len),\n            **kwargs,\n        )\n        if batch.pixel_values is not None:\n            batch.pixel_values = None\n        return logits, speculative_logits\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/model.py",
    "content": "import inspect\nimport torch\n\nfrom abc import ABC, abstractmethod\nfrom typing import List, Tuple, Optional, TypeVar, Type, Dict\nfrom collections import defaultdict\nfrom transformers import PreTrainedTokenizerBase\n\nfrom text_generation_server.models.types import Batch, Generation\nfrom text_generation_server.models.globals import BLOCK_SIZE\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.pb.generate_pb2 import InfoResponse\nfrom text_generation_server.adapters.weights import LayerAdapterWeights\nfrom text_generation_server.pb import generate_pb2\n\nBASE_MODEL_ADAPTER_ID = \"__base_model__\"\n\n\nB = TypeVar(\"B\", bound=Batch)\n\n\nclass Model(ABC):\n    def __init__(\n        self,\n        model_id: str,\n        model: torch.nn.Module,\n        tokenizer: PreTrainedTokenizerBase,\n        requires_padding: bool,\n        dtype: torch.dtype,\n        device: torch.device,\n        rank: int = 0,\n        world_size: int = 1,\n        sliding_window: Optional[int] = None,\n        speculate: Optional[int] = None,\n        adapter_id: str = BASE_MODEL_ADAPTER_ID,\n        support_chunking: bool = False,\n    ):\n        self.model_id = model_id\n        self.model = model.eval()\n        self.tokenizer = tokenizer\n\n        # all_special_ids is not set correctly if the rust tokenizer is unpacked\n        # TODO report this to transformers.\n        other_special_ids = {\n            id for id, token in tokenizer.added_tokens_decoder.items() if token.special\n        }\n        self.all_special_ids = set(tokenizer.all_special_ids)\n        self.all_special_ids.update(other_special_ids)\n        self.requires_padding = requires_padding\n        self.dtype = dtype\n        self.device = device\n        self.rank = rank\n        self.world_size = world_size\n        self.sliding_window = sliding_window if sliding_window != -1 else None\n\n        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(\n            LayerAdapterWeights\n        )\n        self.loaded_adapters = set()\n        self.static_adapter_id = adapter_id\n\n        if speculate is None:\n            speculate = get_speculate()\n        self.speculate = speculate\n\n        self.has_position_ids = (\n            inspect.signature(model.forward).parameters.get(\"position_ids\", None)\n            is not None\n        )\n\n        self.check_initialized()\n\n    @property\n    def info(self) -> InfoResponse:\n        if self.requires_padding and self.sliding_window is not None:\n            raise NotImplementedError(\"sliding_window is not implemented with padding\")\n\n        return InfoResponse(\n            requires_padding=self.requires_padding,\n            dtype=str(self.dtype),\n            device_type=self.device.type,\n            window_size=None,\n            speculate=self.speculate,\n            block_size=BLOCK_SIZE,\n        )\n\n    @property\n    @abstractmethod\n    def batch_type(self) -> Type[B]:\n        raise NotImplementedError\n\n    @abstractmethod\n    def generate_token(\n        self, batch: B\n    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:\n        raise NotImplementedError\n\n    def warmup(\n        self, batch: generate_pb2.WarmupRequest\n    ) -> Tuple[Optional[int], Optional[int], Optional[int]]:\n        self.generate_token(batch)\n        return None, None, None\n\n    def decode_token(\n        self,\n        all_input_ids: List[int],\n        prefix_offset: int = 0,\n        read_offset: int = 0,\n        skip_special_tokens: bool = False,\n    ) -> Tuple[str, int, int]:\n        \"\"\"Hack to hopefully support generate_stream for the maximum number of tokenizers\"\"\"\n\n        # The prefix text is necessary only to defeat cleanup algorithms in the decode\n        # which decide to add a space or not depending on the surrounding ids.\n        prefix_text = self.tokenizer.decode(\n            all_input_ids[prefix_offset:read_offset],\n            skip_special_tokens=skip_special_tokens,\n        )\n\n        new_text = self.tokenizer.decode(\n            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens\n        )\n\n        if len(new_text) > len(prefix_text) and not new_text.endswith(\"�\"):\n            # utf-8 char at the end means it's a potential unfinished byte sequence\n            # from byte fallback tokenization.\n            # If it's in the middle, it's probably a real invalid id generated\n            # by the model\n            new_text = new_text[len(prefix_text) :]\n            return new_text, read_offset, len(all_input_ids)\n        else:\n            return \"\", prefix_offset, read_offset\n\n    def check_initialized(self):\n        uninitialized_parameters = []\n        for n, p in self.model.named_parameters():\n            if p.data.device == torch.device(\"meta\"):\n                uninitialized_parameters.append(n)\n        if uninitialized_parameters:\n            raise RuntimeError(\n                f\"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}\"\n            )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/seq2seq_lm.py",
    "content": "import torch\nimport torch.distributed\nimport time\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForSeq2SeqLM,\n    PreTrainedTokenizerBase,\n    AutoConfig,\n)\nfrom typing import Optional, Tuple, List, Type, Dict\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.tokens import batch_top_tokens\nfrom text_generation_server.models import Model\nfrom text_generation_server.models.types import (\n    GeneratedText,\n    Batch,\n    Generation,\n    Tokens,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass Seq2SeqLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    requests_idx_mapping: Dict[int, int]\n\n    # Encoder values\n    input_ids: Optional[torch.Tensor]\n    attention_mask: torch.Tensor\n\n    # Decoder values\n    decoder_input_ids: torch.Tensor\n    decoder_attention_mask: Optional[torch.Tensor]\n    encoder_last_hidden_state: Optional[torch.Tensor]\n\n    # All tokens\n    all_decoder_input_ids: List[torch.Tensor]\n\n    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values\n    past_key_values: Optional[List[Tuple]]\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    decoder_input_lengths: List[int]\n    prefix_offsets: List[int]\n    read_offsets: List[int]\n\n    # Generation helpers\n    next_token_choosers: List[NextTokenChooser]\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Metadata used for padding\n    max_input_length: int\n    max_decoder_input_length: int\n    padding_right_offset: int\n\n    # Maximum number of tokens this batch will grow to\n    max_tokens: int\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        \"\"\"Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf\"\"\"\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.max_tokens,\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"Seq2SeqLMBatch\":\n        \"\"\"Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch\"\"\"\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            inputs.append(concat_text_chunks(r.input_chunks.chunks))\n            requests_idx_mapping[r.id] = i\n            decoder_input_lengths.append(1)\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        # Tokenize batch\n        tokenized_inputs = tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            return_token_type_ids=False,\n            truncation=True,\n            max_length=max_truncation,\n        ).to(device)\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n\n        # Decoder sequence only contains the bos_token\n        decoder_input_ids = (\n            torch.tensor(tokenizer.bos_token_id, device=device)\n            .repeat(len(pb.requests))\n            .view(-1, 1)\n        )\n        for _ in pb.requests:\n            prefix_offsets.append(0)\n            read_offsets.append(1)\n        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n\n        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=tokenized_inputs[\"input_ids\"],\n            attention_mask=tokenized_inputs[\"attention_mask\"],\n            decoder_input_ids=decoder_input_ids,\n            all_decoder_input_ids=list(all_decoder_input_ids),\n            decoder_attention_mask=None,\n            encoder_last_hidden_state=None,\n            past_key_values=None,\n            input_lengths=input_lengths.tolist(),\n            decoder_input_lengths=decoder_input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length.item(),\n            max_decoder_input_length=1,\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> Optional[\"Seq2SeqLMBatch\"]:\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        if len(request_ids) == len(self):\n            return self\n\n        keep_indices = []\n\n        # New values after filtering\n        requests_idx_mapping = {}\n        requests = []\n        input_lengths = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n\n        all_decoder_input_ids = []\n\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        max_input_length = 0\n        max_decoder_input_length = 0\n        padding_right_offset = 0\n\n        total_remaining_decode_tokens = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            requests_idx_mapping[request_id] = i\n            keep_indices.append(idx)\n\n            requests.append(self.requests[idx])\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n\n            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])\n\n            request_input_length = self.input_lengths[idx]\n            input_lengths.append(request_input_length)\n            max_input_length = max(max_input_length, request_input_length)\n\n            request_decoder_input_length = self.decoder_input_lengths[idx]\n            decoder_input_lengths.append(request_decoder_input_length)\n            max_decoder_input_length = max(\n                max_decoder_input_length, request_decoder_input_length\n            )\n\n            next_token_choosers.append(self.next_token_choosers[idx])\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(self.top_n_tokens[idx])\n            remaining_decode_tokens = (\n                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens\n            )\n            total_remaining_decode_tokens += remaining_decode_tokens\n            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)\n\n        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached\n        self.decoder_input_ids = self.decoder_input_ids[keep_indices]\n        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]\n        if self.decoder_attention_mask is not None:\n            self.decoder_attention_mask = self.decoder_attention_mask[\n                keep_indices,\n                -(self.padding_right_offset + max_decoder_input_length) : (\n                    self.decoder_attention_mask.shape[1] - self.padding_right_offset\n                )\n                + padding_right_offset,\n            ]\n\n        self.encoder_last_hidden_state = self.encoder_last_hidden_state[\n            keep_indices, -max_input_length:\n        ]\n\n        # Ensure that past_key_values tensors can be updated in-place\n        if type(self.past_key_values[0]) is tuple:\n            self.past_key_values = [\n                [t for t in layer] for layer in self.past_key_values\n            ]\n\n        decoder_past_seq_len = max_decoder_input_length - 1\n        for layer in self.past_key_values:\n            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]\n            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]\n            layer[2] = layer[2][keep_indices, :, -max_input_length:]\n            layer[3] = layer[3][keep_indices, :, -max_input_length:]\n\n        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]\n        max_tokens = (\n            len(request_ids) * (max_input_length + max_decoder_input_length)\n            + remaining_decode_tokens\n        )\n\n        self.requests = requests\n        self.requests_idx_mapping = requests_idx_mapping\n        self.input_ids = None\n        self.all_decoder_input_ids = all_decoder_input_ids\n        self.input_lengths = input_lengths\n        self.decoder_input_lengths = decoder_input_lengths\n        self.prefix_offsets = prefix_offsets\n        self.read_offsets = read_offsets\n        self.next_token_choosers = next_token_choosers\n        self.stopping_criterias = stopping_criterias\n        self.top_n_tokens = top_n_tokens\n        self.top_n_tokens_tensor = top_n_tokens_tensor\n        self.max_input_length = max_input_length\n        self.max_decoder_input_length = max_decoder_input_length\n        self.padding_right_offset = padding_right_offset\n        self.max_tokens = max_tokens\n\n        return self\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches: List[\"Seq2SeqLMBatch\"]) -> \"Seq2SeqLMBatch\":\n        \"\"\"Concatenate multiple batches together by padding internal torch tensors\"\"\"\n\n        # Used for padding\n        total_batch_size = 0\n        max_input_length = 0\n        max_decoder_input_length = 0\n        padding_right_offset = 0\n        for batch in batches:\n            total_batch_size += len(batch)\n            max_input_length = max(max_input_length, batch.max_input_length)\n            max_decoder_input_length = max(\n                max_decoder_input_length, batch.max_decoder_input_length\n            )\n            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)\n\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n        all_decoder_input_ids = []\n        input_lengths = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        max_tokens = 0\n\n        # Batch tensors\n        attention_mask = None\n        decoder_input_ids = None\n        decoder_attention_mask = None\n        encoder_last_hidden_state = None\n        top_n_tokens_tensor = None\n        past_key_values = []\n\n        # Used for slicing correctly inside the tensors\n        # Equivalent to a cumsum on batch sizes\n        start_index = 0\n\n        for i, batch in enumerate(batches):\n            # Extend all list attributes\n            requests.extend(batch.requests)\n            all_decoder_input_ids.extend(batch.all_decoder_input_ids)\n            input_lengths.extend(batch.input_lengths)\n            decoder_input_lengths.extend(batch.decoder_input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n            next_token_choosers.extend(batch.next_token_choosers)\n            stopping_criterias.extend(batch.stopping_criterias)\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + start_index\n\n            # Slicing end index for this batch\n            end_index = start_index + len(batch)\n\n            # We only concatenate batches that did at least one step\n            if batch.encoder_last_hidden_state is None:\n                raise ValueError(\"Batch encoder_last_hidden_state cannot be None\")\n\n            # Create padded tensor\n            if attention_mask is None:\n                attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_input_length),\n                )\n            # Copy to correct indices\n            attention_mask[start_index:end_index, -batch.max_input_length :] = (\n                batch.attention_mask[:, -batch.max_input_length :]\n            )\n\n            # Create padded tensor\n            if decoder_input_ids is None:\n                decoder_input_ids = batch.decoder_input_ids.new_zeros(\n                    (total_batch_size, 1),\n                )\n            # Copy to correct indices\n            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids\n\n            # Create padded tensor\n            if decoder_attention_mask is None:\n                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here\n                decoder_attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_decoder_input_length + padding_right_offset),\n                )\n            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated\n            # this batch. All generations are of length `batch.max_decoder_input_length`.\n            left_offset = max_decoder_input_length - batch.max_decoder_input_length\n            if batch.decoder_attention_mask is None:\n                decoder_attention_mask[\n                    start_index:end_index,\n                    left_offset:-padding_right_offset,\n                ] = 1\n            # If it exists, we need to index\n            else:\n                batch_left_offset = (\n                    batch.decoder_attention_mask.shape[1]\n                    - batch.max_decoder_input_length\n                    - batch.padding_right_offset\n                )\n                decoder_attention_mask[\n                    start_index:end_index,\n                    left_offset:-padding_right_offset,\n                ] = batch.decoder_attention_mask[\n                    :,\n                    batch_left_offset : -batch.padding_right_offset,\n                ]\n\n            # Create padded tensor\n            if encoder_last_hidden_state is None:\n                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(\n                    (\n                        total_batch_size,\n                        max_input_length,\n                        batch.encoder_last_hidden_state.shape[-1],\n                    ),\n                )\n\n            if top_n_tokens_tensor is None:\n                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n                    total_batch_size,\n                )\n            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor\n\n            # Copy to correct indices\n            encoder_last_hidden_state[\n                start_index:end_index, -batch.max_input_length :, :\n            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]\n            batch.encoder_last_hidden_state = None\n\n            # Ensure that we can update tensors in-place\n            if isinstance(batch.past_key_values[0], tuple):\n                batch.past_key_values = [\n                    [t for t in layer] for layer in batch.past_key_values\n                ]\n\n            # Add eventual padding tokens that were added while concatenating\n            max_tokens += batch.max_tokens + (\n                max_input_length\n                - batch.max_input_length\n                + max_decoder_input_length\n                - batch.max_decoder_input_length\n            ) * len(batch)\n\n            start_index = end_index\n\n        # Determine shapes for new past kv tensors\n        first_past_kvs = batches[0].past_key_values\n        _, num_heads, _, head_dim = first_past_kvs[0][0].shape\n\n        padded_dec_t_shape = (\n            total_batch_size,\n            num_heads,\n            (max_decoder_input_length - 1),\n            head_dim,\n        )\n\n        padded_enc_t_shape = (\n            total_batch_size,\n            num_heads,\n            max_input_length,\n            head_dim,\n        )\n\n        # Iterate over attention layers\n        for j in range(len(first_past_kvs)):\n            past_key_values.append([])\n\n            # Decoder past\n            for k in range(0, 2):\n                # Initialize tensors\n                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)\n                past_key_values[j].append(padded_past_values)\n\n                start_index = 0\n                for batch in batches:\n                    t = batch.past_key_values[j][k]\n                    # Clear reference to the original tensor\n                    batch.past_key_values[j][k] = None\n                    # Slicing end index for this batch\n                    end_index = start_index + len(batch)\n                    # We slice the past keys and values to remove the padding from previous batches\n                    past_seq_len = batch.max_decoder_input_length - 1\n                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[\n                        :, :, -past_seq_len:, :\n                    ]\n                    del t\n\n                    start_index = end_index\n\n            # Encoder past\n            for k in range(2, 4):\n                # Initialize tensors\n                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)\n                past_key_values[j].append(padded_past_values)\n\n                start_index = 0\n                for batch in batches:\n                    t = batch.past_key_values[j][k]\n                    # Clear reference to the original tensor\n                    batch.past_key_values[j][k] = None\n                    # Slicing end index for this batch\n                    end_index = start_index + len(batch)\n                    # We slice the past keys and values to remove the padding from previous batches\n                    padded_past_values[\n                        start_index:end_index, :, -batch.max_input_length :, :\n                    ] = t[:, :, -batch.max_input_length :, :]\n                    del t\n\n                    start_index = end_index\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=None,\n            attention_mask=attention_mask,\n            decoder_input_ids=decoder_input_ids,\n            all_decoder_input_ids=all_decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_last_hidden_state=encoder_last_hidden_state,\n            past_key_values=past_key_values,\n            input_lengths=input_lengths,\n            decoder_input_lengths=decoder_input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length,\n            max_decoder_input_length=max_decoder_input_length,\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nclass Seq2SeqLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        default_dtype=torch.float16,\n        trust_remote_code: bool = False,\n        config_class=AutoConfig,\n        tokenizer_class=AutoTokenizer,\n        aliases=None,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n\n        device = torch.device(\"hpu\")\n        dtype = torch.bfloat16 if dtype is None else dtype\n\n        config = config_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        tokenizer.bos_token_id = config.decoder_start_token_id\n\n        weights_loader = get_loader(\n            quantize=quantize, model_id=model_id, revision=revision\n        )\n        torch.distributed.barrier(group=self.process_group)\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device=device,\n            dtype=dtype,\n            process_group=self.process_group,\n            aliases=aliases,\n            weights_loader=weights_loader,\n        )\n        if config.quantize in [\"awq\", \"gptq\"]:\n            weights._set_gptq_params(model_id, revision)\n\n        model = model_class(config, weights)\n\n        torch.distributed.barrier(group=self.process_group)\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n        )\n\n    @classmethod\n    def fallback(\n        cls,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        if speculator:\n            raise RuntimeError(\"Speculator decoding is not enabled for AutoModel\")\n\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            dtype = torch.float16 if dtype is None else dtype\n        else:\n            if quantize:\n                raise ValueError(\"quantization is not available on CPU\")\n\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        model = AutoModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=dtype,\n            device_map=(\n                \"auto\"\n                if torch.cuda.is_available() and torch.cuda.device_count() > 1\n                else None\n            ),\n            load_in_8bit=quantize == \"bitsandbytes\",\n            trust_remote_code=trust_remote_code,\n        )\n        if torch.cuda.is_available() and torch.cuda.device_count() == 1:\n            model = model.cuda()\n\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        tokenizer.bos_token_id = model.config.decoder_start_token_id\n\n        self = cls.__new__(\n            cls,\n        )\n        super().__init__(\n            self,\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n        )\n        self.quantize = quantize\n        return self\n\n    @property\n    def batch_type(self) -> Type[Seq2SeqLMBatch]:\n        return Seq2SeqLMBatch\n\n    def forward(\n        self,\n        input_ids,\n        attention_mask,\n        decoder_input_ids,\n        decoder_attention_mask: Optional,\n        encoder_last_hidden_state: Optional,\n        past_key_values: Optional = None,\n    ) -> Tuple[\n        torch.Tensor,\n        Optional[torch.Tensor],\n        torch.Tensor,\n        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],\n    ]:\n        # Model Forward\n        outputs = self.model.forward(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            decoder_input_ids=decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_outputs=encoder_last_hidden_state,\n            past_key_values=past_key_values,\n            use_cache=True,\n        )\n        if isinstance(outputs, tuple):\n            # Our custom models\n            outputs, speculative_logits = outputs\n        else:\n            # Generic transformers models\n            speculative_logits = None\n        return (\n            outputs.logits,\n            speculative_logits,\n            outputs.encoder_last_hidden_state,\n            outputs.past_key_values,\n        )\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batch: Seq2SeqLMBatch\n    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:\n        start = time.time_ns()\n        if batch.decoder_attention_mask is not None:\n            # slice to the correct shape\n            decoder_attention_mask = batch.decoder_attention_mask[\n                :, : -batch.padding_right_offset\n            ]\n        else:\n            decoder_attention_mask = None\n\n        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`\n        # internally...\n        if batch.encoder_last_hidden_state is not None:\n            encoder_last_hidden_state = [batch.encoder_last_hidden_state]\n        else:\n            encoder_last_hidden_state = None\n\n        logits, speculative_logits, encoder_last_hidden_state, past = self.forward(\n            batch.input_ids,\n            batch.attention_mask,\n            batch.decoder_input_ids,\n            decoder_attention_mask,\n            encoder_last_hidden_state,\n            batch.past_key_values,\n        )\n\n        # Speculation is not active for seq2seq\n        accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]\n        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n            batch.top_n_tokens,\n            batch.top_n_tokens_tensor,\n            torch.log_softmax(logits[:, -1], -1),\n            accepted_ids,\n        )\n\n        start_decode = time.time_ns()\n\n        # Finished requests\n        generations: List[Generation] = []\n        stopped = True\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            batch.decoder_input_lengths,\n            logits,\n            batch.next_token_choosers,\n            batch.stopping_criterias,\n            batch.all_decoder_input_ids,\n            batch.top_n_tokens,\n            batch_top_token_ids,\n            batch_top_token_logprobs,\n        )\n\n        # For each member of the batch\n        for i, (\n            request,\n            input_length,\n            prefix_offset,\n            read_offset,\n            decoder_input_length,\n            logits,\n            next_token_chooser,\n            stopping_criteria,\n            all_decoder_input_ids,\n            top_n_tokens,\n            top_token_ids,\n            top_token_logprobs,\n        ) in enumerate(iterator):\n            # Select next token\n            next_token_id, logprobs = next_token_chooser(\n                all_decoder_input_ids.view(1, -1), logits[-1:, :]\n            )\n\n            # Append next token to decoder tokens\n            all_decoder_input_ids = torch.cat(\n                [all_decoder_input_ids, next_token_id.squeeze(1)]\n            )\n            new_decoder_input_length = decoder_input_length + 1\n\n            # Generated token\n            next_token_logprob = logprobs[-1, next_token_id]\n            next_token_id_squeezed = next_token_id.squeeze()\n            next_token_text, prefix_offset, read_offset = self.decode_token(\n                all_decoder_input_ids, prefix_offset, read_offset\n            )\n\n            # Evaluate stopping criteria\n            stop, reason = stopping_criteria(next_token_id, next_token_text)\n\n            if not stop:\n                stopped = False\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if i % self.world_size == self.rank:\n                if stop:\n                    # Slice with decoder_input_length to remove padding\n                    # Decode all tokens\n                    output_text, _, _ = self.decode_token(\n                        all_decoder_input_ids,\n                        prefix_offset=len(all_decoder_input_ids)\n                        - decoder_input_length\n                        - 1,\n                        read_offset=len(all_decoder_input_ids) - decoder_input_length,\n                        skip_special_tokens=True,\n                    )\n\n                    # Get seed\n                    if isinstance(next_token_chooser.choice, Sampling):\n                        seed = next_token_chooser.choice.seed\n                    else:\n                        seed = None\n\n                    generated_text = GeneratedText(\n                        output_text, stopping_criteria.current_tokens, reason, seed\n                    )\n                else:\n                    generated_text = None\n\n                # Prefill\n                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:\n                    prefill_tokens = Tokens(\n                        [self.tokenizer.bos_token_id],\n                        [float(\"nan\")],\n                        [self.tokenizer.bos_token],\n                        [False],\n                    )\n                else:\n                    prefill_tokens = None\n\n                if top_n_tokens > 0:\n                    all_top_tokens = []\n                    for top_token_ids, top_token_logprobs in zip(\n                        top_token_ids, top_token_logprobs\n                    ):\n                        toptoken_texts = self.tokenizer.batch_decode(\n                            top_token_ids,\n                            clean_up_tokenization_spaces=False,\n                            skip_special_tokens=False,\n                        )\n                        special_toptokens = [\n                            token_id in self.all_special_ids\n                            for token_id in top_token_ids\n                        ]\n                        top_tokens = Tokens(\n                            top_token_ids,\n                            top_token_logprobs,\n                            toptoken_texts,\n                            special_toptokens,\n                        )\n                        all_top_tokens.append(top_tokens)\n                    top_tokens = all_top_tokens\n                else:\n                    top_tokens = None\n\n                generation = Generation(\n                    request.id,\n                    prefill_tokens,\n                    Tokens(\n                        [next_token_id_squeezed],\n                        [next_token_logprob],\n                        [next_token_text],\n                        [next_token_id_squeezed.item() in self.all_special_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n            # Update values\n            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(\n                next_token_id_squeezed.item()\n            )\n            batch.decoder_input_ids[i] = next_token_id\n            batch.all_decoder_input_ids[i] = all_decoder_input_ids\n            batch.input_lengths[i] = input_length\n            batch.decoder_input_lengths[i] = new_decoder_input_length\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.max_input_length = max(batch.max_input_length, input_length)\n            batch.max_decoder_input_length = max(\n                batch.max_decoder_input_length, new_decoder_input_length\n            )\n\n        # We finished all generations in the batch; there is no next batch\n        if stopped:\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        # We don't need input_ids after the prefill forward\n        batch.input_ids = None\n        batch.encoder_last_hidden_state = encoder_last_hidden_state\n        batch.past_key_values = past\n        # Update decoder_attention_mask as we added a new token to input_ids\n        if batch.decoder_attention_mask is not None:\n            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1\n        batch.padding_right_offset -= 1\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/models/types.py",
    "content": "import torch\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nfrom transformers import PreTrainedTokenizerBase\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.pb.generate_pb2 import FinishReason\n\n\nclass Batch(ABC):\n    @abstractmethod\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"Batch\":\n        raise NotImplementedError\n\n    @abstractmethod\n    def filter(self, request_ids: List[int]) -> \"Batch\":\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def concatenate(cls, batches: List[\"Batch\"]) -> \"Batch\":\n        raise NotImplementedError\n\n    @abstractmethod\n    def __len__(self):\n        raise NotImplementedError\n\n\n@dataclass\nclass GeneratedText:\n    text: str\n    generated_tokens: int\n    finish_reason: FinishReason\n    seed: Optional[int]\n\n    def to_pb(self) -> generate_pb2.GeneratedText:\n        return generate_pb2.GeneratedText(\n            text=self.text,\n            generated_tokens=self.generated_tokens,\n            finish_reason=self.finish_reason,\n            seed=self.seed,\n        )\n\n\n@dataclass\nclass Tokens:\n    token_ids: List[int]\n    logprobs: List[float]\n    texts: List[str]\n    is_special: List[bool]\n\n    def to_pb(self) -> generate_pb2.Tokens:\n        return generate_pb2.Tokens(\n            ids=self.token_ids,\n            logprobs=self.logprobs,\n            texts=self.texts,\n            is_special=self.is_special,\n        )\n\n    def __len__(self):\n        return len(self.token_ids)\n\n\n@dataclass\nclass Generation:\n    request_id: int\n    prefill_tokens: Optional[Tokens]\n    tokens: Tokens\n    generated_text: Optional[GeneratedText]\n    # Optional for now, since it's not yet supported for every model.\n    top_tokens: Optional[List[Tokens]]\n\n    def to_pb(self) -> generate_pb2.Generation:\n        return generate_pb2.Generation(\n            request_id=self.request_id,\n            prefill_tokens=(\n                self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None\n            ),\n            tokens=self.tokens.to_pb(),\n            generated_text=(\n                self.generated_text.to_pb() if self.generated_text is not None else None\n            ),\n            top_tokens=(\n                [top_tokens.to_pb() for top_tokens in self.top_tokens]\n                if self.top_tokens is not None\n                else None\n            ),\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/pb/.gitignore",
    "content": "*.py\n*.pyi\n*.py-e\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/server.py",
    "content": "# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.\n\nimport asyncio\nimport os\nimport torch\nimport time\nimport signal\n\nfrom grpc import aio\nfrom loguru import logger\n\nfrom grpc_reflection.v1alpha import reflection\nfrom pathlib import Path\nfrom typing import List, Optional\n\nfrom text_generation_server.cache import Cache\nfrom text_generation_server.interceptor import ExceptionInterceptor\nfrom text_generation_server.models import Model, get_model_with_lora_adapters\nfrom text_generation_server.pb import generate_pb2_grpc, generate_pb2\nfrom text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor\nfrom text_generation_server.models.globals import set_model_id, ATTENTION\nfrom text_generation_server.models.globals import set_adapter_to_index\nfrom text_generation_server.utils.adapter import AdapterInfo\nfrom text_generation_server.utils.tokens import make_tokenizer_optional\nfrom text_generation_server.utils.prefill_chunking import set_max_prefill_tokens\nfrom text_generation_server.models import VLM_BATCH_TYPES\n\nfrom text_generation_server.utils.version import (\n    is_driver_compatible,\n    MIN_TGI_GAUDI_SYNAPSE_VERSION,\n)\n\n\nclass SignalHandler:\n    KEEP_PROCESSING = True\n\n    def __init__(self):\n        signal.signal(signal.SIGINT, self.exit_gracefully)\n        signal.signal(signal.SIGTERM, self.exit_gracefully)\n\n    def exit_gracefully(self, signum, frame):\n        print(f\"Exiting gracefully: Signal {signum}\")\n        self.KEEP_PROCESSING = False\n\n\nclass TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):\n    def __init__(\n        self,\n        model: Model,\n        cache: Cache,\n        server_urls: List[str],\n    ):\n        self.cache = cache\n        self.model = model\n        # Quantize is resolved during model loading\n        self.quantize = model.quantize\n        self.server_urls = server_urls\n        # For some reason, inference_mode does not work well with GLOO which we use on CPU\n        # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul\n        # op not optimized issue. Will investigate further.\n        # if model.device.type == \"hpu\":\n        # Force inference mode for the lifetime of TextGenerationService\n        # self._inference_mode_raii_guard = torch._C._InferenceMode(True)\n\n    async def Info(self, request, context):\n        return self.model.info\n\n    async def Health(self, request, context):\n        if self.model.device.type == \"hpu\":\n            torch.zeros((2, 2)).to(\"hpu\")\n        return generate_pb2.HealthResponse()\n\n    async def ServiceDiscovery(self, request, context):\n        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)\n\n    async def ClearCache(self, request, context):\n        if request.HasField(\"id\"):\n            self.cache.delete(request.id)\n        else:\n            self.cache.clear()\n        return generate_pb2.ClearCacheResponse()\n\n    async def FilterBatch(self, request, context):\n        batch = self.cache.pop(request.batch_id)\n        if batch is None:\n            raise ValueError(f\"Batch ID {request.batch_id} not found in cache.\")\n        filtered_batch = batch.filter(request.request_ids)\n        self.cache.set(filtered_batch)\n\n        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())\n\n    async def Warmup(self, request, context):\n        if ATTENTION == \"paged\":\n            set_max_prefill_tokens(request.max_prefill_tokens)\n            if (\n                self.model.batch_type in VLM_BATCH_TYPES\n            ):  # Hack, i would rather use kwargs in the `from_pb` call\n                batch = self.model.batch_type.from_pb_processor(\n                    request.batch,\n                    self.model.tokenizer,\n                    self.model.processor,\n                    self.model.model.config,\n                    self.model.dtype,\n                    self.model.device,\n                )\n            else:\n                batch = self.model.batch_type.from_pb(\n                    request.batch,\n                    self.model.tokenizer,\n                    self.model.dtype,\n                    self.model.device,\n                )\n\n            # Override default values with None for clearer semantics.\n            max_input_tokens = (\n                request.max_input_tokens\n                if request.HasField(\"max_input_tokens\")\n                else None\n            )\n            max_total_tokens = (\n                request.max_total_tokens\n                if request.HasField(\"max_total_tokens\")\n                else None\n            )\n            max_supported_total_tokens, max_input_tokens, max_total_tokens = (\n                self.model.warmup(batch, max_input_tokens, max_total_tokens)\n            )\n        else:\n            max_supported_total_tokens, max_input_tokens, max_total_tokens = (\n                self.model.warmup(request)\n            )\n\n            # W/A for the skip tokenizer path\n            # We need to call make_tokenizer_optional after the warmup,\n            # because router is not aware of that feature\n            make_tokenizer_optional(self.model.tokenizer)\n\n        return generate_pb2.WarmupResponse(\n            max_supported_total_tokens=max_supported_total_tokens,\n            max_input_tokens=max_input_tokens,\n            max_total_tokens=max_total_tokens,\n        )\n\n    async def Prefill(self, request, context):\n        start = time.time_ns()\n        if (\n            self.model.batch_type in VLM_BATCH_TYPES\n        ):  # Hack, i would rather use kwargs in the `from_pb` call\n            batch = self.model.batch_type.from_pb_processor(\n                request.batch,\n                self.model.tokenizer,\n                self.model.processor,\n                self.model.model.config,\n                self.model.dtype,\n                self.model.device,\n            )\n        else:\n            batch = self.model.batch_type.from_pb(\n                request.batch, self.model.tokenizer, self.model.dtype, self.model.device\n            )\n\n        generations, next_batch, timings = self.model.generate_token([batch])\n        self.cache.set(next_batch)\n\n        return generate_pb2.PrefillResponse(\n            generations=[generation.to_pb() for generation in generations],\n            batch=next_batch.to_pb() if next_batch else None,\n            forward_ns=timings[0],\n            decode_ns=timings[1],\n            total_ns=time.time_ns() - start,\n        )\n\n    async def Decode(self, request, context):\n        start = time.time_ns()\n        if len(request.batches) == 0:\n            raise ValueError(\"Must provide at least one batch\")\n\n        batches = []\n        for batch_pb in request.batches:\n            batch = self.cache.pop(batch_pb.id)\n            if batch is None:\n                raise ValueError(f\"Batch ID {batch_pb.id} not found in cache.\")\n            batches.append(batch)\n\n        if len(batches) == 0:\n            raise ValueError(\"All batches are empty\")\n\n        generations, next_batch, timings = self.model.generate_token(batches)\n        self.cache.set(next_batch)\n\n        return generate_pb2.DecodeResponse(\n            generations=[generation.to_pb() for generation in generations],\n            batch=next_batch.to_pb() if next_batch else None,\n            concat_ns=None,\n            forward_ns=timings[0],\n            decode_ns=timings[1],\n            total_ns=time.time_ns() - start,\n        )\n\n\ndef serve(\n    model_id: str,\n    lora_adapters: Optional[List[AdapterInfo]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[str],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    uds_path: Path,\n    max_input_tokens: int,\n):\n    async def serve_inner(\n        model_id: str,\n        lora_adapters: Optional[List[AdapterInfo]],\n        revision: Optional[str],\n        sharded: bool = False,\n        quantize: Optional[str] = None,\n        speculate: Optional[int] = None,\n        dtype: Optional[str] = None,\n        kv_cache_dtype: Optional[str] = None,\n        trust_remote_code: bool = False,\n    ):\n        if not is_driver_compatible():\n            logger.warning(\n                f\"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures\"\n            )\n\n        unix_socket_template = \"unix://{}-{}\"\n        adapter_to_index = {}\n        logger.info(\"Server:server_inner: sharded ={}\".format(sharded))\n\n        if sharded:\n            rank = int(os.environ[\"RANK\"])\n            logger.info(\"Server:server_inner: rank ={}\".format(rank))\n            server_urls = [\n                unix_socket_template.format(uds_path, rank)\n                for rank in range(int(os.environ[\"WORLD_SIZE\"]))\n            ]\n            local_url = server_urls[int(os.environ[\"RANK\"])]\n        else:\n            local_url = unix_socket_template.format(uds_path, 0)\n            server_urls = [local_url]\n\n        logger.info(\n            \"Server:server_inner: data type = {}, local_url = {}\".format(\n                dtype, local_url\n            )\n        )\n        if dtype == \"bfloat16\" or None:\n            data_type = torch.bfloat16\n        else:\n            data_type = torch.float\n        if revision == \"None\":\n            revision = None\n        try:\n            model = get_model_with_lora_adapters(\n                model_id,\n                lora_adapters,\n                revision,\n                sharded,\n                quantize,\n                speculate,\n                data_type,\n                kv_cache_dtype,\n                trust_remote_code,\n                max_input_tokens,\n                adapter_to_index,\n            )\n\n        except Exception:\n            logger.exception(\"Error when initializing model\")\n            raise\n\n        set_adapter_to_index(adapter_to_index)\n        server = aio.server(\n            interceptors=[\n                ExceptionInterceptor(),\n                UDSOpenTelemetryAioServerInterceptor(),\n            ],\n            options=[\n                # Set the maximum possible message length: i32::MAX\n                (\"grpc.max_receive_message_length\", (1 << 31) - 1)\n            ],\n        )\n        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(\n            TextGenerationService(model, Cache(), server_urls), server\n        )\n        SERVICE_NAMES = (\n            generate_pb2.DESCRIPTOR.services_by_name[\"TextGenerationService\"].full_name,\n            reflection.SERVICE_NAME,\n        )\n        reflection.enable_server_reflection(SERVICE_NAMES, server)\n        server.add_insecure_port(local_url)\n\n        await server.start()\n\n        logger.info(\"Server started at {}\".format(local_url))\n        signal_handler = SignalHandler()\n        while signal_handler.KEEP_PROCESSING:\n            await asyncio.sleep(0.5)\n\n    set_model_id(model_id)\n    asyncio.run(\n        serve_inner(\n            model_id,\n            lora_adapters,\n            revision,\n            sharded,\n            quantize,\n            speculate,\n            dtype,\n            kv_cache_dtype,\n            trust_remote_code,\n        )\n    )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/tracing.py",
    "content": "import grpc\n\nfrom opentelemetry import trace\nfrom opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\nfrom opentelemetry.instrumentation.grpc._aio_server import (\n    OpenTelemetryAioServerInterceptor,\n)\nfrom opentelemetry.semconv.trace import SpanAttributes\nfrom opentelemetry.sdk.resources import Resource\nfrom opentelemetry.sdk.trace import TracerProvider\nfrom opentelemetry.sdk.trace.export import (\n    BatchSpanProcessor,\n)\n\n\nclass UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):\n    def __init__(self):\n        super().__init__(trace.get_tracer(__name__))\n\n    def _start_span(self, handler_call_details, context, set_status_on_exception=False):\n        \"\"\"\n        Rewrite _start_span method to support Unix Domain Socket gRPC contexts\n        \"\"\"\n\n        # standard attributes\n        attributes = {\n            SpanAttributes.RPC_SYSTEM: \"grpc\",\n            SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],\n        }\n\n        # if we have details about the call, split into service and method\n        if handler_call_details.method:\n            service, method = handler_call_details.method.lstrip(\"/\").split(\"/\", 1)\n            attributes.update(\n                {\n                    SpanAttributes.RPC_METHOD: method,\n                    SpanAttributes.RPC_SERVICE: service,\n                }\n            )\n\n        # add some attributes from the metadata\n        metadata = dict(context.invocation_metadata())\n        if \"user-agent\" in metadata:\n            attributes[\"rpc.user_agent\"] = metadata[\"user-agent\"]\n\n        # We use gRPC over a UNIX socket\n        attributes.update({SpanAttributes.NET_TRANSPORT: \"unix\"})\n\n        return self._tracer.start_as_current_span(\n            name=handler_call_details.method,\n            kind=trace.SpanKind.SERVER,\n            attributes=attributes,\n            set_status_on_exception=set_status_on_exception,\n        )\n\n\ndef setup_tracing(otlp_service_name: str, otlp_endpoint: str):\n    resource = Resource.create(attributes={\"service.name\": otlp_service_name})\n    span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)\n    span_processor = BatchSpanProcessor(span_exporter)\n\n    trace.set_tracer_provider(TracerProvider(resource=resource))\n    trace.get_tracer_provider().add_span_processor(span_processor)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/__init__.py",
    "content": "# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.\n\nfrom text_generation_server.utils.convert import convert_file, convert_files\nfrom text_generation_server.utils.dist import initialize_torch_distributed\nfrom text_generation_server.utils.weights import Weights\nfrom text_generation_server.utils.peft import download_and_unload_peft\nfrom text_generation_server.utils.hub import (\n    weight_files,\n    weight_hub_files,\n    download_weights,\n    EntryNotFoundError,\n    LocalEntryNotFoundError,\n    RevisionNotFoundError,\n)\nfrom text_generation_server.utils.tokens import (\n    NextTokenChooser,\n    HeterogeneousNextTokenChooser,\n    StoppingCriteria,\n    StopSequenceCriteria,\n    FinishReason,\n    Sampling,\n    Greedy,\n    make_tokenizer_optional,\n    is_tokenizer_transparent,\n    pad_next_token_chooser_parameters,\n)\n\n__all__ = [\n    \"convert_file\",\n    \"convert_files\",\n    \"initialize_torch_distributed\",\n    \"weight_files\",\n    \"weight_hub_files\",\n    \"download_weights\",\n    \"download_and_unload_peft\",\n    \"EntryNotFoundError\",\n    \"HeterogeneousNextTokenChooser\",\n    \"LocalEntryNotFoundError\",\n    \"RevisionNotFoundError\",\n    \"Greedy\",\n    \"NextTokenChooser\",\n    \"Sampling\",\n    \"StoppingCriteria\",\n    \"StopSequenceCriteria\",\n    \"FinishReason\",\n    \"Weights\",\n    \"make_tokenizer_optional\",\n    \"is_tokenizer_transparent\",\n    \"pad_next_token_chooser_parameters\",\n]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/adapter.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/utils/adapter.py\n# License:  Apache License Version 2.0, January 2004\n\nimport warnings\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING, Set, Tuple, Optional, List\n\nfrom safetensors.torch import load_file\nfrom transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer\n\nfrom text_generation_server.utils.merges.strategies import merge_adapters\n\nfrom text_generation_server.utils import hub\nfrom text_generation_server.adapters.lora import LoraConfig\n\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters.config import AdapterConfig, ModuleMap\n\n\nBASE_MODEL_ADAPTER_ID = \"__base_model__\"\n\n\n@dataclass\nclass AdapterInfo:\n    id: str\n    path: Optional[str]\n    revision: Optional[str] = None\n\n\n@dataclass\nclass AdapterParameters:\n    adapter_info: Tuple[AdapterInfo]\n    weights: Tuple[float]\n    merge_strategy: NotImplemented\n    density: float\n    majority_sign_method: NotImplemented\n\n\n@dataclass\nclass AdapterSource:\n    adapter_id: str\n    model_id: str\n    revision: str\n\n\ndef parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:\n    if not lora_adapters:\n        return []\n\n    adapter_list = []\n    for adapter in lora_adapters.split(\",\"):\n        adapter = adapter.strip()\n        if adapter.count(\"=\") > 1 or adapter.count(\"@\") > 1:\n            raise ValueError(f\"Invalid LoRA adapter format: {adapter}\")\n        match = re.match(r\"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$\", adapter)\n\n        if match:\n            adapter_id, path, revision = match.groups()\n            adapter_list.append(\n                AdapterInfo(id=adapter_id, path=path, revision=revision)\n            )\n        else:\n            raise ValueError(f\"Invalid LoRA adapter format: {adapter}\")\n    return adapter_list\n\n\ndef load_and_merge_adapters(\n    model_id: str,\n    adapter_parameters: AdapterParameters,\n    adapter_index: int,\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    if len(adapter_parameters.adapter_info) == 1:\n        adapter = next(iter(adapter_parameters.adapter_info))\n        return load_module_map(\n            model_id,\n            adapter.revision,\n            adapter.id,\n            adapter.path,\n            weight_names,\n            trust_remote_code,\n        )\n\n    adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)\n    return _load_and_merge(\n        model_id,\n        adapter_params,\n        weight_names,\n        trust_remote_code,\n    )\n\n\n@dataclass\nclass AdapterParametersContainer:\n    adapter_parameters: AdapterParameters\n    adapter_index: int\n\n    def __hash__(self) -> int:\n        return self.adapter_index\n\n\n@lru_cache(maxsize=32)\ndef _load_and_merge(\n    model_id: str,\n    adapter_params: AdapterParametersContainer,\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    params = adapter_params.adapter_parameters\n\n    adapters_to_merge = []\n    merged_weight_names = set()\n    tokenizer = None\n    for adapter in params.adapter_info:\n        if adapter.id == BASE_MODEL_ADAPTER_ID:\n            raise ValueError(\"Base model adapter cannot be merged.\")\n\n        module_map, adapter_config, adapter_weight_names, adapter_tokenizer = (\n            load_module_map(\n                model_id,\n                adapter.revision,\n                adapter.id,\n                adapter.path,\n                weight_names,\n                trust_remote_code,\n            )\n        )\n\n        adapters_to_merge.append((module_map, adapter_config))\n        merged_weight_names = merged_weight_names.union(adapter_weight_names)\n        if tokenizer is None:\n            tokenizer = adapter_tokenizer\n\n    if len(adapters_to_merge) == 0:\n        raise ValueError(\"No adapters to merge.\")\n\n    module_map, adapter_config = merge_adapters(adapters_to_merge, params)\n    return module_map, adapter_config, merged_weight_names, tokenizer\n\n\ndef check_architectures(\n    model_id: str,\n    adapter_id: str,\n    adapter_config: \"AdapterConfig\",\n    trust_remote_code: bool = False,\n):\n    try:\n        if not adapter_config.base_model_name_or_path:\n            # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)\n            return\n\n        expected_config = AutoConfig.from_pretrained(\n            model_id, trust_remote_code=trust_remote_code\n        )\n        model_config = AutoConfig.from_pretrained(\n            adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code\n        )\n    except Exception as e:\n        warnings.warn(\n            f\"Unable to check architecture compatibility for adapter '{adapter_id}' \"\n            f\"against model '{model_id}'. Assuming they are compatible. Error: {e}\"\n        )\n        return\n\n    if model_config.architectures == expected_config.architectures:\n        warnings.warn(\n            f\"Adapter '{adapter_id}' was not trained on base model '{model_id}'. \"\n            f\"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead.\"\n        )\n    else:\n        # TODO(travis): revisit this when we support clasification heads which will not use CausalLM\n        raise ValueError(\n            f\"Adapter '{adapter_id}' is not compatible with model '{model_id}'. \"\n            f\"Architectures differ: {model_config.architectures} != {expected_config.architectures}. \"\n            f\"Use --model-id '{adapter_config.base_model_name_or_path}' instead.\"\n        )\n\n\n@lru_cache(maxsize=128)\ndef load_module_map(\n    model_id: str,\n    revision: str,\n    adapter_id: str,\n    adapter_path: Optional[str],\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    adapter_config = LoraConfig.load(adapter_path or adapter_id, None)\n\n    if not adapter_path and adapter_config.base_model_name_or_path != model_id:\n        check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)\n\n    adapter_filenames = (\n        hub._weight_files_from_dir(adapter_path, extension=\".safetensors\")\n        if adapter_path\n        else hub._cached_weight_files(\n            adapter_id, revision=revision, extension=\".safetensors\"\n        )\n    )\n\n    # throw an error if no adapter weights are found\n    if not adapter_filenames:\n        raise FileNotFoundError(\n            f\"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'.\"\n        )\n\n    try:\n        adapter_tokenizer = AutoTokenizer.from_pretrained(\n            adapter_config.config_path,\n            trust_remote_code=trust_remote_code,\n        )\n    except Exception:\n        # Adapter does not have a tokenizer, so fallback to base model tokenizer\n        adapter_tokenizer = None\n\n    # load adapter weights from all shards (should have relatively small memory footprint)\n    adapter_weights = {}\n    for filename in adapter_filenames:\n        adapter_weights.update(load_file(filename))\n\n    # map the model weights to the relevant adapter weights (LoRA A and B matrices)\n    module_map, adapter_weight_names = adapter_config.map_weights_for_model(\n        adapter_weights, weight_names\n    )\n    return module_map, adapter_config, adapter_weight_names, adapter_tokenizer\n\n\ndef get_attn_weights(i, layer):\n    qkv = layer.self_attn.query_key_value\n    weights = {}\n\n    for k in [\"q\", \"k\", \"v\"]:\n        key = (i, f\"{k}_proj\")\n        value = (f\"model.layers.{i}.self_attn.{k}_proj\", qkv)\n        weights[key] = value\n\n    # also add the qkv_proj weight for the adapter\n    weights[(i, \"qkv_proj\")] = (\n        f\"model.layers.{i}.self_attn.qkv_proj\",\n        qkv,\n    )\n\n    weights[(i, \"o_proj\")] = (\n        f\"model.layers.{i}.self_attn.o_proj\",\n        layer.self_attn.o_proj,\n    )\n\n    return weights\n\n\ndef get_mlp_weights(i, layer):\n    weights = {}\n    if hasattr(layer, \"mlp\"):\n        mlp = layer.mlp\n        if hasattr(mlp, \"gate_up_proj\"):\n            # handle combined gate_up_proj (e.g., for some LLaMA variants)\n            weights.update(\n                {\n                    (i, \"gate_proj\"): (\n                        f\"model.layers.{i}.mlp.gate_proj\",\n                        mlp.gate_up_proj,\n                    ),\n                    (i, \"up_proj\"): (f\"model.layers.{i}.mlp.up_proj\", mlp.gate_up_proj),\n                }\n            )\n        else:\n            # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)\n            if hasattr(mlp, \"gate_proj\"):\n                weights[(i, \"gate_proj\")] = (\n                    f\"model.layers.{i}.mlp.gate_proj\",\n                    mlp.gate_proj,\n                )\n            if hasattr(mlp, \"up_proj\"):\n                weights[(i, \"up_proj\")] = (f\"model.layers.{i}.mlp.up_proj\", mlp.up_proj)\n\n        if hasattr(mlp, \"down_proj\"):\n            weights[(i, \"down_proj\")] = (\n                f\"model.layers.{i}.mlp.down_proj\",\n                mlp.down_proj,\n            )\n\n    return weights\n\n\n# build_layer_weight_lookup creates a mapping of model layers to their corresponding\n# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples\n# containing the weight tensor path and the actual layer object. This mapping is needed\n# for the lora adapter to know which weights to update when applying the adapter.\ndef build_layer_weight_lookup(model):\n    if hasattr(model, \"language_model\"):\n        m = model.language_model.model\n    elif hasattr(model, \"text_model\"):\n        m = model.text_model.model\n    else:\n        m = model.model\n\n    layer_weights = {}\n\n    for i, layer in enumerate(m.layers):\n        attn_weights = get_attn_weights(i, layer)\n        mlp_weights = get_mlp_weights(i, layer)\n\n        layer_weights.update(attn_weights)\n        layer_weights.update(mlp_weights)\n\n    lm_head = None\n    if hasattr(m, \"lm_head\"):\n        lm_head = m.lm_head\n    elif hasattr(model, \"lm_head\"):\n        lm_head = model.lm_head\n\n    if lm_head:\n        layer_weights[(0, \"lm_head\")] = (\"lm_head\", lm_head)\n\n    return layer_weights\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/chunks.py",
    "content": "from typing import Iterable\n\nfrom loguru import logger\n\nfrom text_generation_server.pb import generate_pb2\n\n\ndef concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:\n    \"\"\"\n    Concatenate text in text chunks. Non-text chunks are dropped.\n    \"\"\"\n    text = None\n    for chunk in chunks:\n        chunk_type = chunk.WhichOneof(\"chunk\")\n        if chunk_type == \"text\":\n            if text is None:\n                text = chunk.text\n            else:\n                raise NotImplementedError(\"Request contained more than one text chunk\")\n        else:\n            # We cannot reject this, e.g. warmup sends an image chunk.\n            logger.debug(f\"Encountered non-text chunk type {chunk_type}\")\n\n    if text is None:\n        raise NotImplementedError(\"Request without a text chunk\")\n\n    return text\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/convert.py",
    "content": "import datetime\nimport torch\nimport os\n\nfrom loguru import logger\nfrom pathlib import Path\nfrom safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete\nfrom typing import List, Dict\nfrom collections import defaultdict\n\n\ndef _remove_duplicate_names(\n    state_dict: Dict[str, torch.Tensor],\n    *,\n    preferred_names: List[str] = None,\n    discard_names: List[str] = None,\n) -> Dict[str, List[str]]:\n    if preferred_names is None:\n        preferred_names = []\n    preferred_names = set(preferred_names)\n    if discard_names is None:\n        discard_names = []\n    discard_names = set(discard_names)\n\n    shareds = _find_shared_tensors(state_dict)\n    to_remove = defaultdict(list)\n    for shared in shareds:\n        complete_names = set(\n            [name for name in shared if _is_complete(state_dict[name])]\n        )\n        if not complete_names:\n            if len(shared) == 1:\n                # Force contiguous\n                name = list(shared)[0]\n                state_dict[name] = state_dict[name].clone()\n                complete_names = {name}\n            else:\n                raise RuntimeError(\n                    f\"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue.\"\n                )\n\n        keep_name = sorted(list(complete_names))[0]\n\n        # Mecanism to preferentially select keys to keep\n        # coming from the on-disk file to allow\n        # loading models saved with a different choice\n        # of keep_name\n        preferred = complete_names.difference(discard_names)\n        if preferred:\n            keep_name = sorted(list(preferred))[0]\n\n        if preferred_names:\n            preferred = preferred_names.intersection(complete_names)\n            if preferred:\n                keep_name = sorted(list(preferred))[0]\n        for name in sorted(shared):\n            if name != keep_name:\n                to_remove[keep_name].append(name)\n    return to_remove\n\n\ndef convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):\n    \"\"\"\n    Convert a pytorch file to a safetensors file\n    This will remove duplicate tensors from the file.\n\n    Unfortunately, this might not respect *transformers* convention.\n    Forcing us to check for potentially different keys during load when looking\n    for specific tensors (making tensor sharing explicit).\n    \"\"\"\n    loaded = torch.load(pt_file, map_location=\"cpu\", weights_only=True)\n    if \"state_dict\" in loaded:\n        loaded = loaded[\"state_dict\"]\n    to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)\n\n    metadata = {\"format\": \"pt\"}\n    for kept_name, to_remove_group in to_removes.items():\n        for to_remove in to_remove_group:\n            if to_remove not in metadata:\n                metadata[to_remove] = kept_name\n            del loaded[to_remove]\n    # Force tensors to be contiguous\n    loaded = {k: v.contiguous() for k, v in loaded.items()}\n\n    dirname = os.path.dirname(sf_file)\n    os.makedirs(dirname, exist_ok=True)\n    save_file(loaded, sf_file, metadata=metadata)\n    reloaded = load_file(sf_file)\n    for k in loaded:\n        pt_tensor = loaded[k]\n        sf_tensor = reloaded[k]\n        if not torch.equal(pt_tensor, sf_tensor):\n            raise RuntimeError(f\"The output tensors do not match for key {k}\")\n\n\ndef convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):\n    assert len(pt_files) == len(sf_files)\n\n    N = len(pt_files)\n    # We do this instead of using tqdm because we want to parse the logs with the launcher\n\n    for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):\n        # Skip blacklisted files\n        if (\n            \"arguments\" in pt_file.name\n            or \"args\" in pt_file.name\n            or \"training\" in pt_file.name\n        ):\n            continue\n\n        start = datetime.datetime.now()\n        convert_file(pt_file, sf_file, discard_names)\n        elapsed = datetime.datetime.now() - start\n        logger.info(f\"Convert: [{i + 1}/{N}] -- Took: {elapsed}\")\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/debug.py",
    "content": "# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.\n\nimport os\nimport glob\nimport time\n\nimport habana_frameworks.torch as htorch\nimport numpy as np\n\nSTART_TS = None\nDBG_TRACE_FILENAME = os.environ.get(\"DBG_TRACE_FILENAME\")\nif \"GRAPH_VISUALIZATION\" in os.environ:\n    for f in glob.glob(\".graph_dumps/*\"):\n        os.remove(f)\n\n\ndef to_gb_rounded(mem: float) -> float:\n    \"\"\"\n    Rounds and converts to GB.\n\n    Args:\n        mem (float): memory in bytes\n\n    Returns:\n        float: memory in GB rounded to the second decimal\n    \"\"\"\n    return np.round(mem / 1024**3, 2)\n\n\ndef count_hpu_graphs():\n    return len(glob.glob(\".graph_dumps/*PreGraph*\"))\n\n\ndef dbg_trace(tag, txt):\n    global START_TS\n    if DBG_TRACE_FILENAME is not None and int(os.getenv(\"RANK\", 0)) == 0:\n        if START_TS is None:\n            START_TS = time.perf_counter()\n        time_offset = time.perf_counter() - START_TS\n        mem_stats = htorch.hpu.memory.memory_stats()\n        mem_used = to_gb_rounded(mem_stats[\"InUse\"])\n        max_mem_used = to_gb_rounded(mem_stats[\"MaxInUse\"])\n        print(\n            f\"ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB \"\n            f\"mmu:{max_mem_used:.1f}GB | {tag} | {txt}\",\n            flush=True,\n            file=open(DBG_TRACE_FILENAME, \"a\"),\n        )\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/dist.py",
    "content": "import os\nimport torch\nfrom torch.distributed import ProcessGroup\nfrom datetime import timedelta\nfrom loguru import logger\n\n# Tensor Parallelism settings\nRANK = int(os.getenv(\"RANK\", \"0\"))\nWORLD_SIZE = int(os.getenv(\"WORLD_SIZE\", \"1\"))\nMEMORY_FRACTION = float(os.getenv(\"HPU_MEMORY_FRACTION\", \"0.8\"))\n\n\nclass FakeBarrier:\n    def wait(self):\n        pass\n\n\nclass FakeGroup(ProcessGroup):\n    def __init__(self, rank, size):\n        self._rank = rank\n        self._size = size\n        super().__init__(rank, size)\n\n    def allreduce(self, *args, **kwargs):\n        return FakeBarrier()\n\n    def allgather(self, inputs, local_tensor, **kwargs):\n        assert (\n            len(inputs[0]) == len(local_tensor) == 1\n        ), f\"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors\"\n        for input_ in inputs:\n            input_[0].data = local_tensor[0].data\n        return FakeBarrier()\n\n    def barrier(self, *args, **kwargs):\n        return FakeBarrier()\n\n    def size(self):\n        return self._size\n\n    def rank(self):\n        return self._rank\n\n    def _get_backend_name(self):\n        return \"fake\"\n\n\ndef initialize_torch_distributed():\n    if WORLD_SIZE == 1:\n        return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE\n    else:\n        if os.getenv(\"DEBUG\", None) == \"1\":\n            return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE\n\n        if not torch.distributed.is_initialized():\n            # Call the init process.\n            torch.distributed.init_process_group(\n                backend=\"hccl\",\n                world_size=WORLD_SIZE,\n                rank=RANK,\n                timeout=timedelta(seconds=120),\n            )\n        else:\n            logger.warning(\"torch.distributed is already initialized.\")\n\n        return torch.distributed.group.WORLD, RANK, WORLD_SIZE\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/hub.py",
    "content": "import time\nimport os\n\nfrom datetime import timedelta\nfrom loguru import logger\nfrom pathlib import Path\nfrom typing import Optional, List\n\nfrom huggingface_hub import file_download, hf_api, HfApi, hf_hub_download\nfrom huggingface_hub.constants import HUGGINGFACE_HUB_CACHE\nfrom huggingface_hub.utils import (\n    LocalEntryNotFoundError,\n    EntryNotFoundError,\n    RevisionNotFoundError,  # noqa # Import here to ease try/except in other part of the lib\n)\n\nWEIGHTS_CACHE_OVERRIDE = os.getenv(\"WEIGHTS_CACHE_OVERRIDE\", None)\nHF_HUB_OFFLINE = os.environ.get(\"HF_HUB_OFFLINE\", \"0\").lower() in [\"true\", \"1\", \"yes\"]\n\n\ndef _cached_weight_files(\n    model_id: str, revision: Optional[str], extension: str\n) -> List[str]:\n    \"\"\"Guess weight files from the cached revision snapshot directory\"\"\"\n    d = _get_cached_revision_directory(model_id, revision)\n    if not d:\n        return []\n    filenames = _weight_files_from_dir(d, extension)\n    return filenames\n\n\ndef _weight_hub_files_from_model_info(\n    info: hf_api.ModelInfo, extension: str\n) -> List[str]:\n    return [\n        s.rfilename\n        for s in info.siblings\n        if s.rfilename.endswith(extension)\n        and len(s.rfilename.split(\"/\")) == 1\n        and \"arguments\" not in s.rfilename\n        and \"args\" not in s.rfilename\n        and \"training\" not in s.rfilename\n    ]\n\n\ndef _weight_files_from_dir(d: Path, extension: str) -> List[str]:\n    # os.walk: do not iterate, just scan for depth 1, not recursively\n    # see _weight_hub_files_from_model_info, that's also what is\n    # done there with the len(s.rfilename.split(\"/\")) == 1 condition\n    root, _, files = next(os.walk(str(d)))\n    filenames = [\n        os.path.join(root, f)\n        for f in files\n        if f.endswith(extension)\n        and \"arguments\" not in f\n        and \"args\" not in f\n        and \"training\" not in f\n    ]\n    return filenames\n\n\ndef _get_cached_revision_directory(\n    model_id: str, revision: Optional[str]\n) -> Optional[Path]:\n    if revision is None:\n        revision = \"main\"\n\n    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(\n        file_download.repo_folder_name(repo_id=model_id, repo_type=\"model\")\n    )\n\n    if not repo_cache.is_dir():\n        # No cache for this model\n        return None\n\n    refs_dir = repo_cache / \"refs\"\n    snapshots_dir = repo_cache / \"snapshots\"\n\n    # Resolve refs (for instance to convert main to the associated commit sha)\n    if refs_dir.is_dir():\n        revision_file = refs_dir / revision\n        if revision_file.exists():\n            with revision_file.open() as f:\n                revision = f.read()\n\n    # Check if revision folder exists\n    if not snapshots_dir.exists():\n        return None\n    cached_shas = os.listdir(snapshots_dir)\n    if revision not in cached_shas:\n        # No cache for this revision and we won't try to return a random revision\n        return None\n\n    return snapshots_dir / revision\n\n\ndef weight_hub_files(\n    model_id: str, revision: Optional[str] = None, extension: str = \".safetensors\"\n) -> List[str]:\n    \"\"\"Get the weights filenames on the hub\"\"\"\n    api = HfApi()\n\n    if HF_HUB_OFFLINE:\n        filenames = _cached_weight_files(model_id, revision, extension)\n    else:\n        # Online case, fetch model info from the Hub\n        info = api.model_info(model_id, revision=revision)\n        filenames = _weight_hub_files_from_model_info(info, extension)\n\n    if not filenames:\n        raise EntryNotFoundError(\n            f\"No {extension} weights found for model {model_id} and revision {revision}.\",\n            None,\n        )\n\n    return filenames\n\n\ndef try_to_load_from_cache(\n    model_id: str, revision: Optional[str], filename: str\n) -> Optional[Path]:\n    \"\"\"Try to load a file from the Hugging Face cache\"\"\"\n\n    d = _get_cached_revision_directory(model_id, revision)\n    if not d:\n        return None\n\n    # Check if file exists in cache\n    cached_file = d / filename\n    return cached_file if cached_file.is_file() else None\n\n\ndef weight_files(\n    model_id: str, revision: Optional[str] = None, extension: str = \".safetensors\"\n) -> List[Path]:\n    \"\"\"Get the local files\"\"\"\n    # Local model\n    d = Path(model_id)\n    if d.exists() and d.is_dir():\n        local_files = _weight_files_from_dir(d, extension)\n        if not local_files:\n            raise FileNotFoundError(\n                f\"No local weights found in {model_id} with extension {extension}\"\n            )\n        return [Path(f) for f in local_files]\n\n    try:\n        filenames = weight_hub_files(model_id, revision, extension)\n    except EntryNotFoundError as e:\n        if extension != \".safetensors\":\n            raise e\n        # Try to see if there are pytorch weights\n        pt_filenames = weight_hub_files(model_id, revision, extension=\".bin\")\n        # Change pytorch extension to safetensors extension\n        # It is possible that we have safetensors weights locally even though they are not on the\n        # hub if we converted weights locally without pushing them\n        filenames = [\n            f\"{Path(f).stem.lstrip('pytorch_')}.safetensors\" for f in pt_filenames\n        ]\n\n    if WEIGHTS_CACHE_OVERRIDE is not None:\n        files = []\n        for filename in filenames:\n            p = Path(WEIGHTS_CACHE_OVERRIDE) / filename\n            if not p.exists():\n                raise FileNotFoundError(\n                    f\"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}.\"\n                )\n            files.append(p)\n        return files\n\n    files = []\n    for filename in filenames:\n        cache_file = try_to_load_from_cache(\n            model_id, revision=revision, filename=filename\n        )\n        if cache_file is None:\n            raise LocalEntryNotFoundError(\n                f\"File {filename} of model {model_id} not found in \"\n                f\"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. \"\n                f\"Please run `text-generation-server download-weights {model_id}` first.\"\n            )\n        files.append(cache_file)\n\n    return files\n\n\ndef download_weights(\n    filenames: List[str], model_id: str, revision: Optional[str] = None\n) -> List[Path]:\n    \"\"\"Download the safetensors files from the hub\"\"\"\n\n    def download_file(fname, tries=5, backoff: int = 5):\n        local_file = try_to_load_from_cache(model_id, revision, fname)\n        if local_file is not None:\n            logger.info(f\"File {fname} already present in cache.\")\n            return Path(local_file)\n\n        for idx in range(tries):\n            try:\n                logger.info(f\"Download file: {fname}\")\n                stime = time.time()\n                local_file = hf_hub_download(\n                    filename=fname,\n                    repo_id=model_id,\n                    revision=revision,\n                    local_files_only=HF_HUB_OFFLINE,\n                )\n                logger.info(\n                    f\"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}.\"\n                )\n                return Path(local_file)\n            except Exception as e:\n                if idx + 1 == tries:\n                    raise e\n                logger.error(e)\n                logger.info(f\"Retrying in {backoff} seconds\")\n                time.sleep(backoff)\n                logger.info(f\"Retry {idx + 1}/{tries - 1}\")\n\n    # We do this instead of using tqdm because we want to parse the logs with the launcher\n    start_time = time.time()\n    files = []\n    for i, filename in enumerate(filenames):\n        file = download_file(filename)\n\n        elapsed = timedelta(seconds=int(time.time() - start_time))\n        remaining = len(filenames) - (i + 1)\n        eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0\n\n        logger.info(f\"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}\")\n        files.append(file)\n\n    return files\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/import_utils.py",
    "content": "import torch\n\n\ndef get_hpu_free_memory(device, memory_fraction):\n    free_hpu_memory, _ = torch.hpu.mem_get_info()\n    return free_hpu_memory\n\n\ndef synchronize_hpu(device):\n    torch.hpu.synchronize()\n\n\ndef noop(*args, **kwargs):\n    pass\n\n\nempty_cache = noop\nsynchronize = synchronize_hpu\nget_free_memory = get_hpu_free_memory\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/kernels.py",
    "content": "import importlib\n\nfrom loguru import logger\nfrom hf_kernels import load_kernel as hf_load_kernel\n\nfrom text_generation_server.utils.log import log_once\n\n\ndef load_kernel(*, module: str, repo_id: str):\n    \"\"\"\n    Load a kernel. First try to load it as the given module (e.g. for\n    local development), falling back to a locked Hub kernel.\n    \"\"\"\n    try:\n        m = importlib.import_module(module)\n        log_once(logger.info, f\"Using local module for `{module}`\")\n        return m\n    except ModuleNotFoundError:\n        return hf_load_kernel(repo_id=repo_id)\n\n\n__all__ = [\"load_kernel\"]\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/log.py",
    "content": "from functools import lru_cache\nfrom text_generation_server.utils.dist import RANK\n\n\n@lru_cache(10)\ndef log_once(log, msg: str, master=True):\n    if master:\n        log_master(log, msg)\n    else:\n        log(msg)\n\n\ndef log_master(log, msg: str):\n    if RANK == 0:\n        log(msg)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/logits_process.py",
    "content": "import math\nimport torch\nimport habana_frameworks.torch.core as htcore\n\nfrom loguru import logger\nfrom typing import Dict\nfrom text_generation_server.pb.generate_pb2 import GrammarType\n\nfrom outlines.fsm.fsm import RegexFSM\nfrom outlines.fsm.json_schema import build_regex_from_schema\nfrom functools import lru_cache\nfrom typing import List, Optional, DefaultDict\nimport time\n\nfrom transformers import (\n    LogitsProcessor,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    TypicalLogitsWarper,\n)\n\nmempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None\n\n\nclass StaticWarper:\n    def __init__(\n        self,\n        temperature=1.0,\n        top_k=None,\n        top_p=None,\n        typical_p=None,\n    ):\n        self.warpers = []\n\n        if temperature is not None and temperature != 1.0:\n            temperature = float(temperature)\n            self.warpers.append(TemperatureLogitsWarper(temperature))\n        if top_k is not None and top_k != 0:\n            self.warpers.append(TopKLogitsWarper(top_k=top_k))\n        if top_p is not None and top_p < 1.0:\n            self.warpers.append(TopPLogitsWarper(top_p=top_p))\n        if typical_p is not None and typical_p < 1.0:\n            self.warpers.append(TypicalLogitsWarper(mass=typical_p))\n\n        self.hpu_graph = None\n        self.static_scores = None\n        self.static_warped_scores = None\n        self.static_next_logprob = None\n\n    def __call__(self, scores):\n        if self.hpu_graph is None:\n            self.static_scores = scores.clone().contiguous()\n            self.static_warped_scores = scores.clone().contiguous()\n            self.static_next_logprob = scores.clone().contiguous()\n            self.hpu_graph = htcore.hpu.HPUGraph()\n\n            with htcore.hpu.graph(self.hpu_graph):\n                local_scores = self.static_scores\n                for warper in self.warpers:\n                    local_scores = warper(None, local_scores)\n\n                self.static_warped_scores.copy_(local_scores)\n                # Compute logprobs\n                self.static_next_logprob.copy_(\n                    torch.log_softmax(self.static_warped_scores, -1)\n                )\n\n        self.static_scores.copy_(scores)\n        self.hpu_graph.replay()\n\n        return self.static_warped_scores, self.static_next_logprob\n\n\n@lru_cache(10)\ndef static_warper(\n    temperature: Optional[float],\n    top_k: Optional[int],\n    top_p: Optional[float],\n    typical_p: Optional[float],\n) -> StaticWarper:\n    return StaticWarper(\n        temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p\n    )\n\n\nclass HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        repetition_penalty (`List[float]`):\n            The parameter for repetition penalty. 1.0 means no penalty. See [this\n            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    \"\"\"\n\n    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):\n        self.penalty = penalty\n        self.penalty_tensor = torch.tensor(\n            penalty, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        score = torch.gather(scores, 1, input_ids)\n\n        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability\n        score = torch.where(\n            score < 0, score * self.penalty_tensor, score / self.penalty_tensor\n        )\n\n        scores.scatter_(1, input_ids, score)\n        return scores\n\n    def filter(self, indices):\n        self.penalty = [self.penalty[i] for i in indices]\n        if any([x != 1.0 for x in self.penalty]):\n            self.penalty_tensor = self.penalty_tensor[indices]\n            return self\n        return None\n\n\nclass FrequencyPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    Frequency penalty as defined by OpenAI\n\n    Args:\n        penalty (`float`):\n            The parameter for frequency penalty. 0.0 means no penalty.\n    \"\"\"\n\n    def __init__(self, penalty: float):\n        self.penalty = penalty\n\n    def __call__(\n        self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        score = torch.gather(scores, 1, input_ids)\n        # if score < 0 then penalty has to be multiplied to reduce the previous token probability\n        score = -torch.where(score < 0, score * self.penalty, score / self.penalty)\n        # set score to 0 where input_ids is a padding token\n        score *= input_ids.ne(0)\n\n        return scores.scatter_add_(1, input_ids, score)\n\n\nclass HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    Frequency penalty as defined by OpenAI in\n    https://platform.openai.com/docs/guides/text-generation/parameter-details\n\n    Args:\n        frequency_penalty (`List[float]`):\n            The parameter for frequency penalty. 0.0 means no penalty.\n    \"\"\"\n\n    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):\n        self.penalty = penalty\n        self.penalty_tensor = torch.tensor(\n            penalty, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        batch_size, input_size = input_ids.size()\n        vocab_size = scores.size(1)\n\n        # Calculate the frequency for each token so far\n        token_freq = torch.zeros(\n            batch_size, vocab_size, dtype=scores.dtype, device=scores.device\n        )\n        token_freq.scatter_add_(\n            1,\n            input_ids,\n            torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device),\n        )\n        token_freq /= input_size\n\n        # Apply the frequency penalty to logits\n        scores -= token_freq * self.penalty_tensor\n        return scores\n\n    def filter(self, indices):\n        self.penalty = [self.penalty[i] for i in indices]\n        if any([x != 0.0 for x in self.penalty]):\n            self.penalty_tensor = self.penalty_tensor[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTemperatureLogitsWarper:\n    r\"\"\"\n    [`LogitsProcessor`] for temperature (exponential scaling output probability distribution).\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        temperature (`float`):\n            The value used to module the logits distribution.\n    \"\"\"\n\n    def __init__(\n        self, temperature: List[float], dtype: torch.dtype, device: torch.device\n    ):\n        self.temperature = temperature\n        self.temperature_tensor = torch.tensor(\n            temperature, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        scores.div_(self.temperature_tensor)\n        return scores\n\n    def filter(self, indices):\n        self.temperature = [self.temperature[i] for i in indices]\n        if any([x != 1.0 for x in self.temperature]):\n            self.temperature_tensor = self.temperature_tensor[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTopPLogitsWarper(LogitsProcessor):\n    \"\"\"\n    [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        top_p (`float`):\n            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n            higher are kept for generation.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        top_p: List[float],\n        dtype: torch.dtype,\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.top_p = top_p\n        self.top_p_opposite = 1 - torch.tensor(\n            top_p, dtype=dtype, device=device\n        ).unsqueeze(1)\n        self.filter_value = filter_value\n        self.min_tokens_to_keep = min_tokens_to_keep\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        sorted_logits, sorted_indices = torch.sort(scores, descending=False)\n        probs = sorted_logits.softmax(dim=-1)\n        # This is way faster for some reason\n        for i in range(probs.shape[0]):\n            probs[i] = probs[i].cumsum(dim=-1)\n\n        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n        sorted_indices_to_remove = probs <= self.top_p_opposite\n        # Keep at least min_tokens_to_keep\n        sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0\n\n        # scatter sorted tensors to original indexing\n        indices_to_remove = sorted_indices_to_remove.scatter(\n            1, sorted_indices, sorted_indices_to_remove\n        )\n        warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)\n\n        return warped_scores\n\n    def filter(self, indices):\n        self.top_p = [self.top_p[i] for i in indices]\n        if any([x < 1.0 for x in self.top_p]):\n            self.top_p_opposite = self.top_p_opposite[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTopKLogitsWarper(LogitsProcessor):\n    r\"\"\"\n    [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        top_k (`int`):\n            The number of highest probability vocabulary tokens to keep for top-k-filtering.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        top_k: List[int],\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.top_k = top_k\n        self.max_top_k = max(top_k)\n        # value - 1 as we will use top_k to index and python uses 0 based numbering\n        self.top_k_tensor = torch.tensor(\n            [max(x - 1, min_tokens_to_keep - 1) for x in top_k],\n            dtype=torch.int64,\n            device=device,\n        ).unsqueeze(1)\n\n        # 0 is a special value that disables top_k warping for this member of the batch\n        disabled = [x == 0 for x in top_k]\n\n        if any(disabled):\n            self.top_k_disabled_mask = torch.tensor(\n                disabled, dtype=torch.bool, device=device\n            ).view(-1, 1)\n        else:\n            self.top_k_disabled_mask = None\n\n        self.filter_value = filter_value\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        # If max_top_k is superior to the vocab, we need to clamp or the warper will fail\n        if scores.size(-1) < self.max_top_k:\n            max_top_k = scores.size(-1)\n            top_k = torch.clamp_max(self.top_k_tensor, max_top_k)\n        else:\n            max_top_k = self.max_top_k\n            top_k = self.top_k_tensor\n\n        # Get the kth score for each member of the batch\n        kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)\n\n        # Mask member of kth_scores that do not want to use top_k warping\n        if self.top_k_disabled_mask is not None:\n            kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)\n\n        # Remove all tokens with a probability less than the last token of the top-k\n        indices_to_remove = scores < kth_scores\n        scores.masked_fill_(indices_to_remove, self.filter_value)\n        return scores\n\n    def filter(self, indices):\n        self.top_k = [self.top_k[i] for i in indices]\n        disabled = [x == 0 for x in self.top_k]\n\n        if not all(disabled):\n            self.top_k_tensor = self.top_k_tensor[indices]\n            self.max_top_k = max(self.top_k)\n\n            if self.top_k_disabled_mask is not None:\n                self.top_k_disabled_mask = (\n                    self.top_k_disabled_mask[indices] if any(disabled) else None\n                )\n\n            return self\n        return None\n\n\nclass HeterogeneousTypicalLogitsWarper(LogitsProcessor):\n    r\"\"\"\n    [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language\n    Generation](https://arxiv.org/abs/2202.00666) for more information.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        mass (`float`):\n            Value of typical_p between 0 and 1 inclusive, defaults to 0.9.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        mass: List[float],\n        dtype: torch.dtype,\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.mass = mass\n        self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)\n\n        # 1 is a special value that disables typical_p warping for this member of the batch\n        disabled = [x == 1.0 for x in mass]\n\n        if any(disabled):\n            self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)\n        else:\n            self.disabled_mask = None\n\n        self.filter_value = filter_value\n        self.min_tokens_to_keep = min_tokens_to_keep\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        # calculate entropy\n        normalized = torch.nn.functional.log_softmax(scores, dim=-1)\n        p = torch.exp(normalized)\n        ent = -(normalized * p).nansum(-1, keepdim=True)\n\n        # shift and sort\n        shifted_scores = torch.abs((-normalized) - ent)\n        sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)\n        sorted_logits = scores.gather(-1, sorted_indices)\n        probs = sorted_logits.softmax(dim=-1)\n        # This is way faster for some reason\n        for i in range(probs.shape[0]):\n            probs[i] = probs[i].cumsum(dim=-1)\n\n        # Remove tokens with cumulative mass above the threshold\n        last_ind = (probs < self.mass_tensor).sum(dim=1)\n        last_ind[last_ind < 0] = 0\n\n        if self.disabled_mask is not None:\n            last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)\n\n        sorted_indices_to_remove = sorted_scores > sorted_scores.gather(\n            1, last_ind.view(-1, 1)\n        )\n        if self.min_tokens_to_keep > 1:\n            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)\n            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0\n        indices_to_remove = sorted_indices_to_remove.scatter(\n            1, sorted_indices, sorted_indices_to_remove\n        )\n\n        warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)\n\n        return warped_scores\n\n    def filter(self, indices):\n        self.mass = [self.mass[i] for i in indices]\n        disabled = [x == 1.0 for x in self.mass]\n\n        if not all(disabled):\n            self.mass_tensor = self.mass_tensor[indices]\n\n            if self.disabled_mask is not None:\n                self.disabled_mask = (\n                    self.disabled_mask[indices] if any(disabled) else None\n                )\n\n            return self\n        return None\n\n\nclass HeterogeneousProcessorWrapper(LogitsProcessor):\n    r\"\"\"\n    A wrapper for logit warpers or processors without heterogeneous parameter support.\n    Args:\n        processors (`Dict[int, LogitsProcessor]`):\n            A mapping of sample indices to logit warpers or processors, to be run sequentially.\n    \"\"\"\n\n    def __init__(\n        self,\n        processors: Dict[int, LogitsProcessor],\n    ):\n        self.processors = processors\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        for i, processor in self.processors.items():\n            scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])\n        return scores\n\n    def filter(self, indices):\n        new_processors = {}\n        for i, idx in enumerate(indices):\n            if idx in self.processors:\n                new_processors[i] = self.processors[idx]\n\n        if new_processors:\n            self.processors = new_processors\n            return self\n        return None\n\n\nclass GrammarLogitProcessor(LogitsProcessor):\n    fsm_state: DefaultDict[int, int]\n    fsm: RegexFSM\n\n    def __init__(self, tokenizer, device, grammar, grammar_type):\n        self.device = device\n        self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)\n        self.fsm = GrammarLogitProcessor._cached_compile_fsm(\n            grammar_type, grammar, self.tokenizer\n        )\n\n    def __call__(\n        self,\n        logits: torch.Tensor,\n        fsm_grammar_state: int,\n    ):\n        if fsm_grammar_state == -1 or self.fsm is None:\n            return logits\n        allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)\n        mask = torch.full_like(logits, -math.inf)\n        mask[:, allowed_tokens] = 0\n        biased_scores = logits + mask\n        return biased_scores\n\n    def advance(self, next_token_id, fsm_grammar_state):\n        return GrammarLogitProcessor._advance(\n            next_token_id, fsm_grammar_state, self.fsm\n        )\n\n    @staticmethod\n    def _advance(next_token_id, fsm_grammar_state, fsm):\n        if fsm_grammar_state == -1:\n            return fsm_grammar_state\n        return fsm.next_state(fsm_grammar_state, next_token_id)\n\n    # TODO: move grammar compilation into the router\n    @staticmethod\n    @lru_cache(maxsize=32, typed=True)\n    def _cached_compile_fsm(grammar_type, schema, tokenizer):\n        start_time = time.time()\n        if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:\n            schema = build_regex_from_schema(schema)\n        elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:\n            pass  # schema is already a regex just here for clarity\n        fsm = RegexFSM(schema, tokenizer)\n        logger.debug(f\"Compiled FSM in {time.time() - start_time:.2f}s\")\n        return fsm\n\n    @staticmethod\n    @lru_cache(maxsize=32, typed=True)\n    def _cached_adapt_tokenizer(tokenizer):\n        \"\"\"Adapt tokenizer to work with the FSM.\n\n        The API of Outlines tokenizers is slightly different to that of\n        `transformers`. In addition we need to handle the missing spaces to\n        Llama's tokenizer to be able to compile FSMs for this model.\n\n        \"\"\"\n        start_time = time.time()\n        tokenizer.vocabulary = tokenizer.get_vocab()\n        tokenizer.special_tokens = set(tokenizer.all_special_tokens)\n\n        def convert_token_to_string(token: str) -> str:\n            from transformers.file_utils import SPIECE_UNDERLINE\n\n            string = tokenizer.convert_tokens_to_string([token])\n\n            # A hack to handle missing spaces to HF's Llama tokenizers\n            if token.startswith(SPIECE_UNDERLINE) or token == \"<0x20>\":\n                return \" \" + string\n\n            return string\n\n        tokenizer.convert_token_to_string = convert_token_to_string\n        logger.debug(f\"Adapted tokenizer in {time.time() - start_time:.2f}s\")\n        return tokenizer\n\n\nclass HeterogeneousGrammarLogitProcessor(LogitsProcessor):\n    def __init__(self, tokenizer, device, grammars, grammar_types):\n        self.device = device\n        self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)\n        self.fsms = []\n        for grammar, grammar_type in zip(grammars, grammar_types):\n            if len(grammar) == 0:\n                self.fsms.append(None)\n                continue\n            fsm = GrammarLogitProcessor._cached_compile_fsm(\n                grammar_type, grammar, self.tokenizer\n            )\n            self.fsms.append(fsm)\n\n    def __call__(\n        self,\n        logits: torch.Tensor,\n        fsm_grammar_states: List[int],\n    ):\n        mask = torch.full_like(logits, -math.inf)\n        for i in range(logits.shape[0]):\n            fsm = self.fsms[i]\n            if fsm is None:\n                continue\n            allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])\n            mask[i, allowed_tokens] = 0\n            logits[i] += mask[i]\n        return logits\n\n    def advance_batch(self, next_token_ids, fsm_grammar_states):\n        return [\n            GrammarLogitProcessor._advance(\n                next_token_ids[i], fsm_grammar_states[i], self.fsms[i]\n            )\n            for i in range(len(next_token_ids))\n        ]\n\n    def advance_at_index(self, next_token_id, fsm_grammar_state, index):\n        if self.fsms[index] is None:\n            return fsm_grammar_state\n        return GrammarLogitProcessor._advance(\n            next_token_id, fsm_grammar_state, self.fsms[index]\n        )\n\n    def filter(self, indices):\n        new_fsms = []\n        for i in indices:\n            new_fsms.append(self.fsms[i])\n        self.fsms = new_fsms\n        return self\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/merges/strategies.py",
    "content": "import copy\nfrom abc import ABC\nfrom collections import defaultdict\nfrom typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union\nfrom text_generation_server.utils.merges.utils import (\n    calculate_majority_sign_mask,\n    disjoint_merge,\n    prune,\n)\nimport torch\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters.lora import LoraConfig\n    from text_generation_server.utils.adapter import ModuleMap\n\n\nclass AdapterParameters:\n    def __init__(\n        self, adapter_ids, weights, merge_strategy, density, majority_sign_method\n    ):\n        self.adapter_ids = adapter_ids\n        self.weights = weights\n        self.merge_strategy = merge_strategy\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n\ndef _apply_weights(\n    tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor\n) -> torch.Tensor:\n    if isinstance(tensors, torch.Tensor):\n        t = tensors\n    else:\n        t = torch.stack(tensors, dim=0)\n\n    # element-wise weighting of each task tensor\n    # need to unsqueeze weights to match task tensor dimensions\n    # for multiplication to apply element-wise\n    while len(t.shape) > len(w.shape):\n        w = w.unsqueeze(-1)\n    return t * w\n\n\nclass MergeStrategy(ABC):\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        raise NotImplementedError()\n\n\nclass LinearMerge(MergeStrategy):\n    def __init__(self, **kwargs):\n        pass\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n        return weighted_task_tensors.sum(dim=0)\n\n\nclass TiesMerge(MergeStrategy):\n    def __init__(self, density: float, majority_sign_method: str = \"total\", **kwargs):\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"magnitude\") for tensor in task_tensors\n        ]\n        task_tensors = torch.stack(task_tensors, dim=0)\n\n        # elect sign before applying weights\n        majority_sign_mask = calculate_majority_sign_mask(\n            task_tensors, method=self.majority_sign_method\n        )\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n\n        # disjoint merge\n        return disjoint_merge(weighted_task_tensors, majority_sign_mask)\n\n\nclass DareLinearMerge(MergeStrategy):\n    def __init__(self, density: float, **kwargs):\n        self.density = density\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"random\", rescale=True)\n            for tensor in task_tensors\n        ]\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n        return weighted_task_tensors.sum(dim=0)\n\n\nclass DareTiesMerge(MergeStrategy):\n    def __init__(self, density: float, majority_sign_method: str = \"total\", **kwargs):\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"random\", rescale=True)\n            for tensor in task_tensors\n        ]\n        task_tensors = torch.stack(task_tensors, dim=0)\n\n        # elect sign before applying weights\n        majority_sign_mask = calculate_majority_sign_mask(\n            task_tensors, method=self.majority_sign_method\n        )\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n\n        # disjoint merge\n        mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)\n        return mixed_task_tensors\n\n\nstrategy_registry: Dict[str, Type[MergeStrategy]] = {\n    \"linear\": LinearMerge,\n    \"ties\": TiesMerge,\n    \"dare_linear\": DareLinearMerge,\n    \"dare_ties\": DareTiesMerge,\n}\n\n\ndef merge_adapters(\n    adapters: List[Tuple[\"ModuleMap\", \"LoraConfig\"]],\n    merge_params: AdapterParameters,\n) -> Tuple[\"ModuleMap\", \"LoraConfig\"]:\n    # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()\n    strategy_name = \"linear\"\n\n    weights = merge_params.weights\n    if not weights:\n        weights = torch.ones(len(adapters))\n    else:\n        weights = torch.tensor(weights)\n\n    merge_config = {\n        \"density\": merge_params.density,\n        # \"majority_sign_method\": MajoritySignMethodEnum.Name(\n        #     merge_params.majority_sign_method\n        # ).lower(),\n        \"majority_sign_method\": \"total\",\n    }\n    merge_strategy = strategy_registry[strategy_name](**merge_config)\n\n    module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(\n        lambda: defaultdict(lambda: defaultdict(list))\n    )\n    lora_configs = []\n    weight_name_to_adapter_idx = defaultdict(list)\n\n    # input is list of (module_map, lora_config) tuples\n    # convert into dict[k][param_name] -> list of tensors\n    for idx, (module_map, lora_config) in enumerate(adapters):\n        for weight_name, data in module_map.items():\n            weight_name_to_adapter_idx[weight_name].append(idx)\n            for k, (param_data, param_name) in data.items():\n                module_maps[weight_name][k][param_name].append(param_data)\n        lora_configs.append(lora_config)\n\n    # validate lora configs are compatible\n    _validate_lora_configs(lora_configs)\n\n    # merge tensors for each module such that we have a single ModuleMap:\n    # dict[k] -> merged tensor\n    merged_module_map: \"ModuleMap\" = defaultdict(dict)\n    for weight_name, data in module_maps.items():\n        indices = weight_name_to_adapter_idx[weight_name]\n        param_weights = weights[indices]\n        for k, param_data in data.items():\n            for param_name, tensors in param_data.items():\n                merged_tensor = merge_strategy.merge(tensors, param_weights)\n                merged_module_map[weight_name][k] = (merged_tensor, param_name)\n\n    # merge lora configs\n    merged_lora_config = _merge_lora_configs(lora_configs)\n\n    return merged_module_map, merged_lora_config\n\n\ndef _validate_lora_configs(lora_configs: List[\"LoraConfig\"]):\n    # check that all configs have the same rank\n    ranks = set(lora_config.r for lora_config in lora_configs)\n    if len(ranks) > 1:\n        raise ValueError(\n            f\"unable to merge adapters, lora configs have different ranks: {ranks}\"\n        )\n\n    if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):\n        raise ValueError(\n            \"unable to merge adapters, lora configs have no target modules\"\n        )\n\n\ndef _merge_lora_configs(lora_configs: List[\"LoraConfig\"]) -> \"LoraConfig\":\n    merged_lora_config = copy.copy(lora_configs[0])\n\n    # merge target modules as a union operation\n    merged_target_modules = sorted(\n        set(\n            module\n            for lora_config in lora_configs\n            for module in lora_config.target_modules\n        )\n    )\n    merged_lora_config.target_modules = merged_target_modules\n\n    return merged_lora_config\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/merges/utils.py",
    "content": "# coding=utf-8\n# From: https://github.com/huggingface/peft/pull/1364\n# Copyright 2024-present the HuggingFace Inc. team.\n# Modifications by Predibase, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport torch\n\n\ndef magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:\n    \"\"\"\n    Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction\n    `density`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    \"\"\"\n    mask = torch.zeros_like(tensor).reshape(-1)\n    k = int(density * tensor.reshape(-1).shape[0])\n    top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)\n    mask[top_k[1]] = 1\n    return tensor * mask.reshape(tensor.shape)\n\n\ndef random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:\n    \"\"\"\n    Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction\n    `density`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.\n    \"\"\"\n    mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))\n    pruned_tensor = tensor * mask\n    if rescale:\n        torch.div(input=pruned_tensor, other=density)\n    return pruned_tensor\n\n\ndef prune(\n    tensor: torch.Tensor,\n    density: float,\n    method: Literal[\"magnitude\", \"random\"],\n    rescale: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Prune the values of task tensors based on the `method`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    method (`str`):The method to use to prune. Should be one of [\"magnitude\", \"random\"].\n    rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.\n    \"\"\"\n    if density >= 1:\n        return tensor\n    elif density < 0:\n        raise ValueError(\"Density should be >= 0, got {density}\")\n    if method == \"magnitude\":\n        return magnitude_based_pruning(tensor, density)\n    elif method == \"random\":\n        return random_pruning(tensor, density, rescale=rescale)\n    else:\n        raise ValueError(f\"Unknown method {method}\")\n\n\ndef calculate_majority_sign_mask(\n    tensor: torch.Tensor, method: Literal[\"total\", \"frequency\"] = \"total\"\n):\n    \"\"\"\n    Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to get the mask from.\n    method (`str`):The method to use to get the mask. Should be one of [\"total\", \"frequency\"].\n    \"\"\"\n\n    sign = tensor.sign()\n    if method == \"total\":\n        sign_magnitude = (sign * tensor.abs()).sum(dim=0)\n    elif method == \"frequency\":\n        sign_magnitude = sign.sum(dim=0)\n    else:\n        raise RuntimeError(f'Unimplemented mask method \"{method}\"')\n    majority_sign = torch.where(sign_magnitude >= 0, 1, -1)\n    return sign == majority_sign\n\n\ndef disjoint_merge(task_tensors, majority_sign_mask):\n    mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)\n    num_params_preserved = majority_sign_mask.sum(dim=0)\n    return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/peft.py",
    "content": "import os\nfrom typing import Union\nfrom loguru import logger\nimport torch\n\nfrom transformers import AutoTokenizer\nfrom peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM\n\n\ndef download_and_unload_peft(model_id, revision, trust_remote_code):\n    torch_dtype = torch.float16\n\n    logger.info(\"Trying to load a Peft model. It might take a while without feedback\")\n    try:\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    except Exception:\n        model = AutoPeftModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    logger.info(\"Peft model detected.\")\n    logger.info(\"Merging the lora weights.\")\n\n    base_model_id = model.peft_config[\"default\"].base_model_name_or_path\n\n    model = model.merge_and_unload()\n\n    os.makedirs(model_id, exist_ok=True)\n    cache_dir = model_id\n    logger.info(f\"Saving the newly created merged model to {cache_dir}\")\n    tokenizer = AutoTokenizer.from_pretrained(\n        base_model_id, trust_remote_code=trust_remote_code\n    )\n    model.save_pretrained(cache_dir, safe_serialization=True)\n    model.config.save_pretrained(cache_dir)\n    tokenizer.save_pretrained(cache_dir)\n\n\ndef download_peft(\n    model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool\n):\n    torch_dtype = torch.float16\n    try:\n        _model = AutoPeftModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    except Exception:\n        _model = AutoPeftModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    logger.info(\"Peft model downloaded.\")\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/prefill_chunking.py",
    "content": "from typing import Optional\n\nSUPPORT_CHUNKING: Optional[bool] = None\nMAX_PREFILL_TOKENS: Optional[int] = None\n\n\ndef set_support_chunking(support_chunking: bool):\n    global SUPPORT_CHUNKING\n    SUPPORT_CHUNKING = support_chunking\n\n\ndef get_support_chunking() -> bool:\n    global SUPPORT_CHUNKING\n    return SUPPORT_CHUNKING\n\n\ndef set_max_prefill_tokens(max_prefill_tokens: int):\n    global MAX_PREFILL_TOKENS\n    MAX_PREFILL_TOKENS = max_prefill_tokens\n\n\ndef get_max_prefill_tokens() -> int:\n    global MAX_PREFILL_TOKENS\n    return MAX_PREFILL_TOKENS\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/quantization.py",
    "content": "import json\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional, List\n\nfrom huggingface_hub import hf_hub_download\nfrom text_generation_server.utils.weights import (\n    WeightsLoader,\n)\n\n\n# TODO: Split this config to have a single config type per quant method\n@dataclass\nclass _QuantizerConfig:\n    bits: int\n    checkpoint_format: Optional[str]\n    desc_act: bool\n    groupsize: int\n    quant_method: str\n    sym: bool\n    weight_block_size: Optional[List[int]]\n    modules_to_not_convert: List[str]\n\n\n@dataclass\nclass _FP8QuantizerConfig:\n    activation_scale_ub: float\n\n\ndef _get_config_json(model_id: str, revision: Optional[str], filename: str):\n    if os.path.exists(\n        os.path.join(\n            model_id,\n        )\n    ):\n        filename = os.path.join(model_id, filename)\n    else:\n        filename = hf_hub_download(model_id, filename=filename, revision=revision)\n    with open(filename, \"r\") as f:\n        return json.load(f)\n\n\n# We should probably do this with Pydantic JSON deserialization,\n# but for now we'll stay close to the old _set_gptq_params.\ndef _get_quantizer_config(model_id, revision):\n    bits = 4\n    groupsize = -1\n    quant_method = \"gptq\"\n    checkpoint_format = None\n    sym = False\n    desc_act = False\n    weight_block_size = None\n    modules_to_not_convert = []\n\n    filename = \"config.json\"\n    try:\n        data = _get_config_json(model_id, revision, filename)\n        # FP8 config\n        if data[\"quantization_config\"][\"quant_method\"] == \"fbgemm_fp8\":\n            return _FP8QuantizerConfig(\n                activation_scale_ub=data[\"quantization_config\"][\"activation_scale_ub\"]\n            )\n        weight_block_size = data[\"quantization_config\"].get(\"weight_block_size\", None)\n\n        if \"zero_point\" in data[\"quantization_config\"]:\n            sym = not data[\"quantization_config\"][\"zero_point\"]\n            quant_method = \"awq\"\n        elif \"sym\" in data[\"quantization_config\"]:\n            sym = data[\"quantization_config\"][\"sym\"]\n\n        bits = data[\"quantization_config\"][\"bits\"]\n        groupsize = data[\"quantization_config\"][\"group_size\"]\n        # Order is important here, desc_act is missing on some real models\n        quant_method = data[\"quantization_config\"][\"quant_method\"]\n        checkpoint_format = data[\"quantization_config\"].get(\"checkpoint_format\")\n        desc_act = data[\"quantization_config\"].get(\"desc_act\", False)\n        modules_to_not_convert = data[\"quantization_config\"].get(\n            \"modules_to_not_convert\", []\n        )\n        if modules_to_not_convert is None:\n            modules_to_not_convert = []\n    except Exception:\n        filename = \"quantize_config.json\"\n        try:\n            data = _get_config_json(model_id, revision, filename)\n            bits = data[\"bits\"]\n            groupsize = data[\"group_size\"]\n\n            if \"zero_point\" in data:\n                sym = not data[\"zero_point\"]\n                quant_method = \"awq\"\n            elif \"sym\" in data:\n                sym = data[\"sym\"]\n\n            desc_act = data[\"desc_act\"]\n            if \"version\" in data and data[\"version\"] == \"GEMM\":\n                quant_method = \"awq\"\n        except Exception:\n            filename = \"quant_config.json\"\n            try:\n                data = _get_config_json(model_id, revision, filename)\n                bits = data[\"w_bit\"]\n                groupsize = data[\"q_group_size\"]\n                desc_act = data[\"desc_act\"]\n                if \"version\" in data and data[\"version\"] == \"GEMM\":\n                    quant_method = \"awq\"\n            except Exception:\n                pass\n\n    return _QuantizerConfig(\n        bits=bits,\n        groupsize=groupsize,\n        quant_method=quant_method,\n        checkpoint_format=checkpoint_format,\n        sym=sym,\n        desc_act=desc_act,\n        weight_block_size=weight_block_size,\n        modules_to_not_convert=modules_to_not_convert,\n    )\n\n\ndef get_loader(\n    quantize: Optional[str], model_id: str, revision: Optional[str]\n) -> WeightsLoader:\n    if quantize == \"compressed-tensors\":\n        config = _get_config_json(model_id, revision, \"config.json\")\n        from text_generation_server.layers.compressed_tensors import (\n            CompressedTensorsLoader,\n        )\n\n        return CompressedTensorsLoader(config)\n    quantizer_config = _get_quantizer_config(model_id, revision)\n    if quantize in {\"awq\", \"gptq\"}:\n        from text_generation_server.layers.gptq import GPTQWeightsLoader\n\n        # TODO: improve check once we have one config type per quantize value\n        if not isinstance(quantizer_config, _QuantizerConfig):\n            raise ValueError(\n                f\"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config.\"\n            )\n\n        return GPTQWeightsLoader(\n            bits=quantizer_config.bits,\n            desc_act=quantizer_config.desc_act,\n            groupsize=quantizer_config.groupsize,\n            quant_method=quantizer_config.quant_method,\n            quantize=quantize,\n            sym=quantizer_config.sym,\n            modules_to_not_convert=quantizer_config.modules_to_not_convert,\n        )\n    elif quantize == \"fp8\" or quantize is None:\n        from text_generation_server.layers.fp8 import HybridFP8UnquantLoader\n\n        # Since the default for the quantize config is _QuantizerConfig,\n        # we need to add this check to not get an attribute error\n        activation_scale_ub = None\n        weight_block_size = quantizer_config.weight_block_size\n        if isinstance(quantizer_config, _FP8QuantizerConfig):\n            activation_scale_ub = quantizer_config.activation_scale_ub\n\n        return HybridFP8UnquantLoader(\n            activation_scale_ub,\n            to_fp8=quantize == \"fp8\",\n            weight_block_size=weight_block_size,\n        )\n    else:\n        raise ValueError(f\"Unknown quantization method: {quantize}\")\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/segments.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/utils/segments.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom typing import List, Tuple, Union\n\nimport torch\n\n\ndef find_segments(\n    adapter_indices: Union[torch.Tensor, List[int]],\n) -> Tuple[List[int], List[int]]:\n    segments = [0]\n    segment_indices = []\n\n    if isinstance(adapter_indices, torch.Tensor):\n        # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first\n        adapter_indices = adapter_indices.cpu().tolist()\n\n    start_index = 0\n    for i in range(1, len(adapter_indices)):\n        if adapter_indices[i] != adapter_indices[i - 1]:\n            segments.append(i)\n            segment_indices.append(adapter_indices[i - 1])\n            start_index = i\n\n    # Handle the last segment\n    if start_index < len(adapter_indices):\n        segments.append(len(adapter_indices))\n        segment_indices.append(adapter_indices[-1])\n\n    return segments, segment_indices\n\n\nclass SegmentConcatBuilder:\n    def __init__(self):\n        self.adapter_segment_indices = []\n        self.adapter_segment_tensors = []\n\n    def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):\n        # Update adapter segments\n        if self.adapter_segment_tensors:\n            # Because we have already processed at least one batch, remove the 0 start index\n            # from this batch denoting the beginning of the segment, then offset all segment\n            # positions by the value of the last segment in the previous batch to account for\n            # the concatenation.\n            adapter_segments = (\n                adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]\n            )\n\n        if (\n            self.adapter_segment_indices\n            and self.adapter_segment_indices[-1] == segment_indices[0]\n        ):\n            # If the last segment in the previous batch is the same as the first segment in this batch,\n            # then we merge them together into a single segment. In effect, this means removing it from\n            # the segment indices of this batch, and extending the segment span by removing the segment\n            # end index from the previous batch.\n            segment_indices = segment_indices[1:]\n            self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]\n\n        self.adapter_segment_indices.extend(segment_indices)\n        self.adapter_segment_tensors.append(adapter_segments)\n\n    def build(self) -> Tuple[torch.Tensor, List[int]]:\n        return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/sgmv.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/utils/sgmv.py\n# License:  Apache License Version 2.0, January 2004\n\nimport os\nimport warnings\nfrom functools import lru_cache\nfrom typing import List, Tuple\n\nimport torch\nimport torch.nn.functional as F\n\ntry:\n    import punica_kernels as _kernels\n\n    HAS_SGMV = not bool(os.environ.get(\"DISABLE_SGMV\", \"\"))\nexcept ImportError:\n    warnings.warn(\"Could not import SGMV kernel from Punica, falling back to loop.\")\n    _kernels = None\n    HAS_SGMV = False\n\n\nMIN_SGMV_RANK = 8\nMIN_RANK_CUSTOM = 16\nMAX_RANK_CUSTOM = 128\nSGMV_BLOCK_SIZE = 16\nBGMV_MAX_RANK = 64\n\n\ndef has_sgmv() -> bool:\n    return HAS_SGMV\n\n\ndef pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:\n    \"\"\"Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size.\"\"\"\n    if not has_sgmv():\n        return t\n\n    # tensor parallelism will result in effective rank being divided by world_size,\n    # so we need to scale the min rank to offset that effect\n    min_rank = MIN_SGMV_RANK * world_size\n\n    # if we're at or below the min rank, pad up to the min rank\n    # otherwise, pad to the nearest multiple of the block size\n    current_rank = t.size(dim)\n    target_rank = (\n        min_rank\n        if current_rank <= min_rank\n        else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE\n    )\n    if current_rank == target_rank:\n        return t\n\n    pad_size = target_rank - current_rank\n\n    # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n    pad = [0, 0] * t.dim()\n    pad[(t.dim() - dim - 1) * 2 + 1] = pad_size\n    pad = tuple(pad)\n\n    return F.pad(t, pad, mode=\"constant\", value=0.0)\n\n\ndef use_cutlass_shrink(lora_rank: int) -> bool:\n    return lora_rank < MIN_RANK_CUSTOM\n\n\ndef orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:\n    if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:\n        return t.transpose(0, 1)\n    return t\n\n\n# Source: https://github.com/punica-ai/punica/blob/master/src/punica/ops/__init__.py\ndef add_lora_sgmv_cutlass(\n    y: torch.Tensor,\n    x: torch.Tensor,\n    wa_ptr: torch.Tensor,\n    wb_ptr: torch.Tensor,\n    s_start: torch.Tensor,\n    s_end: torch.Tensor,\n    layer_idx: int,\n    lora_rank: int,\n):\n    \"\"\"\n    Semantics:\n        y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])\n\n    Args:\n        y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.\n        x: Shape: `[B, H1]`. Input vectors.\n        wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\\\n            Weight matrix shape: `[num_layers, R, H1]`.\n        wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\\\n            Weight matrix shape: `[num_layers, R, H2]`.\n        s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.\n        s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.\n        layer_idx: Layer index of the weight matrices.\n    \"\"\"\n    if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:\n        # Custom SGMV shrink only supports rank 16, 32, 64, 128\n        _add_lora_sgmv_cutlass_legacy(\n            y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank\n        )\n        return\n\n    tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)\n    tmp2_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))\n    tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)\n    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)\n    _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)\n    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)\n\n\ndef _add_lora_sgmv_cutlass_legacy(\n    y: torch.Tensor,\n    x: torch.Tensor,\n    wa_ptr: torch.Tensor,\n    wb_ptr: torch.Tensor,\n    s_start: torch.IntTensor,\n    s_end: torch.IntTensor,\n    layer_idx: int,\n    lora_rank: int,\n):\n    tmp_size = _kernels.sgmv_cutlass_tmp_size(wa_ptr.size(0))\n    tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)\n    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)\n    _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)\n    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)\n\n\n@lru_cache(maxsize=1)\ndef get_tmp_tensor(device: torch.device) -> torch.Tensor:\n    return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)\n\n\n@lru_cache(maxsize=32)\ndef get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:\n    tmp_size = _kernels.sgmv_cutlass_tmp_size(size)\n    return torch.empty((tmp_size,), dtype=torch.uint8, device=device)\n\n\ndef get_tmp_tensor_for_size_no_kernels(size: int, device: torch.device) -> torch.Tensor:\n    return torch.empty((size,), dtype=torch.uint8, device=device)\n\n\ndef get_tmp_expand_size(size: int) -> int:\n    return _kernels.sgmv_cutlass_tmp_size(size)\n\n\ndef get_tmp_tensors(\n    nsegments: int, lora_rank: int, device: torch.device\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    use_cutlass = use_cutlass_shrink(lora_rank) and has_sgmv()\n    has_sgmv_available = has_sgmv()\n\n    if use_cutlass:\n        tmp = get_tmp_tensor_for_size(nsegments, device)\n        return tmp, tmp\n    elif has_sgmv_available:\n        return get_tmp_tensor(device), get_tmp_tensor_for_size(nsegments, device)\n    else:\n        tmp = get_tmp_tensor_for_size(nsegments, device)\n        return tmp, tmp\n\n\ndef lora_a_sgmv_cutlass(\n    x: torch.Tensor,\n    tmp: torch.Tensor,\n    wa_ptr: torch.Tensor,\n    s_start: torch.IntTensor,\n    s_end: torch.IntTensor,\n    layer_idx: int,\n    lora_rank: int,\n) -> torch.Tensor:\n    v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)\n    if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:\n        _kernels.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)\n    else:\n        _kernels.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)\n    return v\n\n\ndef lora_b_sgmv_cutlass(\n    y: torch.Tensor,\n    v: torch.Tensor,\n    tmp: torch.Tensor,\n    wb_ptr: torch.Tensor,\n    s_start: torch.IntTensor,\n    s_end: torch.IntTensor,\n    layer_idx: int,\n):\n    _kernels.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)\n\n\n\"\"\"\nSemantics:\n    y[i] += (\n        x[i].unsqueeze(0)\n        @ wa_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)\n        @ wb_T_all[indices[i], layer_idx, :, :].transpose(-1, -2)\n        * scale\n    ).squeeze(0)\n\nArgs:\n    y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.\n    v: Shape: `[B, R]`. Temporary vector.\n    x: Shape: `[B, H1]`. Input vectors.\n    wa_T_all: Shape: `[None, L, R, H1]`. All of the transposed LoRA A matrices.\n    wb_T_all: Shape: `[None, L, H2, R]`. All of the transposed LoRA B matrices.\n    indicies: Shape: `[B]`. Indices of the LoRA weights.\n    layer_idx: Layer index of LoRA weights.\n    scale: Scaling factor.\n\"\"\"\n\n\ndef add_lora_a_bgmv(\n    v: torch.Tensor,\n    x: torch.Tensor,\n    wa_T_all: torch.Tensor,\n    indicies: torch.LongTensor,\n    layer_idx: int,\n):\n    _kernels.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)\n\n\ndef add_lora_b_bgmv(\n    y: torch.Tensor,\n    v: torch.Tensor,\n    wb_T_all: torch.Tensor,\n    indicies: torch.LongTensor,\n    layer_idx: int,\n):\n    _kernels.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)\n\n\ndef segmented_matmul(\n    y: torch.Tensor,\n    x: torch.Tensor,\n    w: List[torch.Tensor],\n    b: List[torch.Tensor],\n    s_start: torch.IntTensor,\n    s_end: torch.IntTensor,\n):\n    for i in range(len(w)):\n        if s_end[i] - s_start[i] <= 0:\n            continue\n\n        xi = x[s_start[i] : s_end[i]]\n        wi = w[i]\n        bi = b[i]\n        y[s_start[i] : s_end[i]] = F.linear(xi, wi, bi)\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/speculate.py",
    "content": "SPECULATE = None\n\n\ndef get_speculate() -> int:\n    global SPECULATE\n    return SPECULATE\n\n\ndef set_speculate(speculate: int):\n    global SPECULATE\n    SPECULATE = speculate\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/tokens.py",
    "content": "# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.\n\nimport re\nfrom typing import List, Optional, Tuple, Set, Union\n\nimport torch\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.pb.generate_pb2 import FinishReason, GrammarType\nfrom text_generation_server.utils.logits_process import (\n    FrequencyPenaltyLogitsProcessor,\n    GrammarLogitProcessor,\n    HeterogeneousProcessorWrapper,\n    HeterogeneousRepetitionPenaltyLogitsProcessor,\n    HeterogeneousFrequencyPenaltyLogitsProcessor,\n    HeterogeneousTemperatureLogitsWarper,\n    HeterogeneousTopKLogitsWarper,\n    HeterogeneousTopPLogitsWarper,\n    HeterogeneousTypicalLogitsWarper,\n    HeterogeneousGrammarLogitProcessor,\n    static_warper,\n)\nfrom text_generation_server.utils.watermark import WatermarkLogitsProcessor\nfrom transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor\nimport os\n\n\nclass NextTokenChooser:\n    def __init__(\n        self,\n        watermark: bool = False,\n        temperature: float = 1.0,\n        repetition_penalty: float = 1.0,\n        frequency_penalty: float = 0.0,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        typical_p: Optional[float] = None,\n        do_sample: bool = False,\n        seed: int = 0,\n        device: str = \"cpu\",\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        grammar: str = \"\",\n        grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,\n        fsm_grammar_state: int = 0,\n    ):\n        self.watermark_processor = (\n            WatermarkLogitsProcessor(device=device) if watermark else None\n        )\n        self.repetition_processor = (\n            RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)\n            if repetition_penalty and repetition_penalty != 1.0\n            else None\n        )\n        self.frequency_processor = (\n            FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)\n            if frequency_penalty and frequency_penalty != 0.0\n            else None\n        )\n        self.grammar_processor = (\n            GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)\n            if grammar != \"\"\n            else None\n        )\n        self.tokenizer = tokenizer\n\n        has_warpers = (\n            (temperature is not None and temperature != 1.0)\n            or (top_k is not None and top_k != 0)\n            or (top_p is not None and top_p < 1.0)\n            or (typical_p is not None and typical_p < 1.0)\n        )\n        if has_warpers:\n            self.static_warper = static_warper(\n                temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p\n            )\n        else:\n            self.static_warper = None\n\n        sampling = do_sample or has_warpers\n\n        self.choice = Sampling(seed, device) if sampling else Greedy()\n        self.fsm_grammar_state = fsm_grammar_state\n        self.grammar = grammar\n\n    def __call__(self, input_ids, scores):\n        if self.watermark_processor is not None:\n            scores = self.watermark_processor(input_ids, scores)\n        if self.repetition_processor is not None:\n            scores = self.repetition_processor(input_ids, scores)\n        if self.frequency_processor is not None:\n            scores = self.frequency_processor(input_ids, scores)\n        if self.grammar_processor is not None:\n            scores = self.grammar_processor(scores, self.fsm_grammar_state)\n\n        if self.static_warper is None:\n            next_logprob = torch.log_softmax(scores, -1)\n        else:\n            scores, next_logprob = self.static_warper(scores)\n\n        next_id = self.choice(scores[-1]).view(1, 1)\n\n        return next_id, next_logprob\n\n    def advance_grammar(self, next_id: int):\n        if self.grammar_processor is not None:\n            self.fsm_grammar_state = self.grammar_processor.advance(\n                next_id, self.fsm_grammar_state\n            )\n        return self\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.NextTokenChooserParameters,\n        device: torch.device,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> \"NextTokenChooser\":\n        return NextTokenChooser(\n            watermark=pb.watermark,\n            temperature=pb.temperature,\n            repetition_penalty=pb.repetition_penalty,\n            frequency_penalty=pb.frequency_penalty,\n            top_k=pb.top_k,\n            top_p=pb.top_p,\n            typical_p=pb.typical_p,\n            do_sample=pb.do_sample,\n            seed=pb.seed,\n            device=device,\n            tokenizer=tokenizer,\n            grammar=pb.grammar,\n            grammar_type=pb.grammar_type,\n        )\n\n\nclass StopSequenceCriteria:\n    def __init__(self, stop_sequence: str):\n        stop_sequence = re.escape(stop_sequence)\n        self.regex = re.compile(f\"{stop_sequence}$\")\n\n    def __call__(self, output: str) -> bool:\n        if self.regex.findall(output):\n            return True\n        return False\n\n\nclass StoppingCriteria:\n    def __init__(\n        self,\n        eos_token_ids: Optional[Union[Set[int], int]],\n        stop_sequence_criterias: List[StopSequenceCriteria],\n        max_new_tokens: int = 20,\n        ignore_eos_token: bool = False,\n    ):\n        if eos_token_ids is None:\n            eos_token_ids = set()\n        elif isinstance(eos_token_ids, int):\n            eos_token_ids = set([eos_token_ids])\n        elif isinstance(eos_token_ids, set):\n            eos_token_ids = eos_token_ids\n        else:\n            raise RuntimeError(\n                f\"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]\"\n            )\n        self.eos_token_ids = eos_token_ids\n        self.stop_sequence_criterias = stop_sequence_criterias\n        self.max_new_tokens = max_new_tokens\n        self.current_tokens = 0\n        self.current_output = \"\"\n\n        if os.getenv(\"TEXT_GENERATION_SERVER_IGNORE_EOS_TOKEN\", \"false\") == \"true\":\n            self.ignore_eos_token = True\n        else:\n            self.ignore_eos_token = ignore_eos_token\n\n    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:\n        self.current_tokens += 1\n        if self.current_tokens >= self.max_new_tokens:\n            return True, FinishReason.FINISH_REASON_LENGTH\n\n        if isinstance(last_token, torch.Tensor):\n            last_token = last_token.item()\n\n        if not self.ignore_eos_token and last_token in self.eos_token_ids:\n            return True, FinishReason.FINISH_REASON_EOS_TOKEN\n\n        if self.stop_sequence_criterias:\n            self.current_output += last_output\n            # There is no need to keep an output that is too long\n            if len(self.current_output) > 300:\n                # Slice to -200 to avoid doing it all the time\n                self.current_output = self.current_output[-200:]\n            for stop_sequence_criteria in self.stop_sequence_criterias:\n                if stop_sequence_criteria(self.current_output):\n                    return True, FinishReason.FINISH_REASON_STOP_SEQUENCE\n\n        return False, None\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.StoppingCriteriaParameters,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> \"StoppingCriteria\":\n        stop_sequence_criterias = [\n            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences\n        ]\n        # TODO Hack because eos_token_id cannot be what we want.\n        eos_token_id = getattr(tokenizer, \"_eos_token_ids\", tokenizer.eos_token_id)\n        return StoppingCriteria(\n            eos_token_id,\n            stop_sequence_criterias,\n            pb.max_new_tokens,\n            pb.ignore_eos_token,\n        )\n\n\ndef create_n_gram_speculation(\n    input_ids: torch.Tensor,\n    next_ids: torch.Tensor,\n    accepted_ids: torch.Tensor,\n    speculate: int,\n    verbose: bool,\n):\n    # Very trivial approach, find first match in the string.\n    # This is much less refined than actual n-gram but seems to work\n    # relatively OK in grounded mode and is by far much faster with\n    # much less worst case complexity as everything happens on device.\n    B = accepted_ids.shape[0]\n    device = input_ids.device\n    seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]\n    indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1\n    all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(\n        speculate, device=device\n    )\n    all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)\n\n    speculative_ids = input_ids.gather(dim=-1, index=all_indices)\n    return speculative_ids\n\n\nclass HeterogeneousNextTokenChooser:\n    def __init__(\n        self,\n        dtype: torch.dtype,\n        device: torch.device,\n        watermark: List[bool],\n        temperature: List[float],\n        repetition_penalty: List[float],\n        frequency_penalty: List[float],\n        top_k: List[int],\n        top_p: List[float],\n        typical_p: List[float],\n        do_sample: List[bool],\n        seeds: List[int],\n        tokenizer: PreTrainedTokenizerBase,\n        grammars: List[str],\n        grammar_types: List[int],\n        fsm_grammar_states: List[int],\n        quantization_enabled: bool,\n    ):\n        warpers = []\n\n        # TODO: enable watermark with FP8 quantization\n        self.watermark_processor = (\n            HeterogeneousProcessorWrapper(\n                {\n                    i: WatermarkLogitsProcessor(device=device)\n                    for i, do_watermark in enumerate(watermark)\n                    if do_watermark\n                }\n            )\n            if any(watermark) and not quantization_enabled\n            else None\n        )\n\n        self.repetition_processor = (\n            HeterogeneousRepetitionPenaltyLogitsProcessor(\n                repetition_penalty, dtype, device\n            )\n            if any([x != 1.0 for x in repetition_penalty])\n            else None\n        )\n\n        self.frequency_processor = (\n            HeterogeneousFrequencyPenaltyLogitsProcessor(\n                frequency_penalty, dtype, device\n            )\n            if any([x != 0.0 for x in frequency_penalty])\n            else None\n        )\n\n        self.grammar_processor = (\n            HeterogeneousGrammarLogitProcessor(\n                tokenizer, device, grammars, grammar_types\n            )\n            if any([grammar != \"\" for grammar in grammars])\n            else None\n        )\n\n        if any(x != 1.0 for x in temperature):\n            do_sample = [\n                sample or x != 1.0 for x, sample in zip(temperature, do_sample)\n            ]\n            warpers.append(\n                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)\n            )\n\n        if any(x != 0 for x in top_k):\n            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]\n            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))\n\n        if any(x < 1.0 for x in top_p):\n            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]\n            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))\n\n        if any(x < 1.0 for x in typical_p):\n            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]\n            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))\n\n        self.warpers = warpers\n\n        if any(do_sample):\n            self.choice = HeterogeneousSampling(do_sample, seeds, device)\n        else:\n            self.choice = Greedy()\n\n        self.seeds = seeds\n        self.do_sample = do_sample\n        self.dtype = dtype\n        self.device = device\n        self.tokenizer = tokenizer\n        self.fsm_grammar_states = fsm_grammar_states\n        self.grammars = grammars\n        self.grammar_types = grammar_types\n\n    def __call__(\n        self,\n        input_ids: torch.Tensor,\n        scores: torch.Tensor,\n        speculate: int,\n        speculated_ids: Optional[torch.Tensor] = None,\n        speculative_scores: Optional[torch.Tensor] = None,\n        verbose=False,\n    ):\n        if speculated_ids is not None:\n            B = scores.shape[0] // (speculated_ids.shape[1] + 1)\n            S = speculated_ids.shape[1] + 1\n            scores = scores.view(B, S, -1)\n        else:\n            B = scores.shape[0]\n            S = 1\n            scores = scores.view(B, S, -1)\n\n        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)\n\n        for j in range(S):\n            _scores = scores[:, j]\n            if self.watermark_processor is not None:\n                _scores = self.watermark_processor(input_ids, _scores)\n            if self.repetition_processor is not None:\n                _scores = self.repetition_processor(input_ids, _scores)\n            if self.frequency_processor is not None:\n                _scores = self.frequency_processor(input_ids, _scores)\n            if self.grammar_processor is not None:\n                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)\n            for warper in self.warpers:\n                _scores = warper(input_ids, _scores)\n            _next_ids = self.choice(_scores)\n            scores[:, j] = _scores\n            next_ids[:, j] = _next_ids\n        next_ids = next_ids.view(B * S)\n        allscores = scores.view(B * S, -1)\n        alllogprobs = torch.log_softmax(allscores, -1)\n\n        if speculated_ids is not None:\n            accepted_ids = []\n            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)\n            S = speculated_ids.shape[1] + 1\n            indices = []\n            for i in range(B):\n                _next_ids = next_ids[i * S : (i + 1) * S]\n                _speculated_ids = speculated_ids[i]\n                validate_speculative = _next_ids[:-1] == _speculated_ids\n                index = i * S\n                accepted = 1\n                # First is always valid\n                indices.append(index)\n                for valid in validate_speculative.tolist():\n                    if valid:\n                        index += 1\n                        accepted += 1\n                        indices.append(index)\n                    else:\n                        break\n                accepted_ids.append(accepted)\n\n            accepted_ids = torch.tensor(\n                accepted_ids, device=input_ids.device, dtype=input_ids.dtype\n            )\n            next_ids = next_ids[indices]\n            logprobs = alllogprobs[indices]\n            indices = torch.arange(B, device=input_ids.device) * S\n            if speculative_scores is not None:\n                speculative_scores = speculative_scores[indices + accepted_ids - 1]\n        else:\n            accepted_ids = torch.ones_like(next_ids)\n            logprobs = alllogprobs\n\n        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)\n\n        if speculate > 0:\n            if speculative_scores is not None:\n                # Medusa provided some scores\n                speculative_ids = Greedy()(speculative_scores)\n            else:\n                # n-gram\n                speculative_ids = create_n_gram_speculation(\n                    input_ids, next_ids, accepted_ids, speculate, verbose\n                )\n        else:\n            speculative_ids = None\n\n        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids\n\n    def advance_grammar(self, next_ids: List[int]):\n        if self.grammar_processor is not None:\n            other_new_states = self.grammar_processor.advance_batch(\n                next_ids, self.fsm_grammar_states\n            )\n            self.fsm_grammar_states = other_new_states\n        return self\n\n    def advance_grammar_single(self, grammar_state_index: int, next_id: int):\n        if self.grammar_processor is not None:\n            self.fsm_grammar_states[grammar_state_index] = (\n                self.grammar_processor.advance_at_index(\n                    next_id,\n                    self.fsm_grammar_states[grammar_state_index],\n                    grammar_state_index,\n                )\n            )\n        return self\n\n    def advance_grammar_single_with_past_state(\n        self, grammar_state_index: int, next_id: torch.Tensor, past_state: int\n    ):\n        if self.grammar_processor is not None:\n            next_id = next_id.item()\n            self.fsm_grammar_states[grammar_state_index] = (\n                self.grammar_processor.advance_at_index(\n                    next_id,\n                    past_state,\n                    grammar_state_index,\n                )\n            )\n        return self\n\n    def filter(self, indices):\n        if self.watermark_processor is not None:\n            self.watermark_processor = self.watermark_processor.filter(indices)\n\n        if self.repetition_processor is not None:\n            self.repetition_processor = self.repetition_processor.filter(indices)\n\n        if self.frequency_processor is not None:\n            self.frequency_processor = self.frequency_processor.filter(indices)\n\n        if self.grammar_processor is not None:\n            self.grammar_processor = self.grammar_processor.filter(indices)\n\n        filtered_warpers = []\n        for warper in self.warpers:\n            filtered_warper = warper.filter(indices)\n            if filtered_warper is not None:\n                filtered_warpers.append(filtered_warper)\n        self.warpers = filtered_warpers\n\n        self.seeds = [self.seeds[i] for i in indices]\n        self.do_sample = [self.do_sample[i] for i in indices]\n\n        new_grammars = []\n        new_fsm_grammar_states = []\n        new_grammar_types = []\n        for i in indices:\n            new_grammars.append(self.grammars[i])\n            new_fsm_grammar_states.append(self.fsm_grammar_states[i])\n            new_grammar_types.append(self.grammar_types[i])\n\n        self.grammars = new_grammars\n        self.fsm_grammar_states = new_fsm_grammar_states\n        self.grammar_types = new_grammar_types\n\n        if any(self.do_sample):\n            self.choice.filter(indices)\n        else:\n            self.choice = Greedy()\n\n        return self\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: List[generate_pb2.NextTokenChooserParameters],\n        dtype: torch.dtype,\n        device: torch.device,\n        tokenizer: PreTrainedTokenizerBase,\n        fsm_grammar_states: Optional[List[int]] = None,\n        quantization_enabled: bool = False,\n    ) -> \"HeterogeneousNextTokenChooser\":\n        return HeterogeneousNextTokenChooser(\n            watermark=[pb_.watermark for pb_ in pb],\n            temperature=[pb_.temperature for pb_ in pb],\n            repetition_penalty=[pb_.repetition_penalty for pb_ in pb],\n            frequency_penalty=[pb_.frequency_penalty for pb_ in pb],\n            top_k=[pb_.top_k for pb_ in pb],\n            top_p=[pb_.top_p for pb_ in pb],\n            typical_p=[pb_.typical_p for pb_ in pb],\n            do_sample=[pb_.do_sample for pb_ in pb],\n            seeds=[pb_.seed for pb_ in pb],\n            device=device,\n            dtype=dtype,\n            tokenizer=tokenizer,\n            grammars=[pb_.grammar for pb_ in pb],\n            grammar_types=[pb_.grammar_type for pb_ in pb],\n            fsm_grammar_states=(\n                fsm_grammar_states if fsm_grammar_states else [0] * len(pb)\n            ),\n            quantization_enabled=quantization_enabled,\n        )\n\n\ndef pad_next_token_chooser_parameters(\n    parameters: List[generate_pb2.NextTokenChooserParameters],\n    expected_size: int,\n) -> List[generate_pb2.NextTokenChooserParameters]:\n    # disable all logits processors to minimize padding overhead\n    empty_parameters = generate_pb2.NextTokenChooserParameters(\n        temperature=1.0,\n        top_k=0,\n        top_p=1.0,\n        typical_p=1.0,\n        do_sample=False,\n        seed=0,\n        repetition_penalty=1.0,\n        frequency_penalty=0.0,\n        watermark=False,\n        grammar=\"\",\n        grammar_type=0,\n    )\n    parameters.extend([empty_parameters] * (expected_size - len(parameters)))\n    return parameters\n\n\nclass Sampling:\n    def __init__(self, seed: int, device: str = \"cpu\"):\n        if device in [\"hpu\", torch.device(\"hpu\")]:\n            import habana_frameworks.torch.hpu.random as htrandom\n\n            self.generator = htrandom.default_generators[0].manual_seed(seed)\n        else:\n            self.generator = torch.Generator(\"cpu\")\n            self.generator.manual_seed(seed)\n        self.seed = seed\n\n    def __call__(self, logits):\n        probs = torch.nn.functional.softmax(logits, -1)\n        # Avoid GPU<->CPU sync done by torch multinomial\n        # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637\n        q = torch.empty_like(probs).exponential_(1, generator=self.generator)\n        return probs.div_(q).argmax()\n\n\nclass Greedy:\n    def __call__(self, logits):\n        return logits.argmax(dim=-1)\n\n\nclass HeterogeneousSampling:\n    r\"\"\"\n    Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.\n    \"\"\"\n\n    def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):\n        self.seeds = seeds\n\n        self.greedy_indices = []\n        self.sampling_mapping = {}\n        for i, (sample, seed) in enumerate(zip(do_sample, seeds)):\n            if sample:\n                self.sampling_mapping[i] = Sampling(seed, device)\n            else:\n                self.greedy_indices.append(i)\n\n        self.greedy = Greedy()\n\n    def __call__(self, logits):\n        out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)\n        if self.greedy_indices:\n            # Computing for all indices is faster than slicing\n            torch.argmax(logits, -1, out=out)\n\n        for i, sampling in self.sampling_mapping.items():\n            out[i] = sampling(logits[i])\n        return out\n\n    def filter(self, indices):\n        new_greedy_indices = []\n        new_sampling_mapping = {}\n        for i, idx in enumerate(indices):\n            if idx in self.sampling_mapping:\n                new_sampling_mapping[i] = self.sampling_mapping[idx]\n            else:\n                new_greedy_indices.append(i)\n\n        self.greedy_indices = new_greedy_indices\n        self.sampling_mapping = new_sampling_mapping\n        return self\n\n\ndef batch_top_tokens(\n    top_n_tokens: List[int],\n    top_n_tokens_tensor: torch.Tensor,\n    logprobs: torch.Tensor,\n    accepted_ids: torch.Tensor,\n) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:\n    \"\"\"Find the top n most likely tokens for a batch of generations.\n\n    When multiple tokens have equal probabilities and they don't all fit, the\n    remaining tokens are also returned.\n    \"\"\"\n    max_top_n = max(top_n_tokens)\n    # Early exit when top_n_tokens is not used\n    if max_top_n == 0:\n        return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)\n\n    batch_size = accepted_ids.shape[0]\n    speculate_size = logprobs.shape[0] // batch_size\n    top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)\n    # Ensure top_n doesn't exceed vocab size\n    top_n_tokens = [\n        min(tok, logprobs.size(-1))\n        for tok in top_n_tokens\n        for _ in range(speculate_size)\n    ]\n\n    # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2\n    # Sorted topk is faster than torch.sort() since we only need a small subset\n    sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values\n\n    nth_highest = torch.gather(\n        sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)\n    )\n    nth_highest[nth_highest == -float(\"inf\")] = torch.finfo(logprobs.dtype).min\n\n    # Find the new \"fuzzy\" top n values\n    top_n_indices = (logprobs >= nth_highest).nonzero()\n    _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)\n\n    k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()\n    # Take a new topk for these new max n values\n    top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)\n\n    top_n_ishes = top_n_ishes.tolist()\n    top_indices = top_k.indices.tolist()\n    top_values = top_k.values.tolist()\n\n    batch_top_token_ids = []\n    batch_top_token_logprobs = []\n    accepted_ids_list = accepted_ids.tolist()\n    for i, n_accepted_ids in enumerate(accepted_ids_list):\n        start = speculate_size * i\n        stop = speculate_size * (i + 1)\n        _top_indices = top_indices[start:stop]\n        _top_values = top_values[start:stop]\n        _top_n_ishes = top_n_ishes[start:stop]\n        _top_n_tokens = top_n_tokens[start:stop]\n\n        _top_indices = _top_indices[:n_accepted_ids]\n        _top_values = _top_values[:n_accepted_ids]\n        _top_n_ishes = _top_n_ishes[:n_accepted_ids]\n        _top_n_tokens = _top_n_tokens[:n_accepted_ids]\n\n        row_top_token_ids = []\n        row_top_token_logprobs = []\n\n        for idxs, vals, n, req_n in zip(\n            _top_indices, _top_values, _top_n_ishes, _top_n_tokens\n        ):\n            indices = idxs[:n] if req_n > 0 else []\n            values = vals[:n] if req_n > 0 else []\n\n            row_top_token_ids.append(indices)\n            row_top_token_logprobs.append(values)\n\n        batch_top_token_ids.append(row_top_token_ids)\n        batch_top_token_logprobs.append(row_top_token_logprobs)\n\n    return batch_top_token_ids, batch_top_token_logprobs\n\n\ndef make_tokenizer_optional(tokenizer):\n    class _(type(tokenizer)):\n        def __call__(\n            self,\n            text,\n            return_tensors,\n            padding,\n            return_token_type_ids,\n            truncation,\n            max_length,\n        ):\n            assert (\n                return_tensors == \"pt\"\n            ), \"inccorrect input arguments when calling TransparentTokenizer\"\n            assert (\n                padding == \"max_length\" or padding == \"longest\"\n            ), \"inccorrect input arguments when calling TransparentTokenizer\"\n            assert (\n                not return_token_type_ids\n            ), \"inccorrect input arguments when calling TransparentTokenizer\"\n            assert (\n                truncation\n            ), \"inccorrect input arguments when calling TransparentTokenizer\"\n\n            def str_token_to_int(i):\n                if i == \"?\":\n                    return tokenizer.pad_token_id\n                else:\n                    return int(i)\n\n            all_tokens = [\n                [str_token_to_int(i.strip()) for i in inner_text.split(\",\")]\n                for inner_text in text\n            ]\n            if padding == \"longest\":\n                max_length = max(len(tokens) for tokens in all_tokens)\n            return {\n                \"input_ids\": torch.tensor(\n                    [\n                        [tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens\n                        for tokens in all_tokens\n                    ]\n                ),\n                \"attention_mask\": torch.tensor(\n                    [\n                        [0] * (max_length - len(tokens)) + [1] * len(tokens)\n                        for tokens in all_tokens\n                    ]\n                ),\n            }\n\n        def decode(\n            self,\n            token_ids,\n            skip_special_tokens: bool = False,\n            clean_up_tokenization_spaces: bool = None,\n            **kwargs,\n        ) -> str:\n            # I don't think this method is used anywhere and should be removed when doing refactoring\n            return \",\".join(str(i) for i in to_py_obj(token_ids))  # noqa: F821\n\n    if os.getenv(\"SKIP_TOKENIZER_IN_TGI\", \"false\").lower() == \"true\":\n        tokenizer.__class__ = _\n        tokenizer.is_transparent = True\n\n\ndef is_tokenizer_transparent(tokenizer):\n    return hasattr(tokenizer, \"is_transparent\") and tokenizer.is_transparent is True\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/version.py",
    "content": "from packaging.version import Version\nfrom packaging import version\nimport subprocess\n\n\ndef get_driver_version():\n    \"\"\"\n    Returns the driver version.\n    \"\"\"\n    # Enable console printing for `hl-smi` check\n    output = subprocess.run(\n        \"hl-smi\",\n        shell=True,\n        text=True,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        env={\"ENABLE_CONSOLE\": \"true\"},\n    )\n    if output.returncode == 0 and output.stdout:\n        return version.parse(\n            output.stdout.split(\"\\n\")[2]\n            .replace(\" \", \"\")\n            .split(\":\")[1][:-1]\n            .split(\"-\")[0]\n        )\n    return None\n\n\nMIN_TGI_GAUDI_SYNAPSE_VERSION = Version(\"1.19.0\")\n\n\ndef is_driver_compatible():\n    driver_version = get_driver_version()\n    if driver_version is not None:\n        if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:\n            return False\n    return True\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/watermark.py",
    "content": "# coding=utf-8\n# Copyright 2023 Authors of \"A Watermark for Large Language Models\"\n# available at https://arxiv.org/abs/2301.10226\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport torch\nfrom transformers import LogitsProcessor\nfrom typing import List, Union\n\nGAMMA = float(os.getenv(\"WATERMARK_GAMMA\", 0.5))\nDELTA = float(os.getenv(\"WATERMARK_DELTA\", 2.0))\n\n\nclass WatermarkLogitsProcessor(LogitsProcessor):\n    def __init__(\n        self,\n        gamma: float = GAMMA,\n        delta: float = DELTA,\n        hash_key: int = 15485863,  # just a large prime number to create a rng seed with sufficient bit width\n        device: str = \"cpu\",\n    ):\n        # watermarking parameters\n        self.gamma = gamma\n        self.delta = delta\n        self.rng = torch.Generator(device=\"cpu\")\n        self.hash_key = hash_key\n\n    def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):\n        if isinstance(input_ids, list):\n            assert (\n                len(input_ids) >= 1\n            ), \"requires at least a 1 token prefix sequence to seed rng\"\n            prev_token = input_ids[-1]\n        else:\n            assert len(input_ids) == 1\n            input_ids = input_ids[0]\n            assert (\n                input_ids.shape[-1] >= 1\n            ), \"requires at least a 1 token prefix sequence to seed rng\"\n            prev_token = input_ids[-1].item()\n        self.rng.manual_seed(self.hash_key * prev_token)\n\n    def _get_greenlist_ids(\n        self,\n        input_ids: Union[List[int], torch.LongTensor],\n        max_value: int,\n        device: torch.device,\n    ) -> List[int]:\n        # seed the rng using the previous tokens/prefix\n        self._seed_rng(input_ids)\n\n        greenlist_size = int(max_value * self.gamma)\n        vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)\n        greenlist_ids = vocab_permutation[:greenlist_size]\n        return greenlist_ids\n\n    @staticmethod\n    def _calc_greenlist_mask(\n        scores: torch.FloatTensor, greenlist_token_ids\n    ) -> torch.BoolTensor:\n        green_tokens_mask = torch.zeros_like(scores)\n        green_tokens_mask[-1, greenlist_token_ids] = 1\n        final_mask = green_tokens_mask.bool()\n        return final_mask\n\n    @staticmethod\n    def _bias_greenlist_logits(\n        scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float\n    ) -> torch.Tensor:\n        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias\n        return scores\n\n    def __call__(\n        self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        greenlist_ids = self._get_greenlist_ids(\n            input_ids, scores.shape[-1], scores.device\n        )\n        green_tokens_mask = self._calc_greenlist_mask(\n            scores=scores, greenlist_token_ids=greenlist_ids\n        )\n\n        scores = self._bias_greenlist_logits(\n            scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta\n        )\n        return scores\n"
  },
  {
    "path": "backends/gaudi/server/text_generation_server/utils/weights.py",
    "content": "import torch\n\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Union, Type\nfrom safetensors import safe_open\nfrom dataclasses import dataclass\n\n\nclass WeightsLoader(ABC):\n    \"\"\"\n    Instances of this type implement higher-level weight loading.\n\n    At a low-level, every weight is stored in the Safetensors format.\n    The interpretation of weights may be different however, for instance\n    could be packed, quantized weights. Loaders are responsible for\n    interpreting the raw tensors, sharding tensors in a manner compatible\n    with the format, etc.\n    \"\"\"\n\n    @abstractmethod\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        \"\"\"\n        Get the packed weights at the given prefix with column-splitting for\n        tensor parallelism. This method should be used when multiple different\n        weights are packed into a tensor, for instance, query/key/value\n        weights or a gate/up projection.\n\n        The `block_sizes` determines the proportions of the packed tensors.\n        The columns are split in equally sized blocks when `block_sizes` is an\n        `int`, or in blocks proportional given to the sizes. For instance\n        `[2, 1, 1]` will divide an input with dimensionality `1024` in\n        `[512, 256, 256]`.\n        \"\"\"\n        ...\n\n    def get_weights_col(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply column-splitting for tensor\n        paralllism.\n        \"\"\"\n        return weights.get_multi_weights_col([prefix], 0)\n\n    @abstractmethod\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        \"\"\"\n        Get the weights at the given prefixes, column-split them for tensor\n        parallelim, and then concatenate the weights along the given dimension.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_multi_weights(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        \"\"\"\n        Get the weights at the given prefixes, column-split them for tensor\n        parallelim, and then concatenate the weights along the given dimension.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get the weights at the given prefix and apply row-splitting for tensor\n        parallism.\n        \"\"\"\n        ...\n\n\nclass Weight(ABC):\n    \"\"\"Instances of this type implement unquantized/quantized/to-be\n    quantized weights.\"\"\"\n\n    @abstractmethod\n    def get_linear(self, bias: torch.Tensor):\n        \"\"\"Create a linear layer from this weight.\"\"\"\n        ...\n\n\n@dataclass\nclass UnquantizedWeight(Weight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        from text_generation_server.layers.linear import FastLinear\n\n        return FastLinear(self.weight, bias)\n\n\nclass DefaultWeightsLoader(WeightsLoader):\n    \"\"\"Weight loader that loads (unquantized) Torch tensors.\"\"\"\n\n    def __init__(self, weight_class: Type[UnquantizedWeight]):\n        \"\"\"Create a loader. Weights will be wrapped using the given `weights_class`,\n        normally this will be `UnquantizedWeight`, but a quantizer-specific class\n        such as `Fp8Weight` can be used to quantize the weights during loading.\n        \"\"\"\n        self.weight_class = weight_class\n\n    \"\"\"\n    Loader that uses tensors as-is with the exception of applying sharding\n    and/or concatenation.\n    \"\"\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        return weights.get_tensor(f\"{prefix}.weight\")\n\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        return self.weight_class(\n            weights.get_packed_sharded(\n                f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n            ),\n        )\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        w = [weights.get_sharded(f\"{p}.weight\", dim=0) for p in prefixes]\n        return self.weight_class(torch.cat(w, dim=dim))\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        return self.weight_class(\n            weights.get_sharded(f\"{prefix}.weight\", dim=1),\n        )\n\n    def get_multi_weights(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        w = [weights.get_tensor(f\"{p}.weight\") for p in prefixes]\n        return self.weight_class(torch.cat(w, dim=dim))\n\n\nclass Weights:\n    def __init__(\n        self,\n        filenames: List[Path],\n        device,\n        dtype,\n        process_group,\n        weights_loader: WeightsLoader,\n        aliases: Optional[Dict[str, List[str]]] = None,\n        prefix: Optional[str] = None,\n    ):\n        routing = {}\n        for filename in filenames:\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n        if aliases is None:\n            aliases = {}\n        self.aliases = aliases\n        self.routing = routing\n        self.device = device\n        self.dtype = dtype\n        self.process_group = process_group\n        self.prefix = prefix\n        self.weights_loader = weights_loader\n        self._handles = {}\n\n    def _get_handle(self, filename):\n        if filename not in self._handles:\n            f = safe_open(filename, framework=\"pytorch\")\n            self._handles[filename] = f\n\n        return self._handles[filename]\n\n    def get_filename(self, tensor_name: str) -> (str, str):\n        names = [tensor_name]\n        if self.prefix is not None:\n            prefixed = f\"{self.prefix}.{tensor_name}\"\n            names.append(prefixed)\n        for name in names:\n            filename = self.routing.get(name, None)\n            if filename is not None:\n                return str(filename), name\n\n            aliases = self.aliases.get(name, [])\n            for alias in aliases:\n                filename = self.routing.get(alias, None)\n                if filename is not None:\n                    return str(filename), alias\n        raise RuntimeError(f\"weight {tensor_name} does not exist\")\n\n    def _get_slice(self, tensor_name: str):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        return slice_\n\n    def has_tensor(self, tensor_name: str):\n        try:\n            self.get_filename(tensor_name)\n        except Exception:\n            return False\n        return True\n\n    def get_shape(self, tensor_name: str):\n        return self._get_slice(tensor_name).get_shape()\n\n    def get_tensor(\n        self, tensor_name: str, to_device: bool = True, to_dtype: bool = True\n    ) -> torch.Tensor:\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        tensor = f.get_tensor(tensor_name)\n        # Special case for gptq which shouldn't convert\n        # u4 which are disguised as int32. Exl2 uses int16\n        # as well. FP8 uses torch.float8_e4m3fn\n        if (\n            tensor.dtype\n            not in [\n                torch.float8_e4m3fn,\n                torch.int8,\n                torch.int16,\n                torch.int32,\n                torch.int64,\n            ]\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n        if to_device:\n            tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_partial_sharded(\n        self, tensor_name: str, dim: int, to_device=True, to_dtype=True\n    ):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n\n        size = slice_.get_shape()[dim]\n        block_size = (size + world_size - 1) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n\n        if dim == 0:\n            tensor = slice_[start:stop]\n        elif dim == 1:\n            tensor = slice_[:, start:stop]\n        else:\n            raise NotImplementedError(\"Let's make that generic when needed\")\n        # Special case for gptq which shouldn't convert\n        # u4 which are disguised as int32. exl2 uses int16.\n        # FP8 uses torch.float8_e4m3fn.\n        if (\n            tensor.dtype\n            not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n        if to_device:\n            tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        world_size = self.process_group.size()\n        size = slice_.get_shape()[dim]\n        assert (\n            size % world_size == 0\n        ), f\"The choosen size {size} is not compatible with sharding on {world_size} shards\"\n        return self.get_partial_sharded(\n            tensor_name, dim, to_device=to_device, to_dtype=to_dtype\n        )\n\n    def get_packed_sharded(\n        self,\n        tensor_name: str,\n        dim: int,\n        block_sizes: Union[int, List[int]],\n        to_dtype=True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Get a shard from a tensor that packs multiple tensors.\n\n        When a tensor packs multiple tensors (such as QKV or an up\n        projection + gate projection), sharding with `get_sharded` is not\n        safe since it would not split the packed tensors across shards.\n\n        This method shards a tensor, such that the packed tensors are\n        split across shards.\n\n        The columns are split in equally sized blocks when blocks is an `int`, or\n        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will\n        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is\n        convenient for e.g. splitting QKV without knowing the storage details of\n        quantized weights.\n        \"\"\"\n        slice_ = self._get_slice(tensor_name)\n        total_size = slice_.get_shape()[dim]\n        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)\n\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n\n        tensors_slices = []\n        block_offset = 0\n        for block_size in block_sizes:\n            assert (\n                block_size % world_size == 0\n            ), f\"Prepacked tensor cannot be sharded across {world_size} shards\"\n            shard_block_size = block_size // world_size\n            start = rank * shard_block_size\n            stop = (rank + 1) * shard_block_size\n            tensors_slices += range(block_offset + start, block_offset + stop)\n            block_offset += block_size\n\n        if dim == 0:\n            tensor = slice_[tensors_slices, ...]\n        elif dim == 1 or dim == -2:\n            tensor = slice_[:, tensors_slices, ...]\n        elif dim == 2 or dim == -1:\n            tensor = slice_[..., tensors_slices]\n        else:\n            raise ValueError(f\"Unsupported dim {dim}, only dim 0, 1 or 2 are supported\")\n\n        tensor = tensor.to(device=self.device)\n\n        # Avoid casting quantizer dtypes.\n        if (\n            tensor.dtype\n            not in [\n                torch.float8_e4m3fn,\n                torch.int8,\n                torch.int16,\n                torch.int32,\n                torch.int64,\n            ]\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n\n        return tensor\n\n    def get_weights(self, prefix: str):\n        return self.weights_loader.get_weights(self, prefix)\n\n    def get_weights_col_packed_qkv(\n        self,\n        prefix: str,\n        num_heads: int,\n        num_key_value_heads: int,\n    ):\n        return self.get_weights_col_packed(\n            prefix, [num_heads, num_key_value_heads, num_key_value_heads]\n        )\n\n    def get_weights_col_packed_gate_up(self, prefix: str):\n        return self.get_weights_col_packed(prefix, 2)\n\n    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):\n        \"\"\"\n        The columns are split in equally sized blocks when blocks is an `int`, or\n        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will\n        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is\n        convenient for e.g. splitting QKV without knowing the storage details of\n        quantized weights.\n        \"\"\"\n        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)\n\n    def get_weights_col(self, prefix: str):\n        return self.weights_loader.get_weights_col(self, prefix)\n\n    def get_multi_weights_col(self, prefixes: List[str], dim: int):\n        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)\n\n    def get_tensor_shard(self, var, dim):\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n        block_size = var.size()[dim] // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        if dim == 0:\n            tensor = var[start:stop]\n        elif dim == 1:\n            tensor = var[:, start:stop]\n        else:\n            raise NotImplementedError(\"Let's make that generic when needed\")\n        tensor = tensor.to(dtype=self.dtype)\n        tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_weights_row(self, prefix: str):\n        return self.weights_loader.get_weights_row(self, prefix)\n\n    def get_multi_weights(self, prefixes: List[str], dim: int):\n        return self.weights_loader.get_multi_weights(self, prefixes, dim)\n\n    @contextmanager\n    def use_loader(self, weights_loader: WeightsLoader):\n        \"\"\"\n        This method is a context manager that can be used to use `Weights` with\n        a different loader for the duration of the context.\n        \"\"\"\n\n        old_loader = self.weights_loader\n        self.weights_loader = weights_loader\n        try:\n            yield\n        finally:\n            self.weights_loader = old_loader\n\n    @property\n    def loader(self):\n        return self.weights_loader\n\n\ndef _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:\n    \"\"\"\n    Convert block count or proportions to block sizes.\n\n    This function accepts\n\n    - The number of blocks (int), in which case the block size is\n      total_size//blocks; or\n    - A list of block sizes (List[int]).\n\n    In the latter case, if sum(blocks) < total_size, the ratios between\n    the block sizes will be preserved. For instance, if blocks is\n    [2, 1, 1] and total_size is 1024, the returned block sizes are\n    [512, 256, 256].\n    \"\"\"\n    if isinstance(blocks, list):\n        total_blocks = sum(blocks)\n        assert (\n            total_size % total_blocks == 0\n        ), f\"Cannot split {total_size} in proportional blocks: {blocks}\"\n        part_size = total_size // total_blocks\n        return [part_size * block for block in blocks]\n    else:\n        assert total_size % blocks == 0, f\"Prepacked is not divisible by {blocks}\"\n        single_size = total_size // blocks\n        return [single_size] * blocks\n"
  },
  {
    "path": "backends/gaudi/tgi-entrypoint.sh",
    "content": "#!/bin/bash\n\nldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'\n\n# Check if --sharded argument is present in the command line arguments\nif [[ \"$*\" == *\"--sharded true\"* ]]; then\n  echo 'setting PT_HPU_ENABLE_LAZY_COLLECTIVES=1 for sharding'\n  export PT_HPU_ENABLE_LAZY_COLLECTIVES=1\nfi\n\ntext-generation-launcher $@\n"
  },
  {
    "path": "backends/grpc-metadata/Cargo.toml",
    "content": "[package]\nname = \"grpc-metadata\"\nversion = \"0.1.0\"\nedition = \"2021\"\n\n[dependencies]\nopentelemetry = \"^0.20\"\ntonic = \"^0.10\"\ntracing = \"^0.1\"\ntracing-opentelemetry = \"^0.21\"\n"
  },
  {
    "path": "backends/grpc-metadata/src/lib.rs",
    "content": "//! A crate to extract and inject a OpenTelemetry context from and to a gRPC request.\n//! Inspired by: https://github.com/open-telemetry/opentelemetry-rust gRPC examples\n\nuse opentelemetry::global;\nuse opentelemetry::propagation::Injector;\nuse tracing_opentelemetry::OpenTelemetrySpanExt;\n\n/// Inject context in the metadata of a gRPC request.\nstruct MetadataInjector<'a>(pub &'a mut tonic::metadata::MetadataMap);\n\nimpl Injector for MetadataInjector<'_> {\n    /// Set a key and value in the MetadataMap.  Does nothing if the key or value are not valid inputs\n    fn set(&mut self, key: &str, value: String) {\n        if let Ok(key) = tonic::metadata::MetadataKey::from_bytes(key.as_bytes()) {\n            if let Ok(val) = value.parse() {\n                self.0.insert(key, val);\n            }\n        }\n    }\n}\n\n/// Get a context from the global context and inject the span into a gRPC request's metadata.\nfn inject(metadata: &mut tonic::metadata::MetadataMap) {\n    global::get_text_map_propagator(|propagator| {\n        propagator.inject_context(\n            &tracing::Span::current().context(),\n            &mut MetadataInjector(metadata),\n        )\n    })\n}\n\npub trait InjectTelemetryContext {\n    fn inject_context(self) -> Self;\n}\n\nimpl<T> InjectTelemetryContext for tonic::Request<T> {\n    fn inject_context(mut self) -> Self {\n        inject(self.metadata_mut());\n        self\n    }\n}\n"
  },
  {
    "path": "backends/llamacpp/Cargo.toml",
    "content": "[package]\nname = \"text-generation-router-llamacpp\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[build-dependencies]\nbindgen = \"0.71.1\"\npkg-config = \"0.3.31\"\n\n[dependencies]\nasync-trait = \"0.1.85\"\nclap = \"4.5.27\"\nhf-hub.workspace = true\nnum_cpus = \"1.16.0\"\ntext-generation-router = { path = \"../../router\" }\nthiserror = \"2.0.11\"\ntokenizers.workspace = true\ntokio = { version = \"1.43.0\", features = [\"process\"] }\ntokio-stream = \"0.1.17\"\ntracing = \"0.1.41\"\n"
  },
  {
    "path": "backends/llamacpp/README.md",
    "content": "# Llamacpp backend\n\nIf all your dependencies are installed at the system level, running\ncargo build should be sufficient. However, if you want to experiment\nwith different versions of llama.cpp, some additional setup is required.\n\n## Install llama.cpp\n\n    LLAMACPP_PREFIX=$(pwd)/llama.cpp.out\n\n    git clone https://github.com/ggerganov/llama.cpp\n    cd llama.cpp\n    cmake -B build \\\n        -DCMAKE_INSTALL_PREFIX=\"$LLAMACPP_PREFIX\" \\\n        -DLLAMA_BUILD_COMMON=OFF \\\n        -DLLAMA_BUILD_TESTS=OFF \\\n        -DLLAMA_BUILD_EXAMPLES=OFF \\\n        -DLLAMA_BUILD_SERVER=OFF\n    cmake --build build --config Release -j\n    cmake --install build\n\n## Build TGI\n\n    PKG_CONFIG_PATH=\"$LLAMACPP_PREFIX/lib/pkgconfig\" cargo build\n"
  },
  {
    "path": "backends/llamacpp/build.rs",
    "content": "use bindgen::callbacks::{ItemInfo, ParseCallbacks};\nuse std::env;\nuse std::path::PathBuf;\n\n#[derive(Debug)]\nstruct PrefixStripper;\n\nimpl ParseCallbacks for PrefixStripper {\n    fn generated_name_override(&self, item_info: ItemInfo<'_>) -> Option<String> {\n        item_info.name.strip_prefix(\"llama_\").map(str::to_string)\n    }\n}\n\nfn main() {\n    if let Some(cuda_version) = option_env!(\"CUDA_VERSION\") {\n        let mut version: Vec<&str> = cuda_version.split('.').collect();\n        if version.len() > 2 {\n            version.pop();\n        }\n        let cuda_version = format!(\"cuda-{}\", version.join(\".\"));\n        pkg_config::Config::new().probe(&cuda_version).unwrap();\n    }\n    let llama = pkg_config::Config::new().probe(\"llama\").unwrap();\n\n    for path in &llama.link_paths {\n        println!(\"cargo:rustc-link-arg=-Wl,-rpath,{}\", path.display());\n    }\n    if cfg!(target_os = \"linux\") {\n        println!(\"cargo:rustc-link-arg=-Wl,--disable-new-dtags\");\n    }\n    let bindings = bindgen::Builder::default()\n        .clang_args(\n            llama\n                .include_paths\n                .iter()\n                .map(|p| format!(\"-I{}\", p.display())),\n        )\n        .header_contents(\"llama_bindings.h\", \"#include <llama.h>\")\n        .prepend_enum_name(false)\n        .parse_callbacks(Box::new(PrefixStripper))\n        .parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))\n        .generate()\n        .expect(\"Unable to generate bindings\");\n\n    let out_path = PathBuf::from(env::var(\"OUT_DIR\").unwrap());\n    bindings\n        .write_to_file(out_path.join(\"llamacpp.rs\"))\n        .expect(\"Couldn't write bindings!\");\n}\n"
  },
  {
    "path": "backends/llamacpp/requirements.txt",
    "content": "transformers==4.49\nhuggingface-hub==0.28.1\nhf-transfer==0.1.9\ntorch==2.6.0\n"
  },
  {
    "path": "backends/llamacpp/src/backend.rs",
    "content": "use crate::llamacpp;\n\nuse async_trait::async_trait;\nuse std::ffi::CString;\nuse std::mem::replace;\nuse std::str::FromStr;\nuse std::sync::{mpsc, Once};\nuse text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};\nuse text_generation_router::validation::ValidGenerateRequest;\nuse text_generation_router::{FinishReason, Token};\nuse thiserror::Error;\nuse tokenizers::Tokenizer;\nuse tokio::sync::mpsc::{unbounded_channel, UnboundedSender};\nuse tokio::sync::{oneshot, watch};\nuse tokio::task::{spawn, spawn_blocking};\nuse tokio::time::{timeout, Duration, Instant};\nuse tokio_stream::wrappers::UnboundedReceiverStream;\nuse tracing::instrument;\nuse tracing::{debug, error, info, trace, warn};\n\n#[derive(Debug, Clone, Copy)]\npub enum LlamacppSplitMode {\n    GPU(usize),\n    Layer,\n    Row,\n}\n\nimpl FromStr for LlamacppSplitMode {\n    type Err = String;\n    fn from_str(s: &str) -> Result<Self, Self::Err> {\n        match s.to_lowercase().as_str() {\n            \"layer\" => Ok(LlamacppSplitMode::Layer),\n            \"row\" => Ok(LlamacppSplitMode::Row),\n            _ => match s.parse::<usize>() {\n                Ok(n) => Ok(LlamacppSplitMode::GPU(n)),\n                Err(_) => Err(\"Choose a GPU number or `layer` or `row`\".to_string()),\n            },\n        }\n    }\n}\n\n#[derive(Debug, Clone, Copy, clap::ValueEnum)]\npub enum LlamacppNuma {\n    Disabled,\n    Distribute,\n    Isolate,\n    Numactl,\n    Mirror,\n}\n\n#[allow(non_camel_case_types)]\n#[derive(Debug, Clone, Copy, clap::ValueEnum)]\npub enum LlamacppGGMLType {\n    F32,\n    F16,\n    Q4_0,\n    Q4_1,\n    Q5_0,\n    Q5_1,\n    Q8_0,\n    Q8_1,\n    Q2_K,\n    Q3_K,\n    Q4_K,\n    Q5_K,\n    Q6_K,\n    Q8_K,\n    IQ2_XXS,\n    IQ2_XS,\n    IQ3_XXS,\n    IQ1_S,\n    IQ4_NL,\n    IQ3_S,\n    IQ2_S,\n    IQ4_XS,\n    I8,\n    I16,\n    I32,\n    I64,\n    F64,\n    IQ1_M,\n    BF16,\n    TQ1_0,\n    TQ2_0,\n}\n\n// TODO: macro\nimpl LlamacppGGMLType {\n    fn to_ggml_type(self) -> llamacpp::ggml_type {\n        match self {\n            LlamacppGGMLType::F32 => llamacpp::GGML_TYPE_F32,\n            LlamacppGGMLType::F16 => llamacpp::GGML_TYPE_F16,\n            LlamacppGGMLType::Q4_0 => llamacpp::GGML_TYPE_Q4_0,\n            LlamacppGGMLType::Q4_1 => llamacpp::GGML_TYPE_Q4_1,\n            LlamacppGGMLType::Q5_0 => llamacpp::GGML_TYPE_Q5_0,\n            LlamacppGGMLType::Q5_1 => llamacpp::GGML_TYPE_Q5_1,\n            LlamacppGGMLType::Q8_0 => llamacpp::GGML_TYPE_Q8_0,\n            LlamacppGGMLType::Q8_1 => llamacpp::GGML_TYPE_Q8_1,\n            LlamacppGGMLType::Q2_K => llamacpp::GGML_TYPE_Q2_K,\n            LlamacppGGMLType::Q3_K => llamacpp::GGML_TYPE_Q3_K,\n            LlamacppGGMLType::Q4_K => llamacpp::GGML_TYPE_Q4_K,\n            LlamacppGGMLType::Q5_K => llamacpp::GGML_TYPE_Q5_K,\n            LlamacppGGMLType::Q6_K => llamacpp::GGML_TYPE_Q6_K,\n            LlamacppGGMLType::Q8_K => llamacpp::GGML_TYPE_Q8_K,\n            LlamacppGGMLType::IQ2_XXS => llamacpp::GGML_TYPE_IQ2_XXS,\n            LlamacppGGMLType::IQ2_XS => llamacpp::GGML_TYPE_IQ2_XS,\n            LlamacppGGMLType::IQ3_XXS => llamacpp::GGML_TYPE_IQ3_XXS,\n            LlamacppGGMLType::IQ1_S => llamacpp::GGML_TYPE_IQ1_S,\n            LlamacppGGMLType::IQ4_NL => llamacpp::GGML_TYPE_IQ4_NL,\n            LlamacppGGMLType::IQ3_S => llamacpp::GGML_TYPE_IQ3_S,\n            LlamacppGGMLType::IQ2_S => llamacpp::GGML_TYPE_IQ2_S,\n            LlamacppGGMLType::IQ4_XS => llamacpp::GGML_TYPE_IQ4_XS,\n            LlamacppGGMLType::I8 => llamacpp::GGML_TYPE_I8,\n            LlamacppGGMLType::I16 => llamacpp::GGML_TYPE_I16,\n            LlamacppGGMLType::I32 => llamacpp::GGML_TYPE_I32,\n            LlamacppGGMLType::I64 => llamacpp::GGML_TYPE_I64,\n            LlamacppGGMLType::F64 => llamacpp::GGML_TYPE_F64,\n            LlamacppGGMLType::IQ1_M => llamacpp::GGML_TYPE_IQ1_M,\n            LlamacppGGMLType::BF16 => llamacpp::GGML_TYPE_BF16,\n            LlamacppGGMLType::TQ1_0 => llamacpp::GGML_TYPE_TQ1_0,\n            LlamacppGGMLType::TQ2_0 => llamacpp::GGML_TYPE_TQ2_0,\n        }\n    }\n}\n\npub struct LlamacppConfig {\n    pub model_gguf: String,\n    pub max_batch_total_tokens: usize,\n    pub max_physical_batch_total_tokens: usize,\n    pub max_batch_size: usize,\n    pub batch_timeout: Duration,\n    pub n_threads: usize,\n    pub n_threads_batch: usize,\n    pub n_gpu_layers: usize,\n    pub split_mode: LlamacppSplitMode,\n    pub numa: LlamacppNuma,\n    pub defrag_threshold: f32,\n    pub use_mmap: bool,\n    pub use_mlock: bool,\n    pub offload_kqv: bool,\n    pub flash_attention: bool,\n    pub type_k: LlamacppGGMLType,\n    pub type_v: LlamacppGGMLType,\n}\n\n#[derive(Debug)]\nstruct LlamacppRequest {\n    input_ids: Vec<i32>,\n    top_k: i32,\n    top_p: f32,\n    typical_p: f32,\n    min_keep: usize,\n    temp: f32,\n    seed: u32,\n    penalty_last_n: i32,\n    penalty_repeat: f32,\n    penalty_freq: f32,\n    penalty_present: f32,\n    max_new_tokens: usize,\n    tx: UnboundedSender<Result<InferStreamResponse, InferError>>,\n    time: Instant,\n}\n\npub struct LlamacppBackend {\n    tx: UnboundedSender<LlamacppRequest>,\n    status: watch::Receiver<bool>,\n}\n\nimpl LlamacppRequest {\n    fn new(\n        from: &ValidGenerateRequest,\n        tx: UnboundedSender<Result<InferStreamResponse, InferError>>,\n    ) -> Option<Self> {\n        from.input_ids.as_ref().map(|input_ids| LlamacppRequest {\n            input_ids: input_ids.iter().map(|&x| x as i32).collect(),\n            top_k: from.parameters.top_k as _,\n            top_p: from.parameters.top_p as _,\n            typical_p: from.parameters.typical_p as _,\n            min_keep: 0, // disabled\n            temp: from.parameters.temperature as _,\n            seed: from.parameters.seed as _,\n            penalty_last_n: 64, // 0 = disabled, -1 = context size\n            penalty_repeat: from.parameters.repetition_penalty as _,\n            penalty_freq: from.parameters.frequency_penalty as _,\n            penalty_present: 0.0, // disabled\n            max_new_tokens: from.stopping_parameters.max_new_tokens as _,\n            tx,\n            time: Instant::now(),\n        })\n    }\n}\n\nstruct Llamacpp {\n    model: *mut llamacpp::llama_model,\n    ctx: *mut llamacpp::llama_context,\n    vocab: *const llamacpp::llama_vocab,\n    logprobs: Vec<llamacpp::llama_token_data>,\n    batch: llamacpp::llama_batch,\n}\n\nextern \"C\" fn llamacpp_log_callback(\n    level: llamacpp::ggml_log_level,\n    msg: *const std::os::raw::c_char,\n    _user_data: *mut std::os::raw::c_void,\n) {\n    let cmsg = unsafe { std::ffi::CStr::from_ptr(msg) };\n    let rmsg = cmsg.to_string_lossy().trim_end_matches('\\n').to_string();\n\n    match level {\n        llamacpp::GGML_LOG_LEVEL_DEBUG => debug!(target: \"llamacpp\", \"{}\", rmsg),\n        llamacpp::GGML_LOG_LEVEL_INFO => info!(target: \"llamacpp\", \"{}\", rmsg),\n        llamacpp::GGML_LOG_LEVEL_WARN => warn!(target: \"llamacpp\", \"{}\", rmsg),\n        llamacpp::GGML_LOG_LEVEL_ERROR => error!(target: \"llamacpp\", \"{}\", rmsg),\n        _ => trace!(target: \"llamacpp\", \"{}\", rmsg),\n    }\n}\n\nimpl Llamacpp {\n    fn new(conf: LlamacppConfig) -> Result<Self, BackendError> {\n        let gguf = CString::new(conf.model_gguf)?;\n\n        let model = unsafe {\n            let mut params = llamacpp::model_default_params();\n            params.n_gpu_layers = conf.n_gpu_layers as _;\n            params.split_mode = match conf.split_mode {\n                LlamacppSplitMode::GPU(_) => llamacpp::LLAMA_SPLIT_MODE_NONE,\n                LlamacppSplitMode::Layer => llamacpp::LLAMA_SPLIT_MODE_LAYER,\n                LlamacppSplitMode::Row => llamacpp::LLAMA_SPLIT_MODE_ROW,\n            };\n            params.main_gpu = match conf.split_mode {\n                LlamacppSplitMode::GPU(n) => n as _,\n                _ => 0,\n            };\n            params.use_mmap = conf.use_mmap;\n            params.use_mlock = conf.use_mlock;\n            llamacpp::model_load_from_file(gguf.as_ptr(), params)\n        };\n        if model.is_null() {\n            return Err(BackendError::Llamacpp(\"Failed to load model\".to_string()));\n        }\n        let ctx = unsafe {\n            let mut params = llamacpp::context_default_params();\n            params.n_ctx = conf.max_batch_total_tokens as _;\n            params.n_batch = conf.max_batch_total_tokens as _;\n            params.n_ubatch = conf.max_physical_batch_total_tokens as _;\n            params.n_seq_max = conf.max_batch_size as _;\n            params.n_threads = conf.n_threads as _;\n            params.n_threads_batch = conf.n_threads_batch as _;\n            params.defrag_thold = conf.defrag_threshold;\n            params.offload_kqv = conf.offload_kqv;\n            params.flash_attn = conf.flash_attention;\n            params.type_k = conf.type_k.to_ggml_type();\n            params.type_v = conf.type_v.to_ggml_type();\n            params.no_perf = true;\n            llamacpp::init_from_model(model, params)\n        };\n        if ctx.is_null() {\n            return Err(BackendError::Llamacpp(\"Failed to init context\".to_string()));\n        }\n        let vocab = unsafe { llamacpp::model_get_vocab(model) };\n        if vocab.is_null() {\n            return Err(BackendError::Llamacpp(\"Failed to get vocab\".to_string()));\n        }\n        let n_tokens = unsafe { llamacpp::vocab_n_tokens(vocab) };\n        let mut logprobs = Vec::with_capacity(n_tokens as usize);\n\n        for token in 0..n_tokens {\n            logprobs.push(llamacpp::llama_token_data {\n                id: token,\n                logit: 0.0,\n                p: 0.0,\n            });\n        }\n        let batch = unsafe { llamacpp::batch_init(conf.max_batch_total_tokens as _, 0, 1) };\n        Ok(Llamacpp {\n            model,\n            ctx,\n            vocab,\n            logprobs,\n            batch,\n        })\n    }\n\n    fn decode(&mut self) -> i32 {\n        unsafe { llamacpp::decode(self.ctx, self.batch) }\n    }\n\n    fn clear_kv_cache(&mut self, seq_id: llamacpp::llama_seq_id) {\n        unsafe {\n            llamacpp::kv_cache_seq_rm(self.ctx, seq_id, -1, -1);\n        }\n    }\n\n    fn batch_push(\n        &mut self,\n        token: llamacpp::llama_token,\n        pos: llamacpp::llama_pos,\n        seq_id: llamacpp::llama_seq_id,\n        logits: bool,\n    ) -> usize {\n        let n = self.batch.n_tokens as usize;\n        unsafe {\n            *self.batch.token.add(n) = token;\n            *self.batch.pos.add(n) = pos;\n            *self.batch.n_seq_id.add(n) = 1;\n            *(*self.batch.seq_id.add(n)).add(0) = seq_id;\n            *self.batch.logits.add(n) = logits as i8;\n        }\n        self.batch.n_tokens += 1;\n        n\n    }\n}\n\nimpl Drop for Llamacpp {\n    fn drop(&mut self) {\n        if !self.ctx.is_null() {\n            unsafe { llamacpp::free(self.ctx) };\n        }\n        if !self.model.is_null() {\n            unsafe { llamacpp::model_free(self.model) };\n        }\n        unsafe { llamacpp::batch_free(self.batch) };\n    }\n}\n\nstruct LlamacppSampler {\n    chain: *mut llamacpp::llama_sampler,\n}\n\nimpl LlamacppSampler {\n    fn new(req: &LlamacppRequest) -> Option<Self> {\n        let chain = unsafe {\n            let params = llamacpp::sampler_chain_default_params();\n            llamacpp::sampler_chain_init(params)\n        };\n        if chain.is_null() {\n            error!(\"Failed to init sampler\");\n            return None;\n        }\n        let (top_k, top_p, typical_p, temp, penalties, dist) = unsafe {\n            (\n                llamacpp::sampler_init_top_k(req.top_k),\n                llamacpp::sampler_init_top_p(req.top_p, req.min_keep),\n                llamacpp::sampler_init_typical(req.typical_p, req.min_keep),\n                llamacpp::sampler_init_temp(req.temp),\n                llamacpp::sampler_init_penalties(\n                    req.penalty_last_n,\n                    req.penalty_repeat,\n                    req.penalty_freq,\n                    req.penalty_present,\n                ),\n                llamacpp::sampler_init_dist(req.seed),\n            )\n        };\n        let all = &[\n            (\"top_k\", top_k),\n            (\"top_p\", top_p),\n            (\"typical_p\", typical_p),\n            (\"temp\", temp),\n            (\"penalties\", penalties),\n            (\"dist\", dist),\n        ];\n        let mut failed = false;\n\n        for (k, v) in all {\n            if v.is_null() {\n                error!(\"Failed to init {k} sampler\");\n                failed = true;\n            } else {\n                unsafe { llamacpp::sampler_chain_add(chain, *v) };\n            }\n        }\n        if failed {\n            unsafe { llamacpp::sampler_free(chain) };\n            None\n        } else {\n            Some(LlamacppSampler { chain })\n        }\n    }\n\n    fn sample(&self, llamacpp: &mut Llamacpp, idx: usize) -> (llamacpp::llama_token, f32) {\n        let logits = unsafe { llamacpp::get_logits_ith(llamacpp.ctx, idx as _) };\n        for (token, logprob) in llamacpp.logprobs.iter_mut().enumerate() {\n            *logprob = llamacpp::llama_token_data {\n                id: token as _,\n                logit: unsafe { *logits.add(token) },\n                p: 0.0,\n            };\n        }\n        let mut view = llamacpp::llama_token_data_array {\n            data: llamacpp.logprobs.as_mut_ptr(),\n            size: llamacpp.logprobs.len(),\n            selected: -1,\n            sorted: false,\n        };\n        unsafe {\n            llamacpp::sampler_apply(self.chain, &mut view);\n            let logprob = *view.data.offset(view.selected as _);\n            llamacpp::sampler_accept(self.chain, logprob.id);\n            (logprob.id, logprob.p.ln())\n        }\n    }\n}\n\nimpl Drop for LlamacppSampler {\n    fn drop(&mut self) {\n        if !self.chain.is_null() {\n            unsafe { llamacpp::sampler_free(self.chain) };\n        }\n    }\n}\n\nstruct LlamacppSeq {\n    id: usize,\n    batch_pos: usize,\n    token: llamacpp::llama_token,\n    pos: llamacpp::llama_pos,\n    sampler: LlamacppSampler,\n    text: String,\n    n_new_tokens: usize,\n    running: bool,\n}\n\nstatic INIT: Once = Once::new();\n\nimpl LlamacppBackend {\n    pub fn new(\n        conf: LlamacppConfig,\n        tokenizer: Tokenizer,\n    ) -> (\n        Self,\n        oneshot::Receiver<Result<(), BackendError>>,\n        watch::Sender<bool>,\n    ) {\n        // Setup llama & export logs, once and for all\n        INIT.call_once(|| unsafe {\n            llamacpp::log_set(Some(llamacpp_log_callback), std::ptr::null_mut());\n            llamacpp::backend_init();\n            llamacpp::numa_init(match conf.numa {\n                LlamacppNuma::Disabled => llamacpp::GGML_NUMA_STRATEGY_DISABLED,\n                LlamacppNuma::Distribute => llamacpp::GGML_NUMA_STRATEGY_DISTRIBUTE,\n                LlamacppNuma::Isolate => llamacpp::GGML_NUMA_STRATEGY_ISOLATE,\n                LlamacppNuma::Numactl => llamacpp::GGML_NUMA_STRATEGY_NUMACTL,\n                LlamacppNuma::Mirror => llamacpp::GGML_NUMA_STRATEGY_MIRROR,\n            });\n        });\n\n        let (status_tx, status_rx) = watch::channel(false);\n        let (shutdown_tx, shutdown_rx) = watch::channel(false);\n        let (ok_tx, ok_rx) = oneshot::channel();\n        let (tx, mut rx) = unbounded_channel::<LlamacppRequest>();\n        let (sync_tx, sync_rx) = mpsc::channel();\n\n        spawn(async move {\n            let mut n_tokens = 0;\n            let mut requests = Vec::with_capacity(conf.max_batch_size);\n\n            let flush = |requests: &mut Vec<_>, n_tokens: &mut usize| {\n                if !requests.is_empty() {\n                    let _ =\n                        sync_tx.send(replace(requests, Vec::with_capacity(conf.max_batch_size)));\n                    *n_tokens = 0;\n                }\n            };\n            loop {\n                match timeout(conf.batch_timeout, rx.recv()).await {\n                    Ok(Some(request)) => {\n                        let n_tokens_to_add = request.input_ids.len();\n\n                        if n_tokens + n_tokens_to_add > conf.max_batch_total_tokens {\n                            flush(&mut requests, &mut n_tokens);\n                        }\n                        n_tokens += n_tokens_to_add;\n                        requests.push(request);\n\n                        if requests.len() == conf.max_batch_size {\n                            flush(&mut requests, &mut n_tokens);\n                        }\n                    }\n                    Ok(None) => break,                             // closed\n                    Err(_) => flush(&mut requests, &mut n_tokens), // timeout\n                }\n            }\n        });\n\n        spawn_blocking(move || {\n            let mut llamacpp = match Llamacpp::new(conf) {\n                Ok(v) => {\n                    let _ = ok_tx.send(Ok(()));\n                    v\n                }\n                Err(e) => {\n                    let _ = ok_tx.send(Err(e));\n                    return;\n                }\n            };\n            let vocab = tokenizer.get_added_vocabulary();\n\n            // health() returns true\n            let _ = status_tx.send(true);\n\n            while let Ok(requests) = sync_rx.recv() {\n                if *shutdown_rx.borrow() {\n                    break;\n                }\n                let start_time = Instant::now();\n                let mut seqs: Vec<LlamacppSeq> = Vec::with_capacity(requests.len());\n                llamacpp.batch.n_tokens = 0;\n\n                for (seq_id, request) in requests.iter().enumerate() {\n                    debug!(\"Request: {:?}\", request);\n                    // TODO remove this\n                    let sampler = match LlamacppSampler::new(request) {\n                        Some(sampler) => sampler,\n                        _ => {\n                            let _ = request.tx.send(Err(InferError::IncompleteGeneration));\n                            continue;\n                        }\n                    };\n                    let last_pos = request.input_ids.len() - 1;\n\n                    for (pos, &token_id) in request.input_ids.iter().enumerate() {\n                        llamacpp.batch_push(\n                            token_id as llamacpp::llama_token,\n                            pos as llamacpp::llama_pos,\n                            seq_id as llamacpp::llama_seq_id,\n                            pos == last_pos, // check samplers\n                        );\n                    }\n                    seqs.push(LlamacppSeq {\n                        id: seq_id,\n                        batch_pos: llamacpp.batch.n_tokens as usize - 1,\n                        token: llamacpp::LLAMA_TOKEN_NULL,\n                        pos: last_pos as llamacpp::llama_pos + 1,\n                        sampler,\n                        text: String::with_capacity(1024),\n                        n_new_tokens: 0,\n                        running: true,\n                    });\n                }\n                while llamacpp.batch.n_tokens > 0 {\n                    if llamacpp.decode() != 0 {\n                        warn!(\"llama_decode failed, clearing kv cache\");\n                        llamacpp.clear_kv_cache(-1);\n                        for seq in seqs.iter_mut() {\n                            let _ = requests[seq.id]\n                                .tx\n                                .send(Err(InferError::IncompleteGeneration));\n                            seq.running = false;\n                        }\n                        break;\n                    }\n                    for seq in seqs.iter_mut() {\n                        if !seq.running {\n                            continue;\n                        }\n                        let (next, logprob) = seq.sampler.sample(&mut llamacpp, seq.batch_pos);\n                        seq.n_new_tokens += 1;\n                        seq.token = next;\n\n                        let piece = match tokenizer.decode(&[next as u32], false) {\n                            Ok(piece) => piece,\n                            Err(e) => {\n                                error!(\"Failed to decode token: {e}\");\n                                let _ = requests[seq.id]\n                                    .tx\n                                    .send(Err(InferError::IncompleteGeneration));\n                                seq.running = false;\n                                continue;\n                            }\n                        };\n                        let special = vocab.is_special_token(&piece);\n\n                        if !special {\n                            seq.text.push_str(&piece);\n                        }\n                        let token = Token {\n                            id: next as _,\n                            text: piece,\n                            logprob,\n                            special,\n                        };\n                        let finish: Option<FinishReason> = {\n                            if unsafe { llamacpp::vocab_is_eog(llamacpp.vocab, next) } {\n                                Some(FinishReason::EndOfSequenceToken)\n                            } else if seq.n_new_tokens == requests[seq.id].max_new_tokens {\n                                Some(FinishReason::Length)\n                            } else {\n                                None\n                            }\n                        };\n                        if let Some(reason) = finish {\n                            let _ = requests[seq.id].tx.send(Ok(InferStreamResponse::End {\n                                token,\n                                top_tokens: vec![],\n                                generated_text: GeneratedText {\n                                    text: seq.text.clone(),\n                                    generated_tokens: seq.n_new_tokens as _,\n                                    finish_reason: reason,\n                                    seed: Some(requests[seq.id].seed as _),\n                                },\n                                start: start_time,\n                                queued: requests[seq.id].time,\n                            }));\n                            seq.running = false;\n                            continue;\n                        }\n                        let _ = requests[seq.id]\n                            .tx\n                            .send(Ok(InferStreamResponse::Intermediate {\n                                token,\n                                top_tokens: vec![],\n                            }));\n                    }\n                    // generate a new batch\n                    llamacpp.batch.n_tokens = 0;\n\n                    for seq in seqs.iter_mut() {\n                        if seq.running {\n                            seq.batch_pos =\n                                llamacpp.batch_push(seq.token, seq.pos, seq.id as _, true);\n                            seq.pos += 1;\n                        } else {\n                            llamacpp.clear_kv_cache(seq.id as _);\n                        }\n                    }\n                }\n            }\n        });\n        (\n            Self {\n                tx,\n                status: status_rx,\n            },\n            ok_rx,\n            shutdown_tx,\n        )\n    }\n}\n\n#[async_trait]\nimpl Backend for LlamacppBackend {\n    #[instrument(skip_all)]\n    fn schedule(\n        &self,\n        request: ValidGenerateRequest,\n    ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {\n        debug!(?request);\n        let (tx, rx) = unbounded_channel::<Result<InferStreamResponse, InferError>>();\n        match LlamacppRequest::new(&request, tx) {\n            Some(v) => match self.tx.send(v) {\n                Err(e) => Err(InferError::GenerationError(e.to_string())),\n                _ => Ok(UnboundedReceiverStream::new(rx)),\n            },\n            _ => Err(InferError::GenerationError(\"Bad request\".to_string())),\n        }\n    }\n\n    async fn health(&self, _: bool) -> bool {\n        *self.status.borrow()\n    }\n\n    fn name(&self) -> &'static str {\n        \"llamacpp\"\n    }\n}\n\n#[derive(Debug, Error)]\npub enum BackendError {\n    #[error(\"CString error: {0}\")]\n    CStringError(#[from] std::ffi::NulError),\n    #[error(\"Llamacpp error: {0}\")]\n    Llamacpp(String),\n}\n"
  },
  {
    "path": "backends/llamacpp/src/llamacpp.rs",
    "content": "#![allow(non_upper_case_globals)]\n#![allow(non_camel_case_types)]\n#![allow(non_snake_case)]\n#![allow(dead_code)]\ninclude!(concat!(env!(\"OUT_DIR\"), \"/llamacpp.rs\"));\n"
  },
  {
    "path": "backends/llamacpp/src/main.rs",
    "content": "mod backend;\nmod llamacpp;\nmod quantize;\n\nuse quantize::QuantizeType;\n\nuse backend::{\n    BackendError, LlamacppBackend, LlamacppConfig, LlamacppGGMLType, LlamacppNuma,\n    LlamacppSplitMode,\n};\nuse clap::Parser;\nuse hf_hub::api::tokio::ApiBuilder;\nuse hf_hub::{Repo, RepoType};\nuse std::path::Path;\nuse text_generation_router::{logging, server, usage_stats};\nuse thiserror::Error;\nuse tokenizers::Tokenizer;\nuse tokio::process::Command;\nuse tokio::sync::oneshot::error::RecvError;\nuse tracing::{error, warn};\n\n/// Backend Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    /// Name of the model to load.\n    #[clap(long, env)]\n    model_id: String,\n\n    /// Revision of the model.\n    #[clap(default_value = \"main\", long, env)]\n    revision: String,\n\n    /// Path to the GGUF model file for inference.\n    #[clap(long, env)]\n    model_gguf: Option<String>,\n\n    /// Number of threads to use for generation.\n    #[clap(long, env)]\n    n_threads: Option<usize>,\n\n    /// Number of threads to use for batch processing.\n    #[clap(long, env)]\n    n_threads_batch: Option<usize>,\n\n    /// Number of layers to store in VRAM.\n    #[clap(default_value = \"0\", long, env)]\n    n_gpu_layers: usize,\n\n    /// Split the model across multiple GPUs.\n    #[clap(default_value = \"layer\", long, env)]\n    split_mode: LlamacppSplitMode,\n\n    /// Defragment the KV cache if holes/size > threshold.\n    #[clap(default_value = \"-1.0\", long, env)]\n    defrag_threshold: f32,\n\n    /// Enable NUMA optimizations.\n    #[clap(default_value = \"disabled\", value_enum, long, env)]\n    numa: LlamacppNuma,\n\n    /// Use memory mapping for the model.\n    #[clap(long, env)]\n    disable_mmap: bool,\n\n    /// Use memory locking to prevent swapping.\n    #[clap(long, env)]\n    use_mlock: bool,\n\n    /// Enable offloading of KQV operations to the GPU.\n    #[clap(long, env)]\n    disable_offload_kqv: bool,\n\n    /// Enable flash attention for faster inference. (EXPERIMENTAL)\n    #[clap(long, env)]\n    disable_flash_attention: bool,\n\n    /// Data type used for K cache.\n    #[clap(default_value = \"f16\", value_enum, long, env)]\n    type_k: LlamacppGGMLType,\n\n    /// Data type used for V cache.\n    #[clap(default_value = \"f16\", value_enum, long, env)]\n    type_v: LlamacppGGMLType,\n\n    /// Number of tokenizer workers used for payload validation and truncation.\n    #[clap(default_value = \"2\", long, env)]\n    validation_workers: usize,\n\n    /// Maximum number of concurrent requests.\n    #[clap(long, env)]\n    max_concurrent_requests: Option<usize>,\n\n    /// Maximum number of input tokens per request.\n    #[clap(default_value = \"1024\", long, env)]\n    max_input_tokens: usize,\n\n    /// Maximum number of total tokens (input + output) per request.\n    #[clap(default_value = \"2048\", long, env)]\n    max_total_tokens: usize,\n\n    /// Maximum number of tokens in a batch.\n    #[clap(long, env)]\n    max_batch_total_tokens: Option<usize>,\n\n    /// Maximum number of tokens in a physical batch.\n    #[clap(long, env)]\n    max_physical_batch_total_tokens: Option<usize>,\n\n    /// Maximum number of requests per batch.\n    #[clap(long, env)]\n    max_batch_size: Option<usize>,\n\n    /// IP address to listen on.\n    #[clap(default_value = \"0.0.0.0\", long)]\n    hostname: String,\n\n    /// Port to listen on.\n    #[clap(default_value = \"3000\", long, short, env)]\n    port: u16,\n\n    #[clap(default_value = \"9000\", long, short, env)]\n    prometheus_port: u16,\n\n    /// Enable JSON output format.\n    #[clap(long, env)]\n    json_output: bool,\n\n    /// OTLP endpoint for telemetry data.\n    #[clap(long, env)]\n    otlp_endpoint: Option<String>,\n\n    /// Service name for OTLP telemetry.\n    #[clap(default_value = \"text-generation-inference.router\", long, env)]\n    otlp_service_name: String,\n\n    /// Allowed origins for CORS.\n    #[clap(long, env)]\n    cors_allow_origin: Option<Vec<String>>,\n\n    /// Path to the tokenizer configuration file.\n    #[clap(long, env)]\n    tokenizer_config_path: Option<String>,\n\n    /// Disable grammar support.\n    #[clap(long, env)]\n    disable_grammar_support: bool,\n\n    /// Maximum number of inputs per request.\n    #[clap(default_value = \"4\", long, env)]\n    max_client_batch_size: usize,\n\n    /// Level of usage statistics collection.\n    #[clap(default_value = \"on\", long, env)]\n    usage_stats: usage_stats::UsageStatsLevel,\n\n    /// Maximum payload size in bytes.\n    #[clap(default_value = \"2000000\", long, env)]\n    payload_limit: usize,\n\n    /// Maximum image fetch size in bytes.\n    #[clap(default_value = \"1073741824\", long, env)]\n    max_image_fetch_size: usize,\n}\n\n#[tokio::main]\nasync fn main() -> Result<(), RouterError> {\n    let args = Args::parse();\n\n    logging::init_logging(args.otlp_endpoint, args.otlp_service_name, args.json_output);\n\n    let n_threads = match args.n_threads {\n        Some(0) | None => num_cpus::get(),\n        Some(threads) => threads,\n    };\n    let n_threads_batch = match args.n_threads_batch {\n        Some(0) | None => n_threads,\n        Some(threads) => threads,\n    };\n    let max_batch_size = match args.max_batch_size {\n        Some(0) | None => n_threads_batch,\n        Some(threads) => threads,\n    };\n    let max_batch_total_tokens = match args.max_batch_total_tokens {\n        None => max_batch_size * args.max_total_tokens,\n        Some(size) => size,\n    };\n    let max_physical_batch_total_tokens = match args.max_physical_batch_total_tokens {\n        None => max_batch_total_tokens,\n        Some(size) => size,\n    };\n    let max_concurrent_requests = match args.max_concurrent_requests {\n        None => max_batch_size * 2,\n        Some(size) => size,\n    };\n    if args.max_input_tokens >= args.max_total_tokens {\n        return Err(RouterError::ArgumentValidation(\n            \"`max_input_tokens` must be < `max_total_tokens`\".to_string(),\n        ));\n    }\n    if args.max_total_tokens > max_batch_total_tokens {\n        return Err(RouterError::ArgumentValidation(\n            \"`max_total_tokens` must be <= `max_batch_total_tokens`\".to_string(),\n        ));\n    }\n    if max_batch_size * args.max_total_tokens > max_batch_total_tokens {\n        return Err(RouterError::ArgumentValidation(\n            \"`max_batch_size` * `max_total_tokens` must be <= `max_batch_total_tokens`\".to_string(),\n        ));\n    }\n\n    let api_builder = || {\n        let mut builder = ApiBuilder::new().with_progress(true);\n\n        if let Ok(cache_dir) = std::env::var(\"HUGGINGFACE_HUB_CACHE\") {\n            builder = builder.with_cache_dir(cache_dir.into());\n        }\n        if let Ok(token) = std::env::var(\"HF_TOKEN\") {\n            builder = builder.with_token(token.into());\n        }\n        if let Ok(origin) = std::env::var(\"HF_HUB_USER_AGENT_ORIGIN\") {\n            builder = builder.with_user_agent(\"origin\", origin.as_str());\n        }\n        builder\n    };\n    let api_repo = api_builder().build()?.repo(Repo::with_revision(\n        args.model_id.clone(),\n        RepoType::Model,\n        args.revision.clone(),\n    ));\n\n    let tokenizer_path = api_repo.get(\"tokenizer.json\").await?;\n    let tokenizer = Tokenizer::from_file(&tokenizer_path)?;\n\n    let model_gguf = if let Some(model_gguf) = args.model_gguf {\n        model_gguf\n    } else {\n        let model_gguf = format!(\"models/{}/model.gguf\", args.model_id);\n        let model_gguf_path = Path::new(&model_gguf);\n\n        if !model_gguf_path.exists() {\n            let tmp_gguf = \"models/tmp.gguf\";\n\n            if let Some(parent) = Path::new(model_gguf_path).parent() {\n                std::fs::create_dir_all(parent)?;\n            }\n            let cache_path = tokenizer_path.parent().unwrap();\n\n            for sibling in api_repo.info().await?.siblings {\n                let _ = api_repo.get(&sibling.rfilename).await?;\n            }\n            let status = Command::new(\"convert_hf_to_gguf.py\")\n                .arg(\"--outfile\")\n                .arg(tmp_gguf)\n                .arg(cache_path)\n                .spawn()?\n                .wait()\n                .await?;\n\n            if !status.success() {\n                let exit_code = status.code().unwrap_or(-1);\n                error!(\"Failed to generate GGUF, exit code: {}\", exit_code);\n                return Err(RouterError::CommandError(exit_code));\n            }\n            quantize::model(tmp_gguf, &model_gguf, QuantizeType::MostlyQ4_0, n_threads)\n                .map_err(RouterError::QuantizeError)?;\n        }\n        model_gguf\n    };\n\n    let (backend, ok, shutdown) = LlamacppBackend::new(\n        LlamacppConfig {\n            model_gguf,\n            n_threads,\n            n_threads_batch,\n            n_gpu_layers: args.n_gpu_layers,\n            split_mode: args.split_mode,\n            defrag_threshold: args.defrag_threshold,\n            numa: args.numa,\n            use_mmap: !args.disable_mmap,\n            use_mlock: args.use_mlock,\n            flash_attention: !args.disable_flash_attention,\n            type_k: args.type_k,\n            type_v: args.type_v,\n            offload_kqv: !args.disable_offload_kqv,\n            max_batch_total_tokens,\n            max_physical_batch_total_tokens,\n            max_batch_size,\n            batch_timeout: tokio::time::Duration::from_millis(5),\n        },\n        tokenizer,\n    );\n    ok.await??;\n\n    if cfg!(debug_assertions) {\n        warn!(\"Graceful shutdown disabled!\");\n        let _ = tokio::task::spawn(async move {\n            let _ = tokio::signal::ctrl_c().await;\n            let _ = shutdown.send(true);\n        });\n    }\n\n    server::run(\n        backend,\n        max_concurrent_requests,\n        0, // max_best_of\n        0, // max_stop_sequences\n        0, // max_top_n_tokens\n        args.max_input_tokens,\n        args.max_total_tokens,\n        args.validation_workers,\n        None,          // api_key\n        args.model_id, // tokenizer_name\n        args.tokenizer_config_path,\n        Some(args.revision),\n        false, // trust_remote_code\n        args.hostname,\n        args.port,\n        args.cors_allow_origin,\n        false, // ngrok,\n        None,  // ngrok_authtoken,\n        None,  // ngrok_edge,\n        args.disable_grammar_support,\n        args.max_client_batch_size,\n        args.usage_stats,\n        args.payload_limit,\n        args.max_image_fetch_size,\n        args.prometheus_port,\n    )\n    .await?;\n    Ok(())\n}\n\n#[derive(Debug, Error)]\nenum RouterError {\n    #[error(\"Argument validation error: {0}\")]\n    ArgumentValidation(String),\n    #[error(\"Tokenizer error: {0}\")]\n    Tokenizer(#[from] tokenizers::Error),\n    #[error(\"Backend error: {0}\")]\n    Backend(#[from] BackendError),\n    #[error(\"WebServer error: {0}\")]\n    WebServer(#[from] server::WebServerError),\n    #[error(\"Recv error: {0}\")]\n    RecvError(#[from] RecvError),\n    #[error(\"Io error: {0}\")]\n    IoError(#[from] std::io::Error),\n    #[error(\"Var error: {0}\")]\n    VarError(#[from] std::env::VarError),\n    #[error(\"Quantize error: {0}\")]\n    QuantizeError(String),\n    #[error(\"Command error: {0}\")]\n    CommandError(i32),\n    #[error(\"HF hub error: {0}\")]\n    HubError(#[from] hf_hub::api::tokio::ApiError),\n}\n"
  },
  {
    "path": "backends/llamacpp/src/quantize.rs",
    "content": "use crate::llamacpp;\n\nuse std::ffi::CString;\n\n#[repr(u32)]\n#[derive(Debug, Clone, Copy)]\npub enum QuantizeType {\n    MostlyQ4_0 = 2,\n}\n\npub fn model(\n    input_path: &str,\n    output_path: &str,\n    ftype: QuantizeType,\n    n_threads: usize,\n) -> Result<(), String> {\n    let c_input_path =\n        CString::new(input_path).map_err(|e| format!(\"Failed to convert input path: {}\", e))?;\n\n    let c_output_path =\n        CString::new(output_path).map_err(|e| format!(\"Failed to convert output path: {}\", e))?;\n\n    let result = unsafe {\n        let mut params = llamacpp::model_quantize_default_params();\n        params.nthread = n_threads as _;\n        params.ftype = ftype as _;\n        params.quantize_output_tensor = true;\n        llamacpp::model_quantize(c_input_path.as_ptr(), c_output_path.as_ptr(), &params)\n    };\n    if result == 0 {\n        Ok(())\n    } else {\n        Err(format!(\"Quantization failed, error code: {}\", result))\n    }\n}\n"
  },
  {
    "path": "backends/neuron/Cargo.toml",
    "content": "[workspace]\nmembers = [\n  \"backends/v2\",\n  \"backends/grpc-metadata\",\n  \"launcher\",\n  \"router\"\n]\ndefault-members = [\n  \"backends/v2\",\n  \"backends/grpc-metadata\",\n  \"launcher\",\n  \"router\"\n]\nresolver = \"2\"\n\n[workspace.package]\nversion = \"3.0.0\"\nedition = \"2021\"\nauthors = [\"Olivier Dehaene\"]\nhomepage = \"https://github.com/huggingface/text-generation-inference\"\n\n[workspace.dependencies]\nbase64 = \"0.22.0\"\ntokenizers = { version = \"0.20.0\", features = [\"http\"] }\nhf-hub = { version = \"0.4.2\", features = [\"tokio\"] }\nmetrics = { version = \"0.23.0\" }\nmetrics-exporter-prometheus = { version = \"0.15.1\", features = [] }\nminijinja = { version = \"2.2.0\", features = [\"json\"] }\nminijinja-contrib = { version = \"2.0.2\", features = [\"pycompat\"] }\npyo3 = { version = \"0.22.2\", features = [\"auto-initialize\"] }\n\n[profile.release]\nincremental = true\n\n[profile.release-binary]\ninherits = \"release\"\ndebug = 1\nincremental = true\npanic = \"abort\"\n\n[profile.release-opt]\ninherits = \"release\"\ndebug = 0\nincremental = false\nlto = \"fat\"\nopt-level = 3\ncodegen-units = 1\n"
  },
  {
    "path": "backends/neuron/Makefile",
    "content": "#  Copyright 2025 The HuggingFace Team. All rights reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\n\nmkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))\nmkfile_dir := $(dir $(mkfile_path))\nroot_dir := \"${mkfile_dir}/../..\"\n\n.PHONY:\timage install_server test_server test_integration\n\nVERSION := $(shell gawk 'match($$0, /^version = \"(.*)\"/, a) {print a[1]}' ${root_dir}/Cargo.toml)\n\nimage:\n\tdocker build --rm -f ${root_dir}/Dockerfile.neuron \\\n\t\t\t\t --ulimit nofile=100000:100000 \\\n\t\t\t\t --build-arg VERSION=$(VERSION) \\\n\t\t\t\t -t text-generation-inference:$(VERSION)-neuron ${root_dir}\n\tdocker tag text-generation-inference:$(VERSION)-neuron text-generation-inference:latest-neuron\n\ninstall_server:\n\tmake -C ${mkfile_dir}/server install VERSION:=${VERSION}\n\ntest_server: install_server\n\tpython -m pip install -r ${mkfile_dir}/tests/requirements.txt\n\tpython -m pytest -sv ${mkfile_dir}/tests/server\n"
  },
  {
    "path": "backends/neuron/README.md",
    "content": "# Text-generation-inference - Neuron backend for AWS Trainium and inferentia2\n\n## Description\n\nThis is the TGI backend for AWS Neuron Trainium and Inferentia family of chips.\n\nThis backend is composed of:\n- the AWS Neuron SDK,\n- the legacy v2 TGI launcher and router,\n- a neuron specific inference server for text-generation.\n\n## Usage\n\nPlease refer to the official [documentation](https://huggingface.co/docs/text-generation-inference/backends/neuron).\n\n## Build your own image\n\nThe simplest way to build TGI with the neuron backend is to use the provided `Makefile`:\n\n```shell\n$ make -C backends/neuron image\n```\n\nAlternatively, you can build the image directly from the top directory using a command similar to the one defined\nin the `Makefile` under the `image` target.\n"
  },
  {
    "path": "backends/neuron/server/.gitignore",
    "content": "build\n"
  },
  {
    "path": "backends/neuron/server/Makefile",
    "content": "# Initialize base variables\nSHELL := /bin/bash\npkg_name := text_generation_server\nBUILDDIR ?= $(CURDIR)/build\nVERSION ?= 0.0.1\nmkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))\nmkfile_dir := $(dir $(mkfile_path))\npkg_dir := $(BUILDDIR)/$(pkg_name)\npy_version := $(subst -,.,${VERSION})\npkg_dist := ${BUILDDIR}/dist/${pkg_name}-$(py_version).tar.gz\n\nclean:\n\trm -rf $(BUILDDIR)/*\n\n${BUILDDIR}:\n\tinstall -d $@\n\n# List static sources to be deployed in the package\nsrc_dir := $(mkfile_dir)/$(pkg_name)\nsources := $(wildcard $(src_dir)/*.py)\ndeployed_sources := $(subst $(src_dir), $(pkg_dir), $(sources))\n\n# Static files are just copied\n\ndefine COPY\n\tcp -f $< $@\nendef\n\n# We use a PHONY target to represent the VERSION\n.PHONY: VERSION\n\nVERSION: ${BUILDDIR}\n\t# The trick is to compare the value of the variable with the content of a file in the build directory\n\t@if [[ `cat ${BUILDDIR}/VERSION 2>&1` != '$(VERSION)' ]]; then echo -n $(VERSION) >${BUILDDIR}/VERSION; fi\n\n# Depending on the PHONY VERSION target makes sure the pyproject.toml is regenerated if the version changes\n$(BUILDDIR)/pyproject.toml: $(mkfile_dir)/pyproject.toml VERSION\n\tmkdir -p $(BUILDDIR)\n\t$(COPY)\n\tsed -i -e 's/version = \"VERSION\"/version = \\\"${VERSION}\\\"/' $@\n\n$(pkg_dir)/%.py: $(src_dir)/%.py\n\tmkdir -p $(pkg_dir)\n\t$(COPY)\n\n# Generated files are produced by grpcio tools\n\n# If not provided, get local proto files\nifndef PROTODIR\nPROTODIR := $(mkfile_dir)/../../../proto\nendif\n\n# Three python files are generated for each protobuf\nprotobufs := $(PROTODIR)/generate.proto\npkg_pb_dir := $(pkg_dir)/pb\ngenerated_sources_base := $(foreach proto, $(protobufs), $(proto:.proto=_pb2.py))\ngenerated_sources := $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base))\ngenerated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=.pyi))\ngenerated_sources += $(subst $(PROTODIR), $(pkg_pb_dir), $(generated_sources_base:.py=_grpc.py))\n\n$(pkg_pb_dir)/%_pb2.py $(pkg_pb_dir)/%_pb2.pyi $(pkg_pb_dir)/%_pb2_grpc.py: $(PROTODIR)/%.proto\n\tmkdir -p $(pkg_pb_dir)\n\tpython -m grpc_tools.protoc -I$(PROTODIR) --python_out=$(pkg_pb_dir) \\\n\t\t--grpc_python_out=$(pkg_pb_dir) --mypy_out=$(pkg_pb_dir) $^\n\tsed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' $(pkg_pb_dir)/$*_pb2_grpc.py\n\n${pkg_dist}: $(BUILDDIR)/pyproject.toml $(deployed_sources) $(generated_sources)\n\tpython -m build $(BUILDDIR)\n\npackage: ${pkg_dist}\n\ninstall: ${pkg_dist}\n\tpython3 -m pip uninstall -y ${pkg_name}\n\tpython3 -m pip install ${pkg_dist}\n"
  },
  {
    "path": "backends/neuron/server/build-requirements.txt",
    "content": "build\ngrpcio-tools==1.53.0\nmypy-protobuf\n"
  },
  {
    "path": "backends/neuron/server/pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=78.1\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"text-generation-server\"\nversion = \"VERSION\"\nauthors = [{name=\"David Corvoysier\", email=\"david@huggingface.co\" }]\ndescription = \"TGI compatible inference server for AWS Neuronx platforms\"\ndependencies = [\n    'protobuf > 3.20.1, < 4',\n    'grpcio == 1.57.0',\n    'grpcio-status == 1.48.2',\n    'grpcio-reflection == 1.48.2',\n    'grpc-interceptor == 0.15.2',\n    'typer == 0.6.1',\n    'safetensors',\n    'loguru == 0.6.0',\n    'optimum-neuron[neuronx] >= 0.0.28',\n]\n\n[tool.setuptools]\npackages = [\"text_generation_server\", \"text_generation_server.pb\"]\n\n[project.scripts]\ntext-generation-server = 'text_generation_server.cli:app'\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/cli.py",
    "content": "import sys\nfrom typing import Optional\n\nimport typer\nfrom loguru import logger\n\n\napp = typer.Typer()\n\n\n@app.command()\ndef serve(\n    model_id: str,\n    revision: Optional[str] = None,\n    sharded: bool = False,\n    trust_remote_code: bool = None,\n    uds_path: str = \"/tmp/text-generation-server\",\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    otlp_endpoint: Optional[str] = None,\n    otlp_service_name: str = \"text-generation-inference.server\",\n    max_input_tokens: Optional[int] = None,\n):\n    \"\"\"This is the main entry-point for the server CLI.\n\n    Args:\n        model_id (`str`):\n            The *model_id* of a model on the HuggingFace hub or the path to a local model.\n        revision (`Optional[str]`, defaults to `None`):\n            The revision of the model on the HuggingFace hub.\n        sharded (`bool`):\n            Whether the model must be sharded or not. Kept for compatibility with the\n            text-generation-launcher, but must be set to False.\n        trust-remote-code (`bool`):\n            Kept for compatibility with text-generation-launcher. Ignored.\n        uds_path (`Union[Path, str]`):\n            The local path on which the server will expose its google RPC services.\n        logger_level (`str`):\n            The server logger level. Defaults to *INFO*.\n        json_output (`bool`):\n            Use JSON format for log serialization.\n        otlp_endpoint (`Optional[str]`, defaults to `None`):\n            The Open Telemetry endpoint to use.\n        otlp_service_name (`Optional[str]`, defaults to `None`):\n            The name to use when pushing data to the Open Telemetry endpoint.\n        max_input_tokens (`Optional[int]`, defaults to `None`):\n            The maximum number of input tokens each request should contain.\n    \"\"\"\n    if sharded:\n        raise ValueError(\"Sharding is not supported.\")\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    if trust_remote_code is not None:\n        logger.warning(\n            \"'trust_remote_code' argument is not supported and will be ignored.\"\n        )\n\n    # Import here after the logger is added to log potential import exceptions\n    from .server import serve\n\n    serve(model_id, revision, uds_path)\n\n\n@app.command()\ndef download_weights(\n    model_id: str,\n    revision: Optional[str] = None,\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    auto_convert: Optional[bool] = None,\n    extension: Optional[str] = None,\n    trust_remote_code: Optional[bool] = None,\n    merge_lora: Optional[bool] = None,\n):\n    \"\"\"Download the model weights.\n\n    This command will be called by text-generation-launcher before serving the model.\n    \"\"\"\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    if extension is not None:\n        logger.warning(\"'extension' argument is not supported and will be ignored.\")\n    if trust_remote_code is not None:\n        logger.warning(\n            \"'trust_remote_code' argument is not supported and will be ignored.\"\n        )\n    if auto_convert is not None:\n        logger.warning(\"'auto_convert' argument is not supported and will be ignored.\")\n    if merge_lora is not None:\n        logger.warning(\"'merge_lora' argument is not supported and will be ignored.\")\n\n    # Import here after the logger is added to log potential import exceptions\n    from .model import fetch_model\n\n    fetch_model(model_id, revision)\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/generator.py",
    "content": "import copy\nimport logging\nimport time\nfrom abc import ABC\nfrom enum import Enum\nfrom typing import List, Optional, Tuple\n\nimport torch\nfrom loguru import logger\nfrom transformers import AutoTokenizer, PreTrainedTokenizerBase\nfrom optimum.neuron.configuration_utils import NeuronConfig\nfrom transformers.generation import GenerationConfig\n\nfrom optimum.neuron import NeuronModelForCausalLM\nfrom optimum.neuron.generation import TokenSelector\n\nfrom .model import get_export_kwargs_from_env\nfrom .pb.generate_pb2 import (\n    Batch,\n    CachedBatch,\n    FinishReason,\n    GeneratedText,\n    Generation,\n    InfoResponse,\n    Request,\n    Tokens,\n)\n\n\n# Disable optimum-neuron warnings as it seems to block the server after a while\noptimum_logger = logging.getLogger(\"optimum.neuron\")\noptimum_logger.setLevel(\"CRITICAL\")\n\n\nclass Generator(ABC):\n    \"\"\"An abstract class to represent the workhorse behind TextGenerationService.\n\n    Ideally, it should not rely on protobuf constructs, but in a first step it does.\n    Implementations would typically need a model and a tokenizer to implement the Generator methods.\n    \"\"\"\n\n    @property\n    def info(self) -> InfoResponse:\n        \"\"\"This should simply return the expected InfoResponse\"\"\"\n        raise NotImplementedError\n\n    def warmup(self, batch: Batch) -> int:\n        \"\"\"Verify if the hardware can support the target load.\n\n        Args:\n            batch (`Batch`):\n                A batch corresponding to the maximum number of concurrent requests.\n\n        Return:\n            The maximum number of tokens the model supports.\n        \"\"\"\n        raise NotImplementedError\n\n    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:\n        \"\"\"Prefill is called whenever new requests need to be added.\n\n        When this method returns successfully, a decode method will follow\n        with both the current and newly prefilled batch(es).\n\n        Args:\n            batch (`Batch`):\n                A batch containing the new requests.\n\n        Return:\n            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.\n        \"\"\"\n        raise NotImplementedError\n\n    def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:\n        \"\"\"Decode after a prefill or another decode.\"\"\"\n        raise NotImplementedError\n\n    def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:\n        \"\"\"Remove requests that are not listed from the specified batch\"\"\"\n        raise NotImplementedError\n\n    def clear(self):\n        \"\"\"Remove all requests from the generator\"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def from_pretrained(cls, model_id: str, revision: Optional[str]):\n        \"\"\"Factory method \"a la transformers\" \"\"\"\n        raise NotImplementedError\n\n\nclass Slot:\n    \"\"\"Represents a slot in a static batch\"\"\"\n\n    class State(Enum):\n        EMPTY = 0\n        PAUSE = 1\n        READY = 2\n\n    def __init__(self, id: int, tokenizer: PreTrainedTokenizerBase):\n        self._id = id\n        self._tokenizer = tokenizer\n        self.clear()\n\n    def clear(self):\n        \"\"\"Clear the slot and mark it as available.\"\"\"\n        self._state = Slot.State.EMPTY\n        self._batch_id = None\n        self._request_id = None\n        self._inputs = \"\"\n        self._truncate = 0\n        self._generation_config = None\n        self._tokens = []\n        self._mask = torch.tensor([])\n        self._selector = None\n        self._generated_tokens = 0\n        self._next_text_token_start = 0\n        self._next_text_token_end = 0\n        self._generated_text = \"\"\n        self._next_text = \"\"\n\n    @property\n    def id(self) -> int:\n        return self._id\n\n    @property\n    def state(self) -> \"Slot.State\":\n        return self._state\n\n    @property\n    def batch_id(self) -> int:\n        return self._batch_id\n\n    @property\n    def request_id(self) -> int:\n        return self._request_id\n\n    @property\n    def cached_text(self) -> str:\n        return self._inputs + self._generated_text\n\n    @property\n    def generation_config(self) -> GenerationConfig:\n        return self._generation_config\n\n    @property\n    def generated_tokens(self) -> int:\n        return self._generated_tokens\n\n    def assign(\n        self, batch_id: int, request: Request, generation_config: GenerationConfig\n    ):\n        \"\"\"Assign a request to a slot.\n\n        Args:\n            request (`Request`):\n                The request to be assigned. Contains the inputs and tokens selection parameters.\n            generation_config (`transformers.GenerationConfig`):\n                The base generation config (might be modified by the request generation parameters).\n        \"\"\"\n        self._state = Slot.State.READY\n        self._batch_id = batch_id\n        self._request_id = request.id\n        self._inputs = request.inputs\n        if request.truncate:\n            self._truncate = request.truncate\n        self._generation_config = copy.deepcopy(generation_config)\n        # Update generation config with request parameters\n        self._generation_config.do_sample = request.parameters.do_sample\n        if self._generation_config.do_sample:\n            if request.parameters.temperature != 0:\n                self._generation_config.temperature = request.parameters.temperature\n            if request.parameters.top_k != 0:\n                self._generation_config.top_k = request.parameters.top_k\n            if request.parameters.top_p != 0:\n                self._generation_config.top_p = request.parameters.top_p\n            if request.parameters.typical_p != 0:\n                self._generation_config.typical_p = request.parameters.typical_p\n        else:\n            # Set the sampling parameters to emulate greedy decoding when using on-device sampling\n            self._generation_config.temperature = 1.0\n            self._generation_config.top_k = 1\n            self._generation_config.top_p = 1.0\n            self._generation_config.typical_p = 1.0\n        if request.parameters.repetition_penalty != 0:\n            self._generation_config.repetition_penalty = (\n                request.parameters.repetition_penalty\n            )\n        self.seed = request.parameters.seed\n        self._generation_config.max_new_tokens = (\n            request.stopping_parameters.max_new_tokens\n        )\n        self._max_new_tokens = self._generation_config.max_new_tokens\n        stop_strings = request.stopping_parameters.stop_sequences\n        if stop_strings:\n            self._generation_config.stop_strings = stop_strings\n\n    def reset(\n        self,\n        input_ids: torch.LongTensor,\n        attention_mask: torch.LongTensor,\n        selector: TokenSelector,\n    ):\n        \"\"\"Reset the slot for the next generation.\n\n        Args:\n            input_ids: (`torch.LongTensor`):\n                The new input_ids to use to generate the next token.\n            attention_mask: (`torch.LongTensor`):\n                The new attention_mask to use to generate the next token.\n            selector: (`optimum.neuron.generation.TokenSelector`):\n                An object implementing the updated token selection logic.\n        \"\"\"\n        self._tokens = input_ids.clone()\n        self._next_text_token_start = 0\n        self._next_text_token_end = torch.numel(self._tokens)\n        self._next_text = \"\"\n        self._mask = attention_mask.clone()\n        self._selector = selector\n\n    def pause(self):\n        \"\"\"Mark the current slot as paused for generation.\n\n        Note that the KV cache for this slot will still be filled.\n        \"\"\"\n        self._state = Slot.State.PAUSE\n\n    def resume(self):\n        \"\"\"Mark the slot as ready for generation.\"\"\"\n        self._state = Slot.State.READY\n\n    def _decode_next_tokens(\n        self,\n    ) -> str:\n        \"\"\"Hack to hopefully support generate_stream for the maximum number of tokenizers\"\"\"\n        # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode\n        # which decide to add a space or not depending on the surrounding ids.\n        new_text = self._tokenizer.decode(\n            self._tokens[self._next_text_token_start :], skip_special_tokens=False\n        )\n        if new_text.endswith(\"�\"):\n            # utf-8 char at the end means it's a potential unfinished byte sequence\n            # from byte fallback tokenization.\n            return \"\"\n\n        # Compare the generated text with the one using only the tokens producing the last one\n        last_text = self._tokenizer.decode(\n            self._tokens[self._next_text_token_start : self._next_text_token_end],\n            skip_special_tokens=False,\n        )\n        if len(new_text) == len(last_text):\n            # Nothing new was actually generated\n            return \"\"\n        # Return the decoded text and store its token offsets\n        self._next_text_token_start = self._next_text_token_end\n        self._next_text_token_end = torch.numel(self._tokens)\n        return new_text[len(last_text) :]\n\n    def append(self, next_token: int) -> str:\n        \"\"\"Append a new generated token to this slot\n\n        The new token is added to the list of generated tokens, which impacts\n        directly the generated_text and stopped property.\n\n        The new token is however not added immediately to the slot inputs: it will\n        be added later on when it has effectively been used to produce the next token.\n\n        Args:\n            next_token (`int`):\n                The newly generated token.\n\n        Return:\n            The corresponding decoded text (if any).\n        \"\"\"\n        self._tokens = torch.cat([self._tokens, torch.LongTensor([next_token])])\n        self._mask = torch.cat([self._mask, torch.LongTensor([1])])\n        self._generated_tokens += 1\n        next_text = self._decode_next_tokens()\n        # Now that a new token has been generated, we can append the previous one to the generated text\n        self._generated_text += self._next_text\n        self._next_text = next_text\n        return next_text\n\n    def select(\n        self, input_ids: torch.LongTensor, logits: torch.Tensor\n    ) -> torch.LongTensor:\n        \"\"\"Select the next token from the candidate logits.\n\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                The sequence used as a prompt for the generation (not used in all generation modes).\n            logits (`torch.Tensor` of shape `(batch_size, sequence_length)`):\n                The logits corresponding to the generated tokens.\n\n        Return:\n            `torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.\n        \"\"\"\n        return self._selector.select(input_ids, logits)[0]\n\n    @property\n    def stopped(self) -> bool:\n        # Transformers stopping criteria expects a batch of input ids\n        input_ids = torch.unsqueeze(self._tokens, dim=0)\n        return self._selector.stopping_criteria(input_ids, None)\n\n    @property\n    def generated_text(self) -> str:\n        return self._generated_text + self._next_text\n\n    @property\n    def next_token(self) -> int:\n        return None if len(self._tokens) == 0 else self._tokens[-1]\n\n    @property\n    def attention_mask(self) -> torch.LongTensor:\n        return self._mask\n\n    @property\n    def max_token(self) -> int:\n        return self._generation_config.max_length\n\n    @property\n    def max_new_tokens(self) -> int:\n        # The current value of max_new_tokens: might be different of the target max_new_tokens\n        # if the slot has been paused and resumed.\n        return self._generation_config.max_new_tokens\n\n    @property\n    def truncate(self) -> int:\n        return self._truncate\n\n\nclass NeuronGenerator(Generator):\n    \"\"\"A Generator for Neuron models.\"\"\"\n\n    def __init__(\n        self,\n        model: NeuronModelForCausalLM,\n        tokenizer: PreTrainedTokenizerBase,\n    ):\n        self.model = model\n        if not isinstance(self.model, NeuronModelForCausalLM):\n            raise ValueError(\"The model must be a NeuronModelForCausalLM.\")\n        if (\n            model.neuron_config.batch_size > 1\n            and not model.neuron_config.continuous_batching\n        ):\n            raise ValueError(\n                \"The neuron model must be compiled with continuous_batching=True.\"\n            )\n        # Specify padding and truncation options for decoder-only architecture\n        tokenizer.pad_token_id = tokenizer.eos_token_id\n        tokenizer.padding_side = \"left\"\n        tokenizer.truncation_side = \"left\"\n        self.tokenizer = tokenizer\n        self.special_tokens = self.tokenizer.all_special_ids\n        self.slots = [\n            Slot(i, tokenizer) for i in range(self.model.neuron_config.batch_size)\n        ]\n        self.batch_id = 0\n\n    @property\n    def on_device_sampling(self) -> bool:\n        return getattr(self.model.neuron_config, \"on_device_sampling\", False)\n\n    @property\n    def info(self) -> InfoResponse:\n        \"\"\"Returns the expected InfoResponse.\"\"\"\n        dtype = getattr(self.model.config, \"torch_dtype\", \"float32\")\n        return InfoResponse(\n            requires_padding=True,\n            dtype=str(dtype),\n            device_type=\"xla\",\n        )\n\n    def warmup(self, batch: Batch) -> int:\n        \"\"\"Verify if the hardware can support the target load.\n\n        Args:\n            batch (`Batch`):\n                A batch corresponding to the maximum number of concurrent requests.\n\n        Return:\n            The maximum number of tokens the model supports.\n        \"\"\"\n        # Just check that the warmup request parameters match the model capacity\n        batch_size = self.model.neuron_config.batch_size\n        if len(batch.requests) > batch_size:\n            raise ValueError(\n                f\"Inconsistent batch_size configuration: Please make sure the batch_size in the compiled model (currently {batch_size}) matches the batch_size passed to TGI.  The compiled model.neuron_config.batch_size is usually in the neuron section of the model config.json file. You may also have passed it into optimum-cli during the compilation process.  The batch size for TGI is usually set in the environment as MAX_BATCH_SIZE.\"\n            )\n        self.prefill(batch)\n        self.clear()\n        return (\n            self.model.neuron_config.batch_size\n            * self.model.neuron_config.sequence_length\n        )\n\n    def max_prefill_length(self) -> int:\n        if hasattr(self.model.neuron_config, \"max_context_length\"):\n            return self.model.neuron_config.max_context_length\n        return self.model.neuron_config.sequence_length\n\n    def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:\n        \"\"\"Prefill new requests.\n\n        Args:\n            batch (`Batch`):\n                A batch containing the new requests.\n\n        Return:\n            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.\n        \"\"\"\n        slots = {state: [] for state in Slot.State}\n        for slot in self.slots:\n            slots[slot.state].append(slot)\n        active_slots = slots[Slot.State.READY]\n        empty_slots = slots[Slot.State.EMPTY]\n        if len(empty_slots) < len(batch.requests):\n            raise ValueError(\n                f\"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots.\"\n                f\" Please align max_batch_size with the static batch size: {self.model.neuron_config.batch_size}.\"\n            )\n        # Assign each request to an empty slot\n        logger.debug(\n            f\"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)\"\n        )\n        new_slots = []\n        for request in batch.requests:\n            slot = empty_slots.pop()\n            slot.assign(self.batch_id, request, self.model.generation_config)\n            new_slots.append(slot)\n            logger.debug(\n                f\"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}\"\n            )\n        prefill_slots = new_slots\n        seq_ids = torch.tensor([slot.id for slot in prefill_slots])\n        # Reconstruct the full inputs (without padding) as seen by the model.\n        # This comprises:\n        # - the inputs for new requests,\n        # - only when rebuilding the cache, the inputs and the generated text that has already\n        # been cached (i.e. excluding the last generated token) for unfinished requests.\n        inputs = []\n        max_length = 0\n        for slot in prefill_slots:\n            inputs.append(slot.cached_text)\n            # Apply truncation, making sure we fit into static dimensions\n            if slot.truncate == 0:\n                max_length = self.max_prefill_length()\n            elif (\n                slot.truncate > max_length and slot.truncate < self.max_prefill_length()\n            ):\n                max_length = slot.truncate\n        # Tokenize with padding and truncation\n        padded_inputs = self.tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            truncation=True,\n            max_length=max_length,\n        )\n        input_ids = padded_inputs.input_ids\n        attention_mask = padded_inputs.attention_mask\n        sampling_params = (\n            torch.zeros(input_ids.shape[0], 3) if self.on_device_sampling else None\n        )\n        # Pause previously active slots during generation\n        for slot in active_slots:\n            slot.pause()\n        # Each slot must be reset with the padded inputs and masks\n        for i, slot in enumerate(prefill_slots):\n            if slot.state != slot.state.EMPTY:\n                if slot.truncate > 0 and slot.truncate < input_ids.shape[-1]:\n                    # Apply per-request truncation\n                    input_ids[i, : -slot.truncate] = self.tokenizer.pad_token_id\n                    attention_mask[i, : -slot.truncate] = 0\n                slot_input_ids = input_ids[i : i + 1, :]\n                # Padded input ids are also required to set logits processors and stopping criterias\n                selector = TokenSelector.create(\n                    slot_input_ids,\n                    slot.generation_config,\n                    self.model,\n                    self.model.neuron_config.sequence_length,\n                    tokenizer=self.tokenizer,\n                    seed=slot.seed,\n                )\n                slot_input_ids = slot_input_ids.squeeze(dim=0).type(torch.int64)\n                slot_attention_mask = attention_mask[i]\n                slot.reset(slot_input_ids, slot_attention_mask, selector)\n                if sampling_params is not None:\n                    sampling_params[i, 0] = slot.generation_config.top_k\n                    sampling_params[i, 1] = slot.generation_config.top_p\n                    sampling_params[i, 2] = slot.generation_config.temperature\n        # Note: when rebuilding cache on prefill, the new tokens on paused slots will be ignored,\n        # as they have already been generated and sent back in the last decode.\n        model_inputs = self.model.prepare_inputs_for_prefill(\n            input_ids,\n            attention_mask=attention_mask,\n            seq_ids=seq_ids,\n            sampling_params=sampling_params,\n        )\n        tokens_or_logits = self.model(**model_inputs)[0]\n        generation, next_batch = self._generate_token(\n            prefill_slots, self.batch_id, tokens_or_logits, input_ids\n        )\n        self.batch_id += 1\n        # Reactivate previously active slots for the next decode\n        for i, slot in enumerate(active_slots):\n            slot.resume()\n        logger.debug(\"Model ready for decoding\")\n        if next_batch is not None:\n            logger.debug(\n                f\"Next batch is {next_batch.id} with requests: {next_batch.request_ids}\"\n            )\n        return generation, next_batch\n\n    def decode(\n        self, batches: List[CachedBatch]\n    ) -> Tuple[List[Generation], CachedBatch]:\n        \"\"\"Decode the specified prefilled requests.\n\n        Args:\n            batches (`List[CachedBatch]`):\n                A list of previous batches containing the prefilled requests.\n\n        Return:\n            A list of `Generation` for each request and a `CachedBatch` containing all pending requests.\n        \"\"\"\n        # batches contains a list composed of:\n        # - the batch id returned by the last decode,\n        # - the batch id(s) returned by the last prefill(s)\n        # Batches are always concatenated during prefill, so we can\n        # just carry on with decoding. We adopt the id of the first\n        # batch in the list as our next batch id.\n        next_batch_id = batches[0].id\n        request_ids = []\n        for batch in batches:\n            request_ids += batch.request_ids\n        cleared_request_ids = []\n        for slot in self.slots:\n            if slot.state == slot.State.READY and slot.request_id not in request_ids:\n                cleared_request_ids.append(slot.request_id)\n                slot.clear()\n        if len(cleared_request_ids) > 0:\n            logger.info(\n                f\"Clearing slot for requests {cleared_request_ids} as they are not requested.\"\n            )\n        active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]\n        if len(active_slots) < len(request_ids):\n            raise ValueError(\n                \"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)\"\n            )\n        decode_slots = active_slots\n        seq_ids = torch.tensor([slot.id for slot in decode_slots])\n        # Reconstruct input_ids and attention_mask from decode slots\n        n_slots = len(decode_slots)\n        input_ids = torch.full(\n            [n_slots, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64\n        )\n        max_length = 0\n        for slot in decode_slots:\n            max_length = max(max_length, slot.attention_mask.size(-1))\n        attention_mask = torch.zeros([n_slots, max_length], dtype=torch.int64)\n        sampling_params = torch.zeros(n_slots, 3) if self.on_device_sampling else None\n        for i, slot in enumerate(decode_slots):\n            if slot.state != Slot.State.EMPTY:\n                # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)\n                input_ids[i, 0] = slot.next_token\n                attention_mask[i, : slot.attention_mask.size(-1)] = slot.attention_mask\n                if sampling_params is not None:\n                    sampling_params[i, 0] = slot.generation_config.top_k\n                    sampling_params[i, 1] = slot.generation_config.top_p\n                    sampling_params[i, 2] = slot.generation_config.temperature\n        model_inputs = self.model.prepare_inputs_for_decode(\n            input_ids,\n            attention_mask=attention_mask,\n            seq_ids=seq_ids,\n            sampling_params=sampling_params,\n        )\n        tokens_or_logits = self.model(**model_inputs)[0]\n        return self._generate_token(\n            decode_slots, next_batch_id, tokens_or_logits, input_ids\n        )\n\n    def _generate_token(\n        self,\n        slots: List[Slot],\n        next_batch_id: int,\n        tokens_or_logits: torch.Tensor,\n        input_ids: torch.LongTensor,\n    ) -> Tuple[List[Generation], CachedBatch]:\n        generations = []\n        active_slots = False\n        for i, slot in enumerate(slots):\n            if slot.state != Slot.State.READY:\n                continue\n            request_id = slot.request_id\n            slot_input_ids = input_ids[i : i + 1, :]\n            if self.on_device_sampling:\n                next_token = tokens_or_logits[i]\n            else:\n                next_token_logits = tokens_or_logits[i : i + 1, -1, :]\n                next_token = slot.select(slot_input_ids, next_token_logits)\n            next_token_text = slot.append(next_token)\n            generated_text = None\n            finish_reason = None\n            if next_token == self.tokenizer.eos_token_id:\n                finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN\n            elif slot.stopped:\n                if slot.generated_tokens == slot.max_new_tokens:\n                    finish_reason = FinishReason.FINISH_REASON_LENGTH\n                else:\n                    finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE\n            if finish_reason is not None:\n                # We must include the generated text for each finished sequence in the response\n                generated_text = GeneratedText(\n                    text=slot.generated_text,\n                    generated_tokens=slot.generated_tokens,\n                    finish_reason=finish_reason,\n                )\n                logger.debug(\n                    f\"Decode complete for request {request_id} with {slot.generated_tokens} tokens\"\n                )\n                # mark the slot as available\n                slot.clear()\n            else:\n                active_slots = True\n            generations.append(\n                Generation(\n                    request_id=request_id,\n                    prefill_tokens=None,\n                    tokens=Tokens(\n                        ids=[next_token],\n                        logprobs=[0],\n                        texts=[next_token_text],\n                        is_special=[next_token in self.special_tokens],\n                    ),\n                    generated_text=generated_text,\n                )\n            )\n        batch = None\n        if active_slots:\n            # Whatever initial batch these requests came from, we always return all pending requests in a single batch\n            request_ids = [\n                slot.request_id for slot in self.slots if slot.state == Slot.State.READY\n            ]\n            batch = self._cached_batch(next_batch_id, request_ids)\n        else:\n            logger.debug(\"No more pending requests\")\n        return generations, batch\n\n    def _cached_batch(self, batch_id: int, request_ids: List):\n        size = len(request_ids)\n        max_tokens = size * self.model.neuron_config.sequence_length\n        return CachedBatch(\n            id=batch_id, request_ids=request_ids, size=size, max_tokens=max_tokens\n        )\n\n    def filter(self, batch_id: int, keep_request_ids: List[int]) -> CachedBatch:\n        \"\"\"Remove requests that are not listed from the specified batch\n\n        Args:\n            batch_id (`int`):\n                The id of a cached batch.\n            keep_ids(`List[int]`):\n                The list of requests that must be kept.\n\n        Return:\n            A `CachedBatch` containing the pending requests.\n        \"\"\"\n        keep_slot_ids = [\n            slot.id for slot in self.slots if slot.request_id in keep_request_ids\n        ]\n        self._clear(keep_slot_ids)\n        return self._cached_batch(batch_id, keep_request_ids)\n\n    def clear(self, batch_id: Optional[int] = None):\n        \"\"\"Remove a subset or all requests from the generator\"\"\"\n        keep_ids = []\n        if batch_id is not None:\n            keep_ids = [slot.id for slot in self.slots if slot.batch_id != batch_id]\n        return self._clear(keep_ids)\n\n    def _clear(self, keep_slot_ids: List):\n        for slot in self.slots:\n            if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:\n                logger.debug(f\"Removing slot {slot.id} with request {slot.request_id}\")\n                slot.clear()\n\n    @classmethod\n    def from_pretrained(cls, model_id: str, revision: str = None):\n        \"\"\"Instantiate a NeuronGenerator.\n\n        Args:\n            model_id (`str`):\n                A hub model id or the path to a local model. This path must also contain a Tokenizer.\n            revision (`Optional[str]`, defaults to `None`):\n                The revision of the model on the HuggingFace hub.\n\n        Returns:\n            A NeuronGenerator.\n        \"\"\"\n        try:\n            neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)\n        except Exception as e:\n            logger.debug(\n                \"NeuronConfig.from_pretrained failed for model %s, revision %s: %s\",\n                model_id,\n                revision,\n                e,\n            )\n            neuron_config = None\n        start = time.time()\n        if neuron_config is None:\n            export_kwargs = get_export_kwargs_from_env()\n            logger.info(f\"Exporting model to neuron with config: {export_kwargs}.\")\n            model = NeuronModelForCausalLM.from_pretrained(\n                model_id,\n                revision=revision,\n                low_cpu_mem_usage=True,\n                export=True,\n                **export_kwargs,\n            )\n        else:\n            logger.info(\n                \"Loading model on neuron devices (this can take a few minutes).\"\n            )\n            model = NeuronModelForCausalLM.from_pretrained(\n                model_id, low_cpu_mem_usage=True, revision=revision\n            )\n        end = time.time()\n        logger.info(f\"Model successfully loaded in {end - start:.2f} s.\")\n        tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n        return cls(model, tokenizer)\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/interceptor.py",
    "content": "from typing import Any, Callable\n\nimport grpc\nfrom google.rpc import code_pb2, status_pb2\nfrom grpc_interceptor.server import AsyncServerInterceptor\nfrom grpc_status import rpc_status\nfrom loguru import logger\n\n\nclass ExceptionInterceptor(AsyncServerInterceptor):\n    async def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        try:\n            response = method(request_or_iterator, context)\n            return await response\n        except Exception as err:\n            method_name = method_name.split(\"/\")[-1]\n            logger.exception(f\"Method {method_name} encountered an error.\")\n\n            await context.abort_with_status(\n                rpc_status.to_status(\n                    status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))\n                )\n            )\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/model.py",
    "content": "import os\nimport shutil\nimport time\nfrom typing import Optional\n\nfrom huggingface_hub import snapshot_download\nfrom huggingface_hub.constants import HF_HUB_CACHE\nfrom loguru import logger\n\nfrom optimum.neuron.cache import get_hub_cached_entries\nfrom optimum.neuron.configuration_utils import NeuronConfig\n\n\nfrom .tgi_env import check_env_and_neuron_config_compatibility\n\n\ndef get_export_kwargs_from_env():\n    batch_size = os.environ.get(\"MAX_BATCH_SIZE\", None)\n    if batch_size is not None:\n        batch_size = int(batch_size)\n    sequence_length = os.environ.get(\"MAX_TOTAL_TOKENS\", None)\n    if sequence_length is not None:\n        sequence_length = int(sequence_length)\n    num_cores = os.environ.get(\"HF_NUM_CORES\", None)\n    if num_cores is not None:\n        num_cores = int(num_cores)\n    auto_cast_type = os.environ.get(\"HF_AUTO_CAST_TYPE\", None)\n    return {\n        \"batch_size\": batch_size,\n        \"sequence_length\": sequence_length,\n        \"num_cores\": num_cores,\n        \"auto_cast_type\": auto_cast_type,\n    }\n\n\ndef is_cached(model_id):\n    # Look for cached entries for the specified model\n    in_cache = False\n    entries = get_hub_cached_entries(model_id)\n    # Look for compatible entries\n    for entry in entries:\n        if check_env_and_neuron_config_compatibility(\n            entry, check_compiler_version=True\n        ):\n            in_cache = True\n            break\n    return in_cache\n\n\ndef log_cache_size():\n    path = HF_HUB_CACHE\n    if os.path.exists(path):\n        usage = shutil.disk_usage(path)\n        gb = 2**30\n        logger.info(\n            f\"Cache disk [{path}]: total = {usage.total / gb:.2f} G, free = {usage.free / gb:.2f} G\"\n        )\n    else:\n        raise ValueError(f\"The cache directory ({path}) does not exist.\")\n\n\ndef fetch_model(\n    model_id: str,\n    revision: Optional[str] = None,\n) -> str:\n    \"\"\"Fetch a neuron model.\n\n    Args:\n        model_id (`str`):\n            The *model_id* of a model on the HuggingFace hub or the path to a local model.\n        revision (`Optional[str]`, defaults to `None`):\n            The revision of the model on the HuggingFace hub.\n\n    Returns:\n        A string corresponding to the model_id or path.\n    \"\"\"\n    if not os.path.isdir(\"/sys/class/neuron_device/\"):\n        raise SystemError(\"No neuron cores detected on the host.\")\n    if os.path.isdir(model_id) and revision is not None:\n        logger.warning(\n            \"Revision {} ignored for local model at {}\".format(revision, model_id)\n        )\n        revision = None\n    # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model)\n    # Note that the model may already be present in the cache.\n    try:\n        neuron_config = NeuronConfig.from_pretrained(model_id, revision=revision)\n    except Exception as e:\n        logger.debug(\n            \"NeuronConfig.from_pretrained failed for model %s, revision %s: %s\",\n            model_id,\n            revision,\n            e,\n        )\n        neuron_config = None\n    if neuron_config is not None:\n        if os.path.isdir(model_id):\n            return model_id\n        # Prefetch the neuron model from the Hub\n        logger.info(\n            f\"Fetching revision [{revision}] for neuron model {model_id} under {HF_HUB_CACHE}\"\n        )\n        log_cache_size()\n        return snapshot_download(model_id, revision=revision, ignore_patterns=\"*.bin\")\n    # Model needs to be exported: look for compatible cached entries on the hub\n    if not is_cached(model_id):\n        hub_cache_url = \"https://huggingface.co/aws-neuron/optimum-neuron-cache\"\n        neuron_export_url = \"https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-neuronx-tgi\"\n        error_msg = (\n            f\"No cached version found for {model_id} with {get_export_kwargs_from_env()}.\"\n            f\"You can start a discussion to request it on {hub_cache_url}\"\n            f\"Alternatively, you can export your own neuron model as explained in {neuron_export_url}\"\n        )\n        raise ValueError(error_msg)\n    logger.warning(\n        f\"{model_id} is not a neuron model: it will be exported using cached artifacts.\"\n    )\n    if os.path.isdir(model_id):\n        return model_id\n    # Prefetch weights, tokenizer and generation config so that they are in cache\n    log_cache_size()\n    start = time.time()\n    snapshot_path = snapshot_download(\n        model_id, revision=revision, ignore_patterns=\"*.bin\"\n    )\n    end = time.time()\n    logger.info(f\"Model weights fetched in {end - start:.2f} s.\")\n    log_cache_size()\n    return snapshot_path\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/server.py",
    "content": "import asyncio\nfrom pathlib import Path\nfrom typing import List\n\nfrom grpc import aio\nfrom grpc_reflection.v1alpha import reflection\nfrom loguru import logger\n\nfrom .generator import Generator, NeuronGenerator\nfrom .interceptor import ExceptionInterceptor\nfrom .pb import generate_pb2, generate_pb2_grpc\n\n\nclass TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):\n    def __init__(self, generator: Generator, server_urls: List[str]):\n        self.generator = generator\n        self.server_urls = server_urls\n\n    async def Info(self, request, context):\n        return self.generator.info\n\n    async def Health(self, request, context):\n        return generate_pb2.HealthResponse()\n\n    async def ServiceDiscovery(self, request, context):\n        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)\n\n    async def ClearCache(self, request, context):\n        if request.HasField(\"id\"):\n            self.generator.clear(request.id)\n        else:\n            self.generator.clear()\n        return generate_pb2.ClearCacheResponse()\n\n    async def FilterBatch(self, request, context):\n        filtered_batch = self.generator.filter(request.batch_id, request.request_ids)\n        return generate_pb2.FilterBatchResponse(batch=filtered_batch)\n\n    async def Warmup(self, request, context):\n        max_tokens = self.generator.warmup(request.batch)\n        return generate_pb2.WarmupResponse(max_supported_total_tokens=max_tokens)\n\n    async def Prefill(self, request, context):\n        generations, batch = self.generator.prefill(request.batch)\n        return generate_pb2.PrefillResponse(generations=generations, batch=batch)\n\n    async def Decode(self, request, context):\n        generations, batch = self.generator.decode(request.batches)\n        return generate_pb2.DecodeResponse(generations=generations, batch=batch)\n\n\ndef serve(\n    model_id: str,\n    revision: str,\n    uds_path: Path,\n):\n    async def serve_inner(model_id: str, revision: str):\n        unix_socket_template = \"unix://{}-{}\"\n        local_url = unix_socket_template.format(uds_path, 0)\n        server_urls = [local_url]\n\n        try:\n            generator = NeuronGenerator.from_pretrained(model_id, revision)\n        except Exception:\n            logger.exception(\"Error when initializing model\")\n            raise\n\n        server = aio.server(interceptors=[ExceptionInterceptor()])\n        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(\n            TextGenerationService(generator, server_urls), server\n        )\n        SERVICE_NAMES = (\n            generate_pb2.DESCRIPTOR.services_by_name[\"TextGenerationService\"].full_name,\n            reflection.SERVICE_NAME,\n        )\n        reflection.enable_server_reflection(SERVICE_NAMES, server)\n        server.add_insecure_port(local_url)\n\n        await server.start()\n\n        logger.info(\"Server started at {}\".format(local_url))\n\n        try:\n            await server.wait_for_termination()\n        except KeyboardInterrupt:\n            logger.info(\"Signal received. Shutting down\")\n            await server.stop(0)\n\n    asyncio.run(serve_inner(model_id, revision))\n"
  },
  {
    "path": "backends/neuron/server/text_generation_server/tgi_env.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nimport logging\nimport os\nimport sys\nfrom typing import Any, Dict, List, Optional\n\nfrom optimum.neuron.modeling_decoder import get_available_cores\nfrom optimum.neuron.cache import get_hub_cached_entries\nfrom optimum.neuron.configuration_utils import NeuronConfig\nfrom optimum.neuron.utils.version_utils import get_neuronxcc_version\nfrom optimum.neuron.utils import map_torch_dtype\n\n\nlogger = logging.getLogger(__name__)\n\ntgi_router_env_vars = [\n    \"MAX_BATCH_SIZE\",\n    \"MAX_TOTAL_TOKENS\",\n    \"MAX_INPUT_TOKENS\",\n    \"MAX_BATCH_PREFILL_TOKENS\",\n]\ntgi_server_env_vars = [\"HF_NUM_CORES\", \"HF_AUTO_CAST_TYPE\"]\n\n\n# By the end of this script all env var should be specified properly\ntgi_env_vars = tgi_server_env_vars + tgi_router_env_vars\n\navailable_cores = get_available_cores()\nneuronxcc_version = get_neuronxcc_version()\n\n\ndef parse_cmdline_and_set_env(argv: List[str] = None) -> argparse.Namespace:\n    parser = argparse.ArgumentParser()\n    if not argv:\n        argv = sys.argv\n    # All these are params passed to tgi and intercepted here\n    parser.add_argument(\n        \"--max-input-tokens\",\n        type=int,\n        default=os.getenv(\"MAX_INPUT_TOKENS\", os.getenv(\"MAX_INPUT_LENGTH\", 0)),\n    )\n    parser.add_argument(\n        \"--max-total-tokens\", type=int, default=os.getenv(\"MAX_TOTAL_TOKENS\", 0)\n    )\n    parser.add_argument(\n        \"--max-batch-size\", type=int, default=os.getenv(\"MAX_BATCH_SIZE\", 0)\n    )\n    parser.add_argument(\n        \"--max-batch-prefill-tokens\",\n        type=int,\n        default=os.getenv(\"MAX_BATCH_PREFILL_TOKENS\", 0),\n    )\n    parser.add_argument(\"--model-id\", type=str, default=os.getenv(\"MODEL_ID\"))\n    parser.add_argument(\"--revision\", type=str, default=os.getenv(\"REVISION\"))\n\n    args = parser.parse_known_args(argv)[0]\n\n    if not args.model_id:\n        raise Exception(\n            \"No model id provided ! Either specify it using --model-id cmdline or MODEL_ID env var\"\n        )\n\n    # Override env with cmdline params\n    os.environ[\"MODEL_ID\"] = args.model_id\n\n    # Set all tgi router and tgi server values to consistent values as early as possible\n    # from the order of the parser defaults, the tgi router value can override the tgi server ones\n    if args.max_total_tokens > 0:\n        os.environ[\"MAX_TOTAL_TOKENS\"] = str(args.max_total_tokens)\n\n    if args.max_input_tokens > 0:\n        os.environ[\"MAX_INPUT_TOKENS\"] = str(args.max_input_tokens)\n\n    if args.max_batch_size > 0:\n        os.environ[\"MAX_BATCH_SIZE\"] = str(args.max_batch_size)\n\n    if args.max_batch_prefill_tokens > 0:\n        os.environ[\"MAX_BATCH_PREFILL_TOKENS\"] = str(args.max_batch_prefill_tokens)\n\n    if args.revision:\n        os.environ[\"REVISION\"] = str(args.revision)\n\n    return args\n\n\ndef neuron_config_to_env(neuron_config):\n    if isinstance(neuron_config, NeuronConfig):\n        neuron_config = neuron_config.to_dict()\n    with open(os.environ[\"ENV_FILEPATH\"], \"w\") as f:\n        f.write(\"export MAX_BATCH_SIZE={}\\n\".format(neuron_config[\"batch_size\"]))\n        f.write(\"export MAX_TOTAL_TOKENS={}\\n\".format(neuron_config[\"sequence_length\"]))\n        f.write(\"export HF_NUM_CORES={}\\n\".format(neuron_config[\"tp_degree\"]))\n        config_key = (\n            \"auto_cast_type\" if \"auto_cast_type\" in neuron_config else \"torch_dtype\"\n        )\n        auto_cast_type = neuron_config[config_key]\n        f.write(\"export HF_AUTO_CAST_TYPE={}\\n\".format(auto_cast_type))\n        max_input_tokens = os.getenv(\"MAX_INPUT_TOKENS\")\n        if not max_input_tokens:\n            max_input_tokens = int(neuron_config[\"sequence_length\"]) // 2\n            if max_input_tokens == 0:\n                raise Exception(\"Model sequence length should be greater than 1\")\n        f.write(\"export MAX_INPUT_TOKENS={}\\n\".format(max_input_tokens))\n        max_batch_prefill_tokens = os.getenv(\"MAX_BATCH_PREFILL_TOKENS\")\n        if not max_batch_prefill_tokens:\n            max_batch_prefill_tokens = int(neuron_config[\"batch_size\"]) * int(\n                max_input_tokens\n            )\n        f.write(\"export MAX_BATCH_PREFILL_TOKENS={}\\n\".format(max_batch_prefill_tokens))\n\n\ndef sort_neuron_configs(dictionary):\n    return -dictionary[\"tp_degree\"], -dictionary[\"batch_size\"]\n\n\ndef lookup_compatible_cached_model(\n    model_id: str, revision: Optional[str]\n) -> Optional[Dict[str, Any]]:\n    # Reuse the same mechanic as the one in use to configure the tgi server part\n    # The only difference here is that we stay as flexible as possible on the compatibility part\n    entries = get_hub_cached_entries(model_id)\n\n    logger.debug(\n        \"Found %d cached entries for model %s, revision %s\",\n        len(entries),\n        model_id,\n        revision,\n    )\n\n    all_compatible = []\n    for entry in entries:\n        if check_env_and_neuron_config_compatibility(\n            entry, check_compiler_version=True\n        ):\n            all_compatible.append(entry)\n\n    if not all_compatible:\n        logger.debug(\n            \"No compatible cached entry found for model %s, env %s, available cores %s, neuronxcc version %s\",\n            model_id,\n            get_env_dict(),\n            available_cores,\n            neuronxcc_version,\n        )\n        return None\n\n    logger.info(\"%d compatible neuron cached models found\", len(all_compatible))\n\n    all_compatible = sorted(all_compatible, key=sort_neuron_configs)\n\n    entry = all_compatible[0]\n\n    return entry\n\n\ndef check_env_and_neuron_config_compatibility(\n    neuron_config_dict: Dict[str, Any], check_compiler_version: bool\n) -> bool:\n    logger.debug(\n        \"Checking the provided neuron config %s is compatible with the local setup and provided environment\",\n        neuron_config_dict,\n    )\n\n    # Local setup compat checks\n    if neuron_config_dict[\"tp_degree\"] > available_cores:\n        logger.debug(\n            \"Not enough neuron cores available to run the provided neuron config\"\n        )\n        return False\n\n    if (\n        check_compiler_version\n        and neuron_config_dict[\"neuronxcc_version\"] != neuronxcc_version\n    ):\n        logger.debug(\n            \"Compiler version conflict, the local one (%s) differs from the one used to compile the model (%s)\",\n            neuronxcc_version,\n            neuron_config_dict[\"neuronxcc_version\"],\n        )\n        return False\n\n    batch_size = os.getenv(\"MAX_BATCH_SIZE\", None)\n    if batch_size is not None and neuron_config_dict[\"batch_size\"] < int(batch_size):\n        logger.debug(\n            \"The provided MAX_BATCH_SIZE (%s) is higher than the neuron config batch size (%s)\",\n            os.getenv(\"MAX_BATCH_SIZE\"),\n            neuron_config_dict[\"batch_size\"],\n        )\n        return False\n    max_total_tokens = os.getenv(\"MAX_TOTAL_TOKENS\", None)\n    if max_total_tokens is not None and neuron_config_dict[\"sequence_length\"] < int(\n        max_total_tokens\n    ):\n        logger.debug(\n            \"The provided MAX_TOTAL_TOKENS (%s) is higher than the neuron config sequence length (%s)\",\n            max_total_tokens,\n            neuron_config_dict[\"sequence_length\"],\n        )\n        return False\n    num_cores = os.getenv(\"HF_NUM_CORES\", None)\n    if num_cores is not None and neuron_config_dict[\"tp_degree\"] < int(num_cores):\n        logger.debug(\n            \"The provided HF_NUM_CORES (%s) is higher than the neuron config tp degree (%s)\",\n            num_cores,\n            neuron_config_dict[\"tp_degree\"],\n        )\n        return False\n    auto_cast_type = os.getenv(\"HF_AUTO_CAST_TYPE\", None)\n    if auto_cast_type is not None:\n        config_key = (\n            \"auto_cast_type\"\n            if \"auto_cast_type\" in neuron_config_dict\n            else \"torch_dtype\"\n        )\n        neuron_config_value = map_torch_dtype(str(neuron_config_dict[config_key]))\n        env_value = map_torch_dtype(auto_cast_type)\n        if env_value != neuron_config_value:\n            logger.debug(\n                \"The provided auto cast type and the neuron config param differ (%s != %s)\",\n                env_value,\n                neuron_config_value,\n            )\n            return False\n    max_input_tokens = int(\n        os.getenv(\"MAX_INPUT_TOKENS\", os.getenv(\"MAX_INPUT_LENGTH\", 0))\n    )\n    if max_input_tokens > 0:\n        if hasattr(neuron_config_dict, \"max_context_length\"):\n            sequence_length = neuron_config_dict[\"max_context_length\"]\n        else:\n            sequence_length = neuron_config_dict[\"sequence_length\"]\n        if max_input_tokens >= sequence_length:\n            logger.debug(\n                \"Specified max input tokens is not compatible with config sequence length ( %s >= %s)\",\n                max_input_tokens,\n                sequence_length,\n            )\n            return False\n\n    return True\n\n\ndef get_env_dict() -> Dict[str, str]:\n    d = {}\n    for k in tgi_env_vars:\n        d[k] = os.getenv(k)\n    return d\n\n\ndef get_neuron_config_for_model(\n    model_name_or_path: str, revision: Optional[str] = None\n) -> NeuronConfig:\n    try:\n        neuron_config = NeuronConfig.from_pretrained(\n            model_name_or_path, revision=revision\n        )\n    except Exception as e:\n        logger.debug(\n            \"NeuronConfig.from_pretrained failed for model %s, revision %s: %s\",\n            model_name_or_path,\n            revision,\n            e,\n        )\n        neuron_config = None\n    if neuron_config is not None:\n        compatible = check_env_and_neuron_config_compatibility(\n            neuron_config.to_dict(), check_compiler_version=False\n        )\n        if not compatible:\n            env_dict = get_env_dict()\n            msg = (\n                \"Invalid neuron config and env. Config {}, env {}, available cores {}, neuronxcc version {}\"\n            ).format(neuron_config, env_dict, available_cores, neuronxcc_version)\n            logger.error(msg)\n            raise Exception(msg)\n    else:\n        neuron_config = lookup_compatible_cached_model(model_name_or_path, revision)\n\n    return neuron_config\n"
  },
  {
    "path": "backends/neuron/tests/conftest.py",
    "content": "pytest_plugins = [\"fixtures.model\"]\n"
  },
  {
    "path": "backends/neuron/tests/fixtures/model.py",
    "content": "import copy\nimport logging\nimport subprocess\nimport sys\nfrom tempfile import TemporaryDirectory\n\nimport os\nimport pytest\nfrom transformers import AutoTokenizer\n\n\nfrom optimum.neuron.cache import synchronize_hub_cache\n\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s\",\n    stream=sys.stdout,\n)\nlogger = logging.getLogger(__file__)\n\n\nOPTIMUM_CACHE_REPO_ID = \"optimum-internal-testing/neuron-testing-cache\"\n\n\n# All model configurations below will be added to the neuron_model_config fixture\nMODEL_CONFIGURATIONS = {\n    \"llama\": {\n        \"model_id\": \"unsloth/Llama-3.2-1B-Instruct\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n    \"qwen2\": {\n        \"model_id\": \"Qwen/Qwen2.5-0.5B\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n    \"granite\": {\n        \"model_id\": \"ibm-granite/granite-3.1-2b-instruct\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n}\n\n\ndef export_model(model_id, export_kwargs, neuron_model_path):\n    export_command = [\n        \"optimum-cli\",\n        \"export\",\n        \"neuron\",\n        \"-m\",\n        model_id,\n        \"--task\",\n        \"text-generation\",\n    ]\n    for kwarg, value in export_kwargs.items():\n        export_command.append(f\"--{kwarg}\")\n        export_command.append(str(value))\n    export_command.append(neuron_model_path)\n    logger.info(f\"Exporting {model_id} with {export_kwargs}\")\n    try:\n        subprocess.run(export_command, check=True)\n    except subprocess.CalledProcessError as e:\n        raise ValueError(f\"Failed to export model: {e}\")\n\n\n@pytest.fixture(scope=\"session\", params=MODEL_CONFIGURATIONS.keys())\ndef neuron_model_config(request):\n    \"\"\"Expose a pre-trained neuron model\n\n    The fixture exports a model locally and returns a dictionary containing:\n    - a configuration name,\n    - the original model id,\n    - the export parameters,\n    - the neuron model local path.\n\n    For each exposed model, the local directory is maintained for the duration of the\n    test session and cleaned up afterwards.\n\n    \"\"\"\n    config_name = request.param\n    model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])\n    model_id = model_config[\"model_id\"]\n    export_kwargs = model_config[\"export_kwargs\"]\n    with TemporaryDirectory() as neuron_model_path:\n        export_model(model_id, export_kwargs, neuron_model_path)\n        synchronize_hub_cache(cache_repo_id=OPTIMUM_CACHE_REPO_ID)\n        tokenizer = AutoTokenizer.from_pretrained(model_id)\n        tokenizer.save_pretrained(neuron_model_path)\n        del tokenizer\n        # Add dynamic parameters to the model configuration\n        model_config[\"neuron_model_path\"] = neuron_model_path\n        # Also add model configuration name to allow tests to adapt their expectations\n        model_config[\"name\"] = config_name\n        # Yield instead of returning to keep a reference to the temporary directory.\n        # It will go out of scope and be released only once all tests needing the fixture\n        # have been completed.\n        logger.info(f\"{config_name} ready for testing ...\")\n        os.environ[\"CUSTOM_CACHE_REPO\"] = OPTIMUM_CACHE_REPO_ID\n        yield model_config\n        logger.info(f\"Done with {config_name}\")\n\n\n@pytest.fixture(scope=\"module\")\ndef neuron_model_path(neuron_model_config):\n    yield neuron_model_config[\"neuron_model_path\"]\n"
  },
  {
    "path": "backends/neuron/tests/prune_test_models.py",
    "content": "from argparse import ArgumentParser\nfrom huggingface_hub import HfApi\n\n\ndef main():\n    parser = ArgumentParser()\n    parser.add_argument(\"--yes\", action=\"store_true\", default=False)\n    args = parser.parse_args()\n    api = HfApi()\n    models = api.list_models(search=\"optimum-internal-testing/neuron-tgi-testing\")\n    for model in models:\n        if args.yes:\n            delete = True\n        else:\n            answer = input(f\"Do you want to delete {model.id} [y/N] ?\")\n            delete = answer == \"y\"\n        if delete:\n            api.delete_repo(model.id)\n            print(f\"Deleted {model.id}.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "backends/neuron/tests/pytest.ini",
    "content": "[pytest]\nasyncio_mode = auto\n"
  },
  {
    "path": "backends/neuron/tests/requirements.txt",
    "content": "#  Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n#  Licensed under the Apache License, Version 2.0 (the \"License\");\n#  you may not use this file except in compliance with the License.\n#  You may obtain a copy of the License at\n#\n#      http://www.apache.org/licenses/LICENSE-2.0\n#\n#  Unless required by applicable law or agreed to in writing, software\n#  distributed under the License is distributed on an \"AS IS\" BASIS,\n#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n#  See the License for the specific language governing permissions and\n#  limitations under the License.\ntext-generation >= 0.6.0\npytest >= 7.4.0\npytest-asyncio >= 0.21.1\nrequests < 2.32.0\ndocker >= 6.1.3\nLevenshtein\n"
  },
  {
    "path": "backends/neuron/tests/server/helpers.py",
    "content": "from text_generation_server.generator import NeuronGenerator\nfrom text_generation_server.pb.generate_pb2 import (\n    Batch,\n    NextTokenChooserParameters,\n    Request,\n    StoppingCriteriaParameters,\n)\n\n\ndef create_request(\n    id: int,\n    inputs: str,\n    truncate: int = 0,\n    max_new_tokens: int = 20,\n    do_sample: bool = False,\n    top_k: int = 50,\n    top_p: float = 0.9,\n    temperature: float = 1.0,\n    seed: int = 42,\n    repetition_penalty: float = 1.0,\n):\n    parameters = NextTokenChooserParameters(\n        temperature=temperature,\n        top_k=top_k,\n        top_p=top_p,\n        do_sample=do_sample,\n        seed=seed,\n        repetition_penalty=repetition_penalty,\n    )\n    stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)\n    return Request(\n        id=id,\n        inputs=inputs,\n        truncate=truncate,\n        parameters=parameters,\n        stopping_parameters=stopping_parameters,\n    )\n\n\ndef check_prefill(\n    input_text,\n    expected_token_id,\n    expected_token_text,\n    do_sample,\n    batch_size,\n    model_path,\n):\n    \"\"\"Verify that a prefill for a single request generates the expected output.\"\"\"\n    generator = NeuronGenerator.from_pretrained(model_path)\n    assert generator.model.batch_size >= batch_size\n    requests = []\n    max_new_tokens = 20\n    for i in range(batch_size):\n        requests.append(\n            create_request(\n                id=0,\n                inputs=input_text,\n                do_sample=do_sample,\n                max_new_tokens=max_new_tokens,\n            )\n        )\n    # Let's be pessimistic when estimating max_tokens\n    batch_size * (len(input_text) + max_new_tokens)\n    max_length = generator.model.max_length\n    batch = Batch(\n        id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length\n    )\n    generations, next_batch = generator.prefill(batch)\n    assert next_batch.size == batch_size\n    # Whatever was passed as max_tokens, the server will correct it\n    # because of static batching\n    assert next_batch.max_tokens == batch_size * max_length\n    assert len(generations) == batch_size\n    for g in generations:\n        tokens = g.tokens\n        assert tokens.ids == [expected_token_id]\n        assert tokens.texts == [expected_token_text]\n\n\ndef check_decode_single(\n    input_text, max_new_tokens, generated_text, do_sample, model_path\n):\n    \"\"\"Verify that a decoding for a single request generates the expected output.\"\"\"\n    generator = NeuronGenerator.from_pretrained(model_path)\n    request = create_request(\n        id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=do_sample\n    )\n    max_length = generator.model.max_length\n    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch = generator.prefill(batch)\n    # We already generated one token: call decode max_new_tokens - 1 times\n    for _ in range(max_new_tokens - 1):\n        assert next_batch.size == 1\n        assert next_batch.max_tokens == max_length\n        assert len(generations) == 1\n        assert len(generations[0].tokens.ids) == 1\n        generations, next_batch = generator.decode([next_batch])\n    assert next_batch is None\n    assert len(generations) == 1\n    output = generations[0].generated_text\n    assert output.generated_tokens == max_new_tokens\n    assert output.finish_reason == 0\n    assert output.text == generated_text\n\n\ndef check_decode_multiple(model_path):\n    \"\"\"Verify that two requests added to the batch at different generation steps\n    generate the same outputs (continuous batching).\n    \"\"\"\n    generator = NeuronGenerator.from_pretrained(model_path)\n    assert generator.model.batch_size > 1\n    input_text = \"Once upon a time\"\n    max_new_tokens = 20\n    # Prefill a single request, remembering the generated token\n    tokens = {0: [], 1: []}\n    request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)\n    max_length = generator.model.max_length\n    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch = generator.prefill(batch)\n    assert next_batch.size == 1\n    assert len(generations) == 1\n    g = generations[0]\n    tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == 1\n    # Decode a few tokens\n    gen_tokens = 4\n    for _ in range(gen_tokens - 1):\n        generations, next_batch = generator.decode([next_batch])\n        assert len(generations) == 1\n        g = generations[0]\n        tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == gen_tokens\n    assert next_batch.size == 1\n    # Add a second request\n    request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)\n    batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch_1 = generator.prefill(batch)\n    assert next_batch_1.size == 1\n    # We should have generated only a single token\n    assert len(generations) == 1\n    g = generations[0]\n    tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == gen_tokens\n    assert len(tokens[1]) == 1\n    # Decode more tokens until we reach the maximum for the first request\n    batches = [next_batch, next_batch_1]\n    for _ in range(max_new_tokens - gen_tokens):\n        generations, next_batch = generator.decode(batches)\n        for g in generations:\n            tokens[g.request_id].append(g.tokens.ids[0])\n        batches = [next_batch]\n    # Verify we now only have one pending request\n    assert next_batch.size == 1\n    assert len(tokens[0]) == max_new_tokens\n    assert len(tokens[1]) == max_new_tokens - gen_tokens + 1\n    # Verify we have the output for the first request\n    for g in generations:\n        if g.request_id == 0:\n            output = g.generated_text\n            assert output.text != \"\"\n            assert output.generated_tokens == max_new_tokens\n            generated_text = output.text\n    # Continue decoding until the end of the second request\n    for _ in range(gen_tokens - 1):\n        generations, next_batch = generator.decode([next_batch])\n        assert len(generations) == 1\n        g = generations[0]\n        tokens[g.request_id].append(g.tokens.ids[0])\n    assert next_batch is None\n    output = generations[0].generated_text\n    assert output.generated_tokens == max_new_tokens\n    assert tokens[0] == tokens[1]\n    assert output.text == generated_text\n"
  },
  {
    "path": "backends/neuron/tests/server/test_cached_model.py",
    "content": "import os\nimport pytest\n\nfrom text_generation_server.generator import NeuronGenerator\nfrom text_generation_server.model import fetch_model, is_cached\n\n\n@pytest.fixture(scope=\"module\")\ndef cached_model_id(neuron_model_config) -> str:\n    \"\"\"\n    Fixture to provide a cached model ID for testing.\n    This assumes the model is already cached in the local environment.\n    \"\"\"\n    export_kwargs = neuron_model_config[\"export_kwargs\"]\n    os.environ[\"MAX_BATCH_SIZE\"] = str(export_kwargs[\"batch_size\"])\n    os.environ[\"MAX_TOTAL_TOKENS\"] = str(export_kwargs[\"sequence_length\"])\n    os.environ[\"HF_AUTO_CAST_TYPE\"] = export_kwargs[\"auto_cast_type\"]\n    os.environ[\"HF_NUM_CORES\"] = str(export_kwargs[\"num_cores\"])\n    yield neuron_model_config[\"model_id\"]\n    os.environ.pop(\"MAX_BATCH_SIZE\", None)\n    os.environ.pop(\"MAX_TOTAL_TOKENS\", None)\n    os.environ.pop(\"HF_AUTO_CAST_TYPE\", None)\n    os.environ.pop(\"HF_NUM_CORES\", None)\n\n\ndef test_model_is_cached(cached_model_id):\n    assert is_cached(cached_model_id), f\"Model {cached_model_id} is not cached\"\n\n\ndef test_fetch_cached_model(cached_model_id: str):\n    model_path = fetch_model(cached_model_id)\n    assert os.path.exists(\n        model_path\n    ), f\"Model {cached_model_id} was not fetched successfully\"\n    assert os.path.isdir(model_path), f\"Model {cached_model_id} is not a directory\"\n\n\ndef test_generator_from_cached_model(cached_model_id: str):\n    generator = NeuronGenerator.from_pretrained(model_id=cached_model_id)\n    assert generator is not None, \"Generator could not be created from cached model\"\n    assert generator.model is not None, \"Generator model is not initialized\"\n    assert generator.tokenizer is not None, \"Generator tokenizer is not initialized\"\n"
  },
  {
    "path": "backends/neuron/tests/server/test_continuous_batching.py",
    "content": "from helpers import create_request\nfrom text_generation_server.generator import NeuronGenerator\nfrom text_generation_server.pb.generate_pb2 import Batch\n\n\ndef test_continuous_batching_two_requests(neuron_model_config):\n    \"\"\"Verify that two requests added to the batch at different generation steps\n    generate the same outputs (continuous batching).\n    \"\"\"\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    generator = NeuronGenerator.from_pretrained(neuron_model_path)\n    assert generator.model.neuron_config.batch_size > 1\n    input_text = \"Once upon a time\"\n    max_new_tokens = 20\n    # Prefill a single request, remembering the generated token\n    tokens = {0: [], 1: []}\n    request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens)\n    max_length = generator.model.neuron_config.sequence_length\n    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch = generator.prefill(batch)\n    assert next_batch.size == 1\n    assert len(generations) == 1\n    g = generations[0]\n    tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == 1\n    # Decode a few tokens\n    gen_tokens = 4\n    for _ in range(gen_tokens - 1):\n        generations, next_batch = generator.decode([next_batch])\n        assert len(generations) == 1\n        g = generations[0]\n        tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == gen_tokens\n    assert next_batch.size == 1\n    # Add a second request\n    request = create_request(id=1, inputs=input_text, max_new_tokens=max_new_tokens)\n    batch = Batch(id=1, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch_1 = generator.prefill(batch)\n    assert next_batch_1.size == 1\n    # We should have generated only a single token\n    assert len(generations) == 1\n    g = generations[0]\n    tokens[g.request_id].append(g.tokens.ids[0])\n    assert len(tokens[0]) == gen_tokens\n    assert len(tokens[1]) == 1\n    # Decode more tokens until we reach the maximum for the first request\n    batches = [next_batch, next_batch_1]\n    for _ in range(max_new_tokens - gen_tokens):\n        generations, next_batch = generator.decode(batches)\n        for g in generations:\n            tokens[g.request_id].append(g.tokens.ids[0])\n        batches = [next_batch]\n    # Verify we now only have one pending request\n    assert next_batch.size == 1\n    assert len(tokens[0]) == max_new_tokens\n    assert len(tokens[1]) == max_new_tokens - gen_tokens + 1\n    # Verify we have the output for the first request\n    for g in generations:\n        if g.request_id == 0:\n            output = g.generated_text\n            assert output.text != \"\"\n            assert output.generated_tokens == max_new_tokens\n            generated_text = output.text\n    # Continue decoding until the end of the second request\n    for _ in range(gen_tokens - 1):\n        generations, next_batch = generator.decode([next_batch])\n        assert len(generations) == 1\n        g = generations[0]\n        tokens[g.request_id].append(g.tokens.ids[0])\n    assert next_batch is None\n    output = generations[0].generated_text\n    assert output.generated_tokens == max_new_tokens\n    assert tokens[0] == tokens[1]\n    assert output.text == generated_text\n"
  },
  {
    "path": "backends/neuron/tests/server/test_decode.py",
    "content": "from helpers import create_request\nfrom text_generation_server.generator import NeuronGenerator\nfrom text_generation_server.pb.generate_pb2 import Batch\n\n\ndef test_decode(neuron_model_config):\n    \"\"\"Verify that a decoding for a single request generates the expected output.\"\"\"\n    config_name = neuron_model_config[\"name\"]\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    generator = NeuronGenerator.from_pretrained(neuron_model_path)\n    for do_sample in [True, False]:\n        mode = \"sample\" if do_sample else \"greedy\"\n        print(f\"{config_name}[{mode}]\")\n        generated_text = _test_decode(config_name, generator, do_sample)\n        if not do_sample:\n            expected_text = {\n                \"llama\": \" The world was holding its breath as the world's top scientists and engineers gathered at the secret underground facility\",\n                \"qwen2\": \" I was sitting in my room, staring at the clock, when a knock at the door. I\",\n                \"granite\": \"\\n\\nThis opening line is from George Orwell's dystopian novel, \\\"1\",\n            }[config_name]\n            assert generated_text == expected_text\n        generator.clear()\n\n\ndef _test_decode(config_name, generator, do_sample):\n    input_text = (\n        \"It was a bright cold day in April, and the clocks were striking thirteen.\"\n    )\n    max_new_tokens = 20\n    request = create_request(\n        id=0,\n        inputs=input_text,\n        max_new_tokens=max_new_tokens,\n        do_sample=do_sample,\n        temperature=0.9,\n    )\n    max_length = generator.model.neuron_config.sequence_length\n    batch = Batch(id=0, requests=[request], size=1, max_tokens=max_length)\n    generations, next_batch = generator.prefill(batch)\n    # We already generated one token: call decode max_new_tokens - 1 times\n    for _ in range(max_new_tokens - 1):\n        assert next_batch.size == 1\n        assert next_batch.max_tokens == max_length\n        assert len(generations) == 1\n        assert len(generations[0].tokens.ids) == 1\n        generations, next_batch = generator.decode([next_batch])\n    assert next_batch is None\n    assert len(generations) == 1\n    output = generations[0].generated_text\n    assert output.generated_tokens == max_new_tokens\n    assert output.finish_reason == 0\n    return output.text\n"
  },
  {
    "path": "backends/neuron/tests/server/test_generator_slot.py",
    "content": "import pytest\nimport torch\nfrom text_generation_server.generator import Slot\nfrom text_generation_server.pb.generate_pb2 import Request\nfrom transformers import AutoTokenizer, GenerationConfig\n\n\nTOKENIZERS = [\"NousResearch/Llama-2-7b-hf\", \"gpt2\"]\n\n\n@pytest.fixture(params=TOKENIZERS)\ndef tokenizer(request):\n    t = AutoTokenizer.from_pretrained(request.param)\n    t.padding_side = \"left\"\n    t.pad_token_id = t.eos_token_id\n    return t\n\n\n@pytest.mark.parametrize(\n    \"input_text, generated_text\",\n    [\n        [\n            \"It was a bright cold day in April, and the clocks were striking thirteen.\",\n            \" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,\"\n            \" slipped quickly through the glass doors of Victory Mansions, though not quickly enough\"\n            \" to prevent a swirl of gritty dust from entering along with him.\",\n        ],\n        [\"This sentence is written in chinese:\", \"我很感谢你的热情\"],\n        [\"Some text might contain a lot of emojis like 😃\", \"😍💪 👉 👀\"],\n    ],\n    ids=[\"spaces\", \"chinese-utf8\", \"emojis\"],\n)\ndef test_decode_streaming(tokenizer, input_text, generated_text):\n    slot = Slot(0, tokenizer)\n    request = Request(id=0, inputs=input_text)\n    slot.assign(0, request, GenerationConfig())\n    assert slot.cached_text == input_text\n\n    inputs = tokenizer(\n        input_text,\n        padding=\"max_length\",\n        max_length=len(input_text) + 1,\n        return_tensors=\"pt\",\n    )\n    input_ids = inputs[\"input_ids\"][0]\n    attention_mask = inputs[\"attention_mask\"][0]\n    generated_tokens = tokenizer(generated_text, add_special_tokens=False)[\"input_ids\"]\n\n    # We need to regenerate the full text as the tokenizer might change it (extra spaces might be added)\n    all_input_ids = torch.cat([input_ids, torch.tensor(generated_tokens)])\n    full_text = tokenizer.decode(all_input_ids, skip_special_tokens=True)\n    regenerated_text = full_text[len(input_text) :]\n\n    # Initialize the slot with the inputs\n    slot.reset(input_ids, attention_mask, selector=None)\n\n    assert slot.generated_tokens == 0\n\n    # Simulate an iterative generation (i.e. don't call select and use known tokens instead)\n    decoded_text = \"\"\n    for i in range(len(generated_tokens)):\n        text = slot.append(generated_tokens[i])\n        assert slot.generated_tokens == i + 1\n        decoded_text += text\n\n    assert decoded_text == regenerated_text\n"
  },
  {
    "path": "backends/neuron/tests/server/test_info.py",
    "content": "from text_generation_server.generator import NeuronGenerator\n\n\ndef test_info(neuron_model_path):\n    generator = NeuronGenerator.from_pretrained(neuron_model_path)\n    info = generator.info\n    assert info.requires_padding is True\n    assert info.device_type == \"xla\"\n    assert info.window_size == 0\n    assert info.speculate == 0\n"
  },
  {
    "path": "backends/neuron/tests/server/test_prefill.py",
    "content": "from helpers import create_request\nfrom text_generation_server.generator import NeuronGenerator\nfrom text_generation_server.pb.generate_pb2 import Batch\n\n\ndef test_prefill(neuron_model_config):\n    \"\"\"Verify that a prefill for a single request generates the expected output.\"\"\"\n    config_name = neuron_model_config[\"name\"]\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    generator = NeuronGenerator.from_pretrained(neuron_model_path)\n    max_batch_size = 4\n    assert generator.model.neuron_config.batch_size >= max_batch_size\n    for num_requests in [1, max_batch_size]:\n        for do_sample in [True, False]:\n            mode = \"sample\" if do_sample else \"greedy\"\n            print(f\"[{mode}]: {num_requests} requests\")\n            _test_prefill(config_name, generator, num_requests, do_sample)\n            generator.clear()\n\n\ndef _test_prefill(config_name, generator, batch_size, do_sample):\n    requests = []\n    max_new_tokens = 20\n    input_text = (\n        \"It was a bright cold day in April, and the clocks were striking thirteen.\"\n    )\n    for i in range(batch_size):\n        requests.append(\n            create_request(\n                id=i,\n                inputs=input_text,\n                do_sample=do_sample,\n                max_new_tokens=max_new_tokens,\n            )\n        )\n    # Let's be pessimistic when estimating max_tokens\n    max_length = generator.max_prefill_length()\n    batch = Batch(\n        id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length\n    )\n    generations, next_batch = generator.prefill(batch)\n    assert next_batch.size == batch_size\n    # Whatever was passed as max_tokens, the server will correct it\n    # because of static batching\n    assert next_batch.max_tokens == batch_size * max_length\n    assert len(generations) == batch_size\n    expectations = {\n        \"llama\": [578, \" The\"],\n        \"qwen2\": [358, \" I\"],\n        \"granite\": [203, \"\\n\"],\n    }[config_name]\n    # Greedy mode should always generate the same output\n    if not do_sample:\n        for g in generations:\n            tokens = g.tokens\n            assert tokens.ids[0] == expectations[0]\n            assert tokens.texts[0] == expectations[1]\n\n\ndef test_prefill_truncate(neuron_model_config):\n    config_name = neuron_model_config[\"name\"]\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    generator = NeuronGenerator.from_pretrained(neuron_model_path)\n    batch_size = generator.model.neuron_config.batch_size\n    # We apply truncation to all requests but the first one\n    truncate = [\n        None,\n    ] + [i * 3 for i in range(1, batch_size)]\n    input_text = (\n        \"Two gin-scented tears trickled down the sides of his nose.\"\n        \" But it was all right, everything was all right, the struggle was finished.\"\n        \" He had won the victory over himself. He loved Big Brother.\"\n    )\n    requests = []\n    for i in range(batch_size):\n        requests.append(create_request(id=i, inputs=input_text, truncate=truncate[i]))\n    max_length = generator.max_prefill_length()\n    batch = Batch(\n        id=0, requests=requests, size=batch_size, max_tokens=batch_size * max_length\n    )\n    generations, _ = generator.prefill(batch)\n    # Even if the input text is identical for all requests, the first generated token might\n    # be different because of the truncation\n    expectations = {\n        \"llama\": [\" He\", \"iens\", \"\\x08\", \" He\"],\n        \"qwen2\": [\" He\", \"<|endoftext|>\", \" \", \" The\"],\n        \"granite\": [\"\\n\", \"\\n\", \"\\n\", \"\\n\"],\n    }[config_name]\n    for i, g in enumerate(generations):\n        tokens = g.tokens\n        assert (\n            tokens.texts[0] == expectations[i]\n        ), f\"Request {i} expected [{expectations[i]}], got [{tokens.texts[0]}]\"\n"
  },
  {
    "path": "backends/neuron/tests/test_entry_point.py",
    "content": "import os\nimport pytest\nfrom tempfile import TemporaryDirectory\n\nfrom optimum.neuron.models.inference.nxd.backend.config import NxDNeuronConfig\nfrom optimum.neuron.utils import map_torch_dtype\n\nfrom text_generation_server.tgi_env import (\n    get_neuron_config_for_model,\n    lookup_compatible_cached_model,\n    neuron_config_to_env,\n)\n\n\ndef test_get_neuron_config_for_model(neuron_model_config):\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    export_kwargs = neuron_model_config[\"export_kwargs\"]\n    os.environ[\"MAX_BATCH_SIZE\"] = str(export_kwargs[\"batch_size\"])\n    os.environ[\"MAX_TOTAL_TOKENS\"] = str(export_kwargs[\"sequence_length\"])\n    os.environ[\"HF_AUTO_CAST_TYPE\"] = export_kwargs[\"auto_cast_type\"]\n    os.environ[\"HF_NUM_CORES\"] = str(export_kwargs[\"num_cores\"])\n    neuron_config = get_neuron_config_for_model(neuron_model_path)\n    assert neuron_config is not None\n    assert neuron_config.batch_size == export_kwargs[\"batch_size\"]\n    assert neuron_config.sequence_length == export_kwargs[\"sequence_length\"]\n    assert neuron_config.tp_degree == export_kwargs[\"num_cores\"]\n    if isinstance(neuron_config, NxDNeuronConfig):\n        assert map_torch_dtype(neuron_config.torch_dtype) == map_torch_dtype(\n            export_kwargs[\"auto_cast_type\"]\n        )\n    else:\n        assert map_torch_dtype(neuron_config.auto_cast_type) == map_torch_dtype(\n            export_kwargs[\"auto_cast_type\"]\n        )\n\n\n@pytest.mark.parametrize(\"model_id\", [\"unsloth/Llama-3.2-1B-Instruct\"])\ndef test_lookup_compatible_cached_model(model_id: str):\n    neuron_config = lookup_compatible_cached_model(model_id, None)\n    assert neuron_config is not None\n\n\ndef test_neuron_config_to_env(neuron_model_config) -> None:\n    neuron_model_path = neuron_model_config[\"neuron_model_path\"]\n    neuron_config = get_neuron_config_for_model(neuron_model_path)\n    with TemporaryDirectory() as temp_dir:\n        os.environ[\"ENV_FILEPATH\"] = os.path.join(temp_dir, \"env.sh\")\n        neuron_config_to_env(neuron_config)\n        with open(os.environ[\"ENV_FILEPATH\"], \"r\") as env_file:\n            env_content = env_file.read()\n            assert f\"export MAX_BATCH_SIZE={neuron_config.batch_size}\" in env_content\n            assert (\n                f\"export MAX_TOTAL_TOKENS={neuron_config.sequence_length}\"\n                in env_content\n            )\n            assert f\"export HF_NUM_CORES={neuron_config.tp_degree}\" in env_content\n            if hasattr(neuron_config, \"torch_dtype\"):\n                auto_cast_type = str(map_torch_dtype(neuron_config.torch_dtype)).split(\n                    \".\"\n                )[-1]\n            else:\n                auto_cast_type = neuron_config.auto_cast_type\n            assert f\"export HF_AUTO_CAST_TYPE={auto_cast_type}\" in env_content\n"
  },
  {
    "path": "backends/neuron/tgi-entrypoint.sh",
    "content": "#!/bin/bash\nset -e -o pipefail -u\n\nexport ENV_FILEPATH=$(mktemp)\n\ntrap \"rm -f ${ENV_FILEPATH}\" EXIT\n\ntouch $ENV_FILEPATH\n\nSCRIPT_DIR=$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\n\n${SCRIPT_DIR}/tgi_entry_point.py $@\n\nsource $ENV_FILEPATH\n\nexec text-generation-launcher $@\n"
  },
  {
    "path": "backends/neuron/tgi_entry_point.py",
    "content": "#!/usr/bin/env python\n\nimport logging\nimport os\nimport sys\n\n\nfrom text_generation_server.tgi_env import (\n    available_cores,\n    get_env_dict,\n    get_neuron_config_for_model,\n    neuron_config_to_env,\n    neuronxcc_version,\n    parse_cmdline_and_set_env,\n    tgi_env_vars,\n)\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef main():\n    \"\"\"\n    This script determines proper default TGI env variables for the neuron precompiled models to\n    work properly\n    :return:\n    \"\"\"\n    args = parse_cmdline_and_set_env()\n\n    for env_var in tgi_env_vars:\n        if not os.getenv(env_var):\n            break\n    else:\n        logger.info(\n            \"All env vars %s already set, skipping, user know what they are doing\",\n            tgi_env_vars,\n        )\n        sys.exit(0)\n\n    neuron_config = get_neuron_config_for_model(args.model_id, args.revision)\n\n    if not neuron_config:\n        msg = (\n            \"No compatible neuron config found. Provided env {}, available cores {}, neuronxcc version {}\"\n        ).format(get_env_dict(), available_cores, neuronxcc_version)\n        logger.error(msg)\n        raise Exception(msg)\n\n    neuron_config_to_env(neuron_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "backends/trtllm/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.20)\n\nif (CMAKE_VERSION VERSION_GREATER_EQUAL \"3.24.0\")\n    cmake_policy(SET CMP0135 NEW)\nendif ()\n\nproject(tgi-trtllm-backend VERSION 1.0.0)\nset(CMAKE_CXX_STANDARD 23)\n\ninclude(FetchContent)\ninclude(ExternalProject)\ninclude(CheckCXXCompilerFlag)\n\noption(TGI_TRTLLM_BACKEND_BUILD_TESTS \"Enable building the unittests suite\" OFF)\noption(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES \"Enable building the examples suite\" OFF)\noption(TGI_TRTLLM_BACKEND_BUILD_USE_LLD \"Enable lld linker instead of ld\" OFF)\nset(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST \"89-real\" CACHE STRING \"List of CUDA architectures to support\")\nset(TGI_TRTLLM_BACKEND_TRT_ROOT \"/usr/local/tensorrt\" CACHE STRING \"Path where TensorRT libraries and headers are located\")\nset(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR \"${TGI_TRTLLM_BACKEND_TRT_ROOT}/include\" CACHE STRING \"Path where TensorRT headers are located\")\nset(TGI_TRTLLM_BACKEND_TRT_LIB_DIR \"${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib\" CACHE STRING \"Path where TensorRT libraries are located\")\n\n# We are using nvidia-ml to query at runtime device information to enable some architecture-specific features\nfind_package(CUDAToolkit 12.6 REQUIRED COMPONENTS CUDA::cudart CUDA::nvml)\nfind_package(MPI REQUIRED)\n\n#### External dependencies ####\ninclude(cmake/json.cmake)\ninclude(cmake/spdlog.cmake)\ninclude(cmake/trtllm.cmake)\n\nif (CMAKE_BUILD_TYPE STREQUAL \"Debug\")\n    set(TGI_TRTLLM_BACKEND_DEBUG ON)\n    add_compile_definitions(TGI_TRTLLM_BACKEND_DEBUG=1)\n    add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)\nendif ()\n\nif (${TGI_TRTLLM_BACKEND_BUILD_USE_LLD})\n    message(STATUS \"Using lld linker\")\n    add_link_options(\"-fuse-ld=lld\")\nendif ()\n\n# Let's build TRTLLM as part of CMake\nadd_subdirectory(\"${trtllm_SOURCE_DIR}/cpp\" \"${trtllm_SOURCE_DIR}/..\")\n\n# Tell CMake to need try to override the RPATH for executorWorker as it has not information on how to do so\nset_target_properties(executorWorker PROPERTIES SKIP_BUILD_RPATH TRUE)\n\n# TGI TRTLLM Backend definition\nadd_library(tgi_trtllm_backend_impl STATIC csrc/hardware.hpp csrc/backend.hpp csrc/backend.cpp)\ninclude_directories(${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})\ntarget_include_directories(tgi_trtllm_backend_impl PRIVATE\n        $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/csrc>\n        #        $<INSTALL_INTERFACE:csrc>\n)\ntarget_include_directories(tgi_trtllm_backend_impl PUBLIC \"${trtllm_SOURCE_DIR}/cpp/include\")\ntarget_link_libraries(tgi_trtllm_backend_impl PRIVATE CUDA::cudart CUDA::nvml)\ntarget_link_libraries(tgi_trtllm_backend_impl PUBLIC nlohmann_json::nlohmann_json spdlog::spdlog)\ntarget_link_libraries(tgi_trtllm_backend_impl PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)\n\n# This install all the artifacts in CMAKE_INSTALL_PREFIX under include/ lib/ bin/ to make easy to link / find it back\ninstall(TARGETS tgi_trtllm_backend_impl)\n#install(TARGETS cutlass_src fb_gemm_src fpA_intB_gemm_src gemm_swiglu_sm90_src kernels_src)\ninstall(TARGETS decoder_attention_0 decoder_attention_1)\ninstall(TARGETS tensorrt_llm nvinfer_plugin_tensorrt_llm decoder_attention_src executorWorker)\ninstall(FILES ${TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH} TYPE LIB)\nif (NOT ${TGI_TRTLLM_BACKEND_DEBUG})\n    install(FILES ${TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH} TYPE LIB)\nendif ()\n\n\n#### Unit Tests ####\nif (${TGI_TRTLLM_BACKEND_BUILD_TESTS} AND CMAKE_BUILD_TYPE MATCHES \"Debug\")\n    message(STATUS \"Building tests\")\n    option(TGI_TRTLLM_BACKEND_ENABLE_ASAN \"Enable AddressSanitizer\")\n    option(TGI_TRTLLM_BACKEND_ENABLE_UBSAN \"Enable UndefinedSanitizer\")\n\n    FetchContent_Declare(\n            Catch2\n            URL https://github.com/catchorg/Catch2/archive/refs/tags/v3.7.1.tar.gz\n    )\n    FetchContent_MakeAvailable(Catch2)\n\n    # This attempt to detect if the compiler can emit warning if it can't apply return value optimization from a function\n    check_cxx_compiler_flag(\"-Wnrvo\" COMPILER_SUPPORT_WARNING_ON_NVRO)\n    if (${COMPILER_SUPPORT_WARNING_ON_NVRO})\n        message(STATUS \"Enabling non-NVRO detection\")\n        target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wnrvo)\n    endif ()\n    target_compile_options(tgi_trtllm_backend_impl PRIVATE -Wall)\n\n    cmake_path(GET TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH PARENT_PATH TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH)\n    message(STATUS \"Adding linking path: ${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}\")\n\n    add_executable(tgi_trtllm_backend_tests tests/test_hardware.cpp tests/test_backend.cpp)\n\n    #    target_compile_options(tgi_trtllm_backend_tests PRIVATE -Werror)\n    target_link_directories(tgi_trtllm_backend_tests PRIVATE \"${TRTLLM_NVRTC_WRAPPER_PARENT_LIBRARY_PATH}\")\n    target_include_directories(tgi_trtllm_backend_tests PUBLIC \"${trtllm_SOURCE_DIR}/cpp/include\")\n    target_include_directories(tgi_trtllm_backend_tests PUBLIC \"csrc/\")\n    target_link_libraries(tgi_trtllm_backend_tests PRIVATE ${TRTLLM_LIBS} CUDA::cudart CUDA::nvml)\n    target_link_libraries(tgi_trtllm_backend_tests PUBLIC Catch2::Catch2WithMain nlohmann_json::nlohmann_json spdlog::spdlog tgi_trtllm_backend_impl)\n    target_link_libraries(tgi_trtllm_backend_tests PRIVATE tensorrt_llm nvinfer_plugin_tensorrt_llm tensorrt_llm_nvrtc_wrapper)\n\n    if (${TGI_TRTLLM_BACKEND_ENABLE_ASAN})\n        message(STATUS \"Enabled AddressSanitizer\")\n        target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=address)\n    endif ()\n\n    if (${TGI_TRTLLM_BACKEND_ENABLE_UBSAN})\n        message(STATUS \"Enabled UndefinedSanitizer\")\n        target_link_options(tgi_trtllm_backend_tests BEFORE PUBLIC -fsanitize=undefined)\n    endif ()\n\n    install(TARGETS tgi_trtllm_backend_tests)\n\n    #    list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)\n    #    include(CTest)\n    #    include(Catch)\n    #    catch_discover_tests(tgi_trtllm_backend_tests)\nendif ()\n"
  },
  {
    "path": "backends/trtllm/Cargo.toml",
    "content": "[package]\nname = \"text-generation-backends-trtllm\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[dependencies]\nasync-trait = \"0.1\"\nclap = { version = \"4.5\", features = [\"derive\"] }\ncxx = \"1.0\"\nhashbrown = \"0.15\"\nhf-hub = { workspace = true }\ntext-generation-router = { path = \"../../router\" }\ntokenizers = { workspace = true }\ntokio = { version = \"1.43.0\", features = [\"rt\", \"rt-multi-thread\", \"parking_lot\", \"signal\", \"sync\"] }\ntokio-stream = \"0.1.17\"\nthiserror = \"1.0.63\"\ntracing = \"0.1\"\npyo3 = { workspace = true }\n\n[build-dependencies]\ncmake = \"0.1\"\ncxx-build = { version = \"1.0\", features = [\"parallel\"] }\npkg-config = \"0.3\"\n"
  },
  {
    "path": "backends/trtllm/README.md",
    "content": "# Text Generation Inference - TensorRT-LLM Backend Implementation\n\n## Description\n\nThis folder provides the sources of the TensorRT-LLM backend implementation powered by TensorRT-LLM Executor new API\n\n## Simplified Request Sequence\n\n```mermaid\nsequenceDiagram\n    actor User\n    participant TextGenerationInference.HttpServer\n    participant TextGenerationInference.TensorRtLlmBackend\n    participant TextGenerationInference.TensorRtLlmWorkerThread\n    participant TensorRtLlm.Executor\n    participant Nvidia.Gpu\n    User ->> TextGenerationInference.HttpServer: POST /generate\n    TextGenerationInference.HttpServer ->> TextGenerationInference.TensorRtLlmBackend: Validate and forward inputs & parameters\n    TextGenerationInference.TensorRtLlmBackend ->> TextGenerationInference.TensorRtLlmWorkerThread: Allocate a new context and spawn a new thread to handle the request\n    TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Submit the request to the In-Flight Batcher\n    activate Nvidia.Gpu\n    TensorRtLlm.Executor ->> Nvidia.Gpu: Add the request to the poll for execution\n    TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Response with an unique request identifier\n    rect rgb(10, 92, 54)\n        loop every 100us\n            rect rgb(15, 81, 50)\n                alt Acquire lock to query executor\n                    TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Poll request number of new token(s) generated\n                else There are new generated tokens\n                    TextGenerationInference.TensorRtLlmWorkerThread ->> TensorRtLlm.Executor: Retrieve newly generated tokens\n                    TensorRtLlm.Executor -->> TextGenerationInference.TensorRtLlmWorkerThread: Return decoded token information and potential error (omitted)\n                    rect rgb(11, 110, 79)\n                        alt Generated token is final\n                            TensorRtLlm.Executor ->> Nvidia.Gpu: Remove request from the scheduler and from the GPU\n                            TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream the remaining decoded tokens and flush the connection\n                        else Generated token is not final\n                            TextGenerationInference.TensorRtLlmWorkerThread -->> User: Stream token back to the user as they get decoded\n                        end\n                    end\n                end\n            end\n            deactivate Nvidia.Gpu\n        end\n    end\n\n```\n"
  },
  {
    "path": "backends/trtllm/build.rs",
    "content": "use cxx_build::CFG;\nuse pkg_config;\nuse std::env;\nuse std::env::consts::ARCH;\nuse std::path::{absolute, PathBuf};\nuse std::sync::LazyLock;\n\nconst ADDITIONAL_BACKEND_LINK_LIBRARIES: [&str; 1] = [\"spdlog\"];\nconst CUDA_ARCH_LIST: Option<&str> = option_env!(\"CUDA_ARCH_LIST\");\nconst CUDA_REQUIRED_VERSION: &str = \"12.8\";\nconst MPI_REQUIRED_VERSION: &str = \"4.1\";\nconst INSTALL_PREFIX: Option<&str> = option_env!(\"CMAKE_INSTALL_PREFIX\");\nconst TENSORRT_ROOT_DIR: Option<&str> = option_env!(\"TENSORRT_ROOT_DIR\");\nconst NCCL_ROOT_DIR: Option<&str> = option_env!(\"NCCL_ROOT_DIR\");\n\nconst IS_GHA_BUILD: LazyLock<bool> = LazyLock::new(|| {\n    option_env!(\"SCCACHE_GHA_ENABLED\").map_or(false, |value| match value.to_lowercase().as_str() {\n        \"on\" => true,\n        \"true\" => true,\n        \"1\" => true,\n        _ => false,\n    })\n});\n\n// Dependencies\nconst BACKEND_DEPS: &str = \"tgi_trtllm_backend_impl\";\nconst CUDA_TRANSITIVE_DEPS: [&str; 4] = [\"cuda\", \"cudart\", \"cublas\", \"nvidia-ml\"];\nconst TENSORRT_LLM_TRANSITIVE_DEPS: [(&str, &str); 5] = [\n    (\"dylib\", \"tensorrt_llm\"),\n    (\"dylib\", \"tensorrt_llm_nvrtc_wrapper\"),\n    (\"dylib\", \"nvinfer_plugin_tensorrt_llm\"),\n    (\"dylib\", \"decoder_attention_0\"),\n    (\"dylib\", \"decoder_attention_1\"),\n];\n\nmacro_rules! probe {\n    ($name: expr, $version: expr) => {\n        if let Err(_) = pkg_config::probe_library($name) {\n            pkg_config::probe_library(&format!(\"{}-{}\", $name, $version))\n                .expect(&format!(\"Failed to locate {}\", $name));\n        }\n    };\n}\n\nfn get_compiler_flag(\n    switch: bool,\n    true_case: &'static str,\n    false_case: &'static str,\n) -> &'static str {\n    match switch {\n        true => true_case,\n        false => false_case,\n    }\n}\n\nfn get_library_architecture() -> &'static str {\n    let os = env::var(\"CARGO_CFG_TARGET_OS\").unwrap();\n    let arch = env::var(\"CARGO_CFG_TARGET_ARCH\").unwrap();\n    let env = env::var(\"CARGO_CFG_TARGET_ENV\").unwrap();\n\n    match os.as_str() {\n        \"linux\" => {\n            if env != \"gnu\" {\n                panic!(\"unsupported linux ABI {env}, only 'gnu' is supported\")\n            }\n\n            match arch.as_str() {\n                \"x86_64\" => \"x86_64-linux-gnu\",\n                \"aarch64\" => \"aarch64-linux-gnu\",\n                _ => panic!(\"unsupported linux architecture {arch}\"),\n            }\n        }\n        \"windows\" => {\n            if env != \"msvc\" {\n                panic!(\"unsupported windows ABI {env}, only 'msvc' is supported\")\n            }\n\n            match arch.as_str() {\n                \"x86_64\" => \"x86_64-windows-msvc\",\n                _ => panic!(\"unsupported windows architecture {arch}\"),\n            }\n        }\n        _ => panic!(\"unsupported OS {os}\"),\n    }\n}\n\nfn build_backend(is_debug: bool, opt_level: &str, out_dir: &PathBuf) -> (PathBuf, PathBuf) {\n    // Build the backend implementation through CMake\n    let install_path = INSTALL_PREFIX.unwrap_or(\"/usr/local/tgi\");\n    let tensorrt_path = TENSORRT_ROOT_DIR.unwrap_or(\"/usr/local/tensorrt\");\n    let cuda_arch_list = CUDA_ARCH_LIST.unwrap_or(\"75-real;80-real;86-real;89-real;90-real\");\n\n    let mut install_path = PathBuf::from(install_path);\n    if !install_path.is_absolute() {\n        install_path = absolute(out_dir).expect(\"cannot happen\").join(install_path);\n    }\n\n    let mut config = cmake::Config::new(\".\");\n    config\n        .uses_cxx11()\n        .generator(\"Ninja\")\n        .profile(match is_debug {\n            true => \"Debug\",\n            false => \"Release\",\n        })\n        .env(\"OPT_LEVEL\", opt_level)\n        .define(\"CMAKE_INSTALL_PREFIX\", &install_path)\n        .define(\"CMAKE_CUDA_COMPILER\", \"/usr/local/cuda/bin/nvcc\")\n        .define(\"CMAKE_LIBRARY_ARCHITECTURE\", get_library_architecture())\n        .define(\"TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST\", cuda_arch_list)\n        .define(\n            \"TGI_TRTLLM_BACKEND_DEBUG\",\n            get_compiler_flag(is_debug, \"ON\", \"OFF\"),\n        )\n        .define(\"TGI_TRTLLM_BACKEND_TRT_ROOT\", tensorrt_path);\n\n    if is_debug || *IS_GHA_BUILD {\n        config.define(\"TGI_TRTLLM_BACKEND_BUILD_TESTS\", \"ON\");\n    }\n\n    if option_env!(\"USE_LLD_LINKER\").is_some() {\n        println!(\"cargo:warning=Using lld linker\");\n        config.define(\"TGI_TRTLLM_BACKEND_BUILD_USE_LLD\", \"ON\");\n    }\n\n    if (is_debug && option_env!(\"ENABLE_ASAN\").is_some()) || *IS_GHA_BUILD {\n        println!(\"cargo:warning=Enabling Address Sanitizer\");\n        config.define(\"TGI_TRTLLM_BACKEND_ENABLE_ASAN\", \"ON\");\n    }\n\n    if (is_debug && option_env!(\"ENABLE_UBSAN\").is_some()) || *IS_GHA_BUILD {\n        println!(\"cargo:warning=Enabling Undefined Sanitizer\");\n        config.define(\"TGI_TRTLLM_BACKEND_ENABLE_UBSAN\", \"ON\");\n    }\n\n    if let Some(nvcc_host_compiler) = option_env!(\"CMAKE_CUDA_HOST_COMPILER\") {\n        config.define(\"CMAKE_CUDA_HOST_COMPILER\", nvcc_host_compiler);\n    }\n\n    if let Some(wrapper) = option_env!(\"RUSTC_WRAPPER\") {\n        println!(\"cargo:warning=Using caching tool: {wrapper}\");\n        config.define(\"CMAKE_C_COMPILER_LAUNCHER\", wrapper);\n        config.define(\"CMAKE_CXX_COMPILER_LAUNCHER\", wrapper);\n        config.define(\"CMAKE_CUDA_COMPILER_LAUNCHER\", wrapper);\n    }\n\n    // Allow to override which Python to use ...\n    if let Some(python3) = option_env!(\"Python3_EXECUTABLE\") {\n        config.define(\"Python3_EXECUTABLE\", python3);\n    }\n\n    config.build();\n\n    // Additional transitive CMake dependencies\n    let deps_folder = out_dir.join(\"build\").join(\"_deps\");\n    for dependency in ADDITIONAL_BACKEND_LINK_LIBRARIES {\n        let dep_name = match is_debug {\n            true => format!(\"{}d\", dependency),\n            false => String::from(dependency),\n        };\n        let dep_path = deps_folder.join(format!(\"{}-build\", dependency));\n        println!(\"cargo:rustc-link-search={}\", dep_path.display());\n        println!(\"cargo:rustc-link-lib=static={}\", dep_name);\n    }\n\n    // Emit linkage information from the artifacts we just built\n    for path in [\"lib\", \"lib64\"] {\n        let install_lib_path = install_path.join(path);\n        println!(\n            r\"cargo:warning=Adding link search path: {}\",\n            install_lib_path.display()\n        );\n        println!(r\"cargo:rustc-link-search={}\", install_lib_path.display());\n    }\n    (PathBuf::from(install_path), deps_folder)\n}\n\nfn build_ffi_layer(deps_folder: &PathBuf, is_debug: bool) {\n    CFG.include_prefix = \"backends/trtllm\";\n    cxx_build::bridge(\"src/lib.rs\")\n        .static_flag(true)\n        .std(\"c++23\")\n        .include(deps_folder.join(\"spdlog-src\").join(\"include\"))\n        .include(deps_folder.join(\"json-src\").join(\"include\"))\n        .include(deps_folder.join(\"trtllm-src\").join(\"cpp\").join(\"include\"))\n        .include(\"/usr/local/cuda/include\")\n        .include(\"/usr/local/tensorrt/include\")\n        .include(\"csrc/\")\n        .file(\"csrc/ffi.hpp\")\n        .define(\n            \"TGI_TRTLLM_BACKEND_DEBUG\",\n            get_compiler_flag(is_debug, \"ON\", \"OFF\"),\n        )\n        .compile(\"tgi_trtllm_backend\");\n\n    println!(\"cargo:rerun-if-changed=CMakeLists.txt\");\n    println!(\"cargo:rerun-if-changed=cmake/trtllm.cmake\");\n    println!(\"cargo:rerun-if-changed=cmake/json.cmake\");\n    println!(\"cargo:rerun-if-changed=cmake/spdlog.cmake\");\n    println!(\"cargo:rerun-if-changed=csrc/backend.hpp\");\n    println!(\"cargo:rerun-if-changed=csrc/backend.cpp\");\n    println!(\"cargo:rerun-if-changed=csrc/hardware.hpp\");\n    println!(\"cargo:rerun-if-changed=csrc/ffi.hpp\");\n}\n\nfn main() {\n    // Misc variables\n    let out_dir = PathBuf::from(env::var(\"OUT_DIR\").unwrap());\n    let build_profile = env::var(\"PROFILE\").unwrap();\n    let (is_debug, opt_level) = match build_profile.as_ref() {\n        \"debug\" => (true, \"0\"),\n        \"dev\" => (true, \"0\"),\n        _ => (false, \"3\"),\n    };\n\n    // Build the backend\n    let (_backend_path, deps_folder) = build_backend(is_debug, opt_level, &out_dir);\n\n    // Build the FFI layer calling the backend above\n    build_ffi_layer(&deps_folder, is_debug);\n\n    // Emit linkage search path\n    probe!(\"ompi\", MPI_REQUIRED_VERSION);\n\n    // Probe CUDA & co. with pkg-config\n    CUDA_TRANSITIVE_DEPS.iter().for_each(|name| {\n        probe!(name, CUDA_REQUIRED_VERSION);\n    });\n\n    // NCCL is slightly trickier because it might not have a pkgconfig installed\n    let nccl_library_path_default = format!(\"/usr/local/{}-linux-gnu\", ARCH);\n    let nccl_library_path = NCCL_ROOT_DIR.unwrap_or(&nccl_library_path_default);\n    println!(r\"cargo:rustc-link-search=native={}\", nccl_library_path);\n    println!(\"cargo:rustc-link-lib=dylib=nccl\");\n\n    // TensorRT\n    let tensort_library_path = TENSORRT_ROOT_DIR.unwrap_or(\"/usr/local/tensorrt/lib\");\n    println!(r\"cargo:rustc-link-search=native={}\", tensort_library_path);\n    println!(\"cargo:rustc-link-lib=dylib=nvinfer\");\n\n    // TensorRT-LLM\n    TENSORRT_LLM_TRANSITIVE_DEPS\n        .iter()\n        .for_each(|(link_type, name)| {\n            println!(\"cargo:rustc-link-lib={}={}\", link_type, name);\n        });\n\n    // Backend\n    println!(\"cargo:rustc-link-lib=static={}\", &BACKEND_DEPS);\n}\n"
  },
  {
    "path": "backends/trtllm/cmake/json.cmake",
    "content": "fetchcontent_declare(\n        json\n#        DOWNLOAD_EXTRACT_TIMESTAMP\n        URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.tar.gz\n)\nfetchcontent_makeavailable(json)\n"
  },
  {
    "path": "backends/trtllm/cmake/spdlog.cmake",
    "content": "set(SPDLOG_USE_FMT ON)\nset(SPDLOG_BUILD_SHARED OFF)\nset(SPDLOG_FMT_EXTERNAL OFF)\n\n# Define the level at which SPDLOG_ compilation level is defined\nif (${CMAKE_BUILD_TYPE} STREQUAL \"Debug\")\n    add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_TRACE)\nelse ()\n    add_compile_definitions(SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG)\nendif ()\n\nfetchcontent_declare(\n        spdlog\n        #        DOWNLOAD_EXTRACT_TIMESTAMP\n        URL https://github.com/gabime/spdlog/archive/refs/tags/v1.15.0.tar.gz\n)\nfetchcontent_makeavailable(spdlog)\n"
  },
  {
    "path": "backends/trtllm/cmake/trtllm.cmake",
    "content": "set(TRT_INCLUDE_DIR ${TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR})\nset(TRT_LIB_DIR ${TGI_TRTLLM_BACKEND_TRT_LIB_DIR})\n\nset(USE_CXX11_ABI ON)\nset(BUILD_PYT OFF)\nset(BUILD_PYBIND OFF)\nset(BUILD_MICRO_BENCHMARKS OFF)\nset(BUILD_BENCHMARKS OFF)\nset(BUILD_TESTS OFF)\nset(CMAKE_CUDA_ARCHITECTURES ${TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST})\n\nmessage(STATUS \"Building for CUDA Architectures: ${CMAKE_CUDA_ARCHITECTURES}\")\n\nset(ENABLE_UCX OFF)\nif (${CMAKE_BUILD_TYPE} STREQUAL \"Debug\")\n    set(FAST_BUILD ON)\n    set(NVTX_DISABLE ON)\n    set(INDEX_RANGE_CHECK ON)\nelse ()\n    set(FAST_BUILD OFF)\n    set(FAST_MATH ON)\n    set(NVTX_DISABLE OFF)\n    set(INDEX_RANGE_CHECK OFF)\nendif ()\n\nfind_package(Python3 REQUIRED Interpreter)\n\nfetchcontent_declare(\n        trtllm\n        GIT_REPOSITORY https://github.com/nvidia/TensorRT-LLM.git\n        GIT_TAG v0.17.0\n        GIT_SHALLOW ON\n        DOWNLOAD_EXTRACT_TIMESTAMP\n)\nfetchcontent_makeavailable(trtllm)\n\nmessage(STATUS \"Found TensorRT-LLM: ${trtllm_SOURCE_DIR}\")\nexecute_process(COMMAND git lfs install WORKING_DIRECTORY \"${trtllm_SOURCE_DIR}/\")\nexecute_process(COMMAND git lfs pull WORKING_DIRECTORY \"${trtllm_SOURCE_DIR}/\")\n\n# TRTLLM use a JIT based *precompiled* library to generate some specific kernels, we are generating the path to this one here\nset(TRTLLM_NVRTC_LIBRARY_NAME \"${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_nvrtc_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}\" CACHE INTERNAL \"nvrtc wrapper library name\")\nset(TRTLLM_NVRTC_WRAPPER_LIBRARY_PATH \"${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_NVRTC_LIBRARY_NAME}\"\n        CACHE INTERNAL \"nvrtc wrapper library path\")\n\n# The same Executor Static library\nset(TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME \"${CMAKE_SHARED_LIBRARY_PREFIX}tensorrt_llm_executor_static${CMAKE_STATIC_LIBRARY_SUFFIX}\" CACHE INTERNAL \"executor_static library name\")\nset(TRTLLM_EXECUTOR_STATIC_LIBRARY_PATH \"${trtllm_SOURCE_DIR}/cpp/tensorrt_llm/executor/${CMAKE_LIBRARY_ARCHITECTURE}/${TRTLLM_EXECUTOR_STATIC_LIBRARY_NAME}\" CACHE INTERNAL \"executor_static library path\")\n"
  },
  {
    "path": "backends/trtllm/cmake/utils/detect_cuda_arch.cu",
    "content": ""
  },
  {
    "path": "backends/trtllm/csrc/backend.cpp",
    "content": "#include <ranges>\n\n#include <nlohmann/json.hpp>\n\n#include \"backend.hpp\"\n#include \"hardware.hpp\"\n\nnamespace huggingface::tgi::backends::trtllm {\n    tle::ParallelConfig backend_workspace_t::parallel_config() const {\n        // Single engine (TP = PP = 1) -> using leader mode (no MPI involved)\n        const auto world_size = config_[\"/pretrained_config/mapping/world_size\"_json_pointer].get<size_t>();\n\n        auto mode = tle::CommunicationMode::kLEADER;\n        std::optional<tle::OrchestratorConfig> orchestratorConfig = std::nullopt;\n\n        if (world_size > 1) {\n            SPDLOG_INFO(\"Detected sharded engine deployment, using orchestrator mode\");\n            mode = tle::CommunicationMode::kORCHESTRATOR;\n            orchestratorConfig = std::make_optional<tle::OrchestratorConfig>(true, executor_worker_path_, nullptr,\n                                                                             true);\n        } else {\n            SPDLOG_INFO(\"Detected single engine deployment, using leader mode\");\n        }\n\n        return tle::ParallelConfig(tle::CommunicationType::kMPI, mode, std::nullopt, std::nullopt, orchestratorConfig);\n    }\n\n\n    tle::ExecutorConfig backend_workspace_t::executor_config() const {\n        // Retrieve the compute capabilities to enable some options at runtime\n        const auto compute_capabilities = hardware::cuda::compute_capabilities_t();\n\n        // Allocate the config\n        tle::ExecutorConfig executor_config(/* maxBeamWidth = */ 1);\n\n        // Set the parallel config as inferred\n        executor_config.setParallelConfig(parallel_config());\n\n        // Define some configuration variables\n        executor_config.setKvCacheConfig(tle::KvCacheConfig(true));\n        executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());\n        executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));\n        return executor_config;\n    }\n\n    backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)\n            : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}\n\n    size_t backend_t::num_tokens_ready() const noexcept {\n        return executor_.getNumResponsesReady();\n    }\n\n    std::expected<request_id_t, backend_error_t>\n    backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,\n                      const sampling_params_t s_params) noexcept {\n        SPDLOG_DEBUG(\"Submit {:d} tokens for scheduling ({}, {})\", token_ids.size(), g_params, s_params);\n        return executor_.enqueueRequest(tle::Request{\n                {token_ids.begin(), token_ids.end()},  // Making actual copy of the tokens\n                static_cast<tle::SizeType32>(g_params.max_new_tokens),\n                true,\n                (tle::SamplingConfig) s_params,\n                tle::OutputConfig{ /* returnLogProbs= */ true},\n                std::nullopt,\n                std::nullopt,\n                std::nullopt,\n                std::nullopt,\n                workspace.generation_config().stop_words\n        });\n    }\n\n    std::vector<tle::Response> backend_t::pull_tokens() noexcept {\n        SPDLOG_TRACE(FMT_STRING(\"Pulling out tokens ({:d} available)\"), num_tokens_ready());\n        return executor_.awaitResponses();\n    }\n\n    void backend_t::cancel(request_id_t request_id) noexcept {\n        SPDLOG_TRACE(FMT_STRING(\"Cancelling request: {:d}\"), request_id);\n        executor_.cancelRequest(request_id);\n    }\n}\n"
  },
  {
    "path": "backends/trtllm/csrc/backend.hpp",
    "content": "#ifndef TGI_BACKEND_TRTLLM\n#define TGI_BACKEND_TRTLLM\n\n#include <cmath>\n#include <cstdint>\n#include <expected>\n#include <fstream>\n#include <list>\n#include <span>\n\n#include <nlohmann/json.hpp>\n#include <spdlog/spdlog.h>\n#include <spdlog/fmt/fmt.h>\n\n#include <tensorrt_llm/executor/executor.h>\n\nnamespace huggingface::tgi::backends::trtllm {\n    namespace tle = tensorrt_llm::executor;\n    using json = nlohmann::json;\n    using request_id_t = uint64_t;\n    using token_id_t = tle::TokenIdType;\n\n    /**\n     * Represent the parameters used for generation\n     */\n    struct generation_params_t {\n        uint32_t max_new_tokens;\n    };\n\n    /**\n     * Represent the parameters used to sample tokens from the logit distribution\n     */\n    struct sampling_params_t {\n        uint32_t top_k;\n        float_t top_p;\n        float_t repetition_penalty;\n        float_t frequency_penalty;\n        float_t temperature;\n        uint64_t seed;\n\n        constexpr explicit operator tle::SamplingConfig() const {\n            return tle::SamplingConfig{\n                    1,\n                    top_k,\n                    top_p,\n                    std::nullopt,\n                    std::nullopt,\n                    std::nullopt,\n                    seed,\n                    temperature,\n                    std::nullopt,\n                    std::nullopt,\n                    repetition_penalty,\n                    std::nullopt,\n                    frequency_penalty,\n                    std::nullopt\n            };\n        }\n    };\n\n    /**\n     * Represent possible values from transformers generation `generation_config.json`.\n     * It usually stores default sampling parameters to use, such as top_p, temperature, etc.\n     */\n    struct generation_config_t {\n        float_t top_p;\n        float_t temperature;\n        std::list<std::vector<int32_t>> stop_words;\n\n        constexpr explicit generation_config_t(const json &config) :\n                top_p(config.value(\"top_p\", 1.0f)), temperature(config.value(\"temperature\", 1.0f)), stop_words(0) {\n            if (config.contains(\"/eos_token_id\"_json_pointer) && config[\"/eos_token_id\"_json_pointer].is_array()) {\n                const auto &eos_token_id = config[\"/eos_token_id\"_json_pointer];\n                std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {\n                    stop_words.emplace_back(1, token_id.template get<int32_t>());\n                });\n\n                SPDLOG_DEBUG(\"Detected {:d} predefined stop_words from generation_config.json\", stop_words.size());\n            }\n        }\n    };\n\n    /**\n     * Helper class representing various items which are stored within the TensorRT-LLM engines folder and\n     * can be retrieved at runtime\n     */\n    class backend_workspace_t {\n    private:\n        constexpr static auto as_json = [](const std::filesystem::path &path) -> json {\n            std::ifstream config_f(path);\n            return json::parse(config_f);\n        };\n\n        std::filesystem::path engines_folder_;\n        std::filesystem::path executor_worker_path_;\n        json config_;\n        generation_config_t generation_config_;\n\n    public:\n        backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) :\n                engines_folder_(engines_folder),\n                executor_worker_path_(executor_worker_path),\n                config_(as_json(engines_folder / \"config.json\")),\n                generation_config_(as_json(engines_folder / \"generation_config.json\")) {};\n\n        backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) :\n                engines_folder_(engines_folder),\n                executor_worker_path_(executor_worker_path),\n                config_(as_json(engines_folder / \"config.json\")),\n                generation_config_(as_json(engines_folder / \"generation_config.json\")) {};\n\n        /**\n         * Path to the folder containing the TensorRT-LLM engines\n         * @return local filesystem path to the folder\n         */\n        [[nodiscard]] constexpr std::filesystem::path engines_folder() const { return engines_folder_; }\n\n        /**\n         * Hugging Face transformers' generated `generation_config_t` mapping information stored in the\n         * `generation_config.json` holding default generation parameters.\n         * @return `generation_config_t`\n         */\n        [[nodiscard]] constexpr const generation_config_t &generation_config() const { return generation_config_; }\n\n        /**\n         * Factory method returning new `tensorrt_llm::executor::ParallelConfig` instance used\n         * to initialize `tensorrt_llm::executor::Executor` with multi-instance communication information\n         * @return `tensorrt_llm::executor::ParallelConfig` instance\n         */\n        [[nodiscard]] tle::ParallelConfig parallel_config() const;\n\n        /**\n         * Factory method returning new `tensorrt_llm::executor::ExecutorConfig` instance used\n         * to initialize `tensorrt_llm::executor::Executor`\n         * @return `tensorrt_llm::executor::ExecutorConfig` instance\n         */\n        [[nodiscard]] tle::ExecutorConfig executor_config() const;\n    };\n\n    /**\n     * Error raised by the underlying backend implementation\n     */\n    enum backend_error_t {\n        EXECUTOR_NOT_READY = 3,\n        EXECUTOR_SCHEDULING_FAILED = 4,\n    };\n\n\n    /**\n     * Actual TensorRT-LLM backend implementation interacting with TensorRT-LLM Executor service to\n     * - schedule new request\n     * - pull status of submitted request(s)\n     * - cancel submitted request(s)\n     */\n    class backend_t {\n    private:\n        backend_workspace_t workspace;\n        tle::Executor executor_;\n\n    public:\n        backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);\n\n        backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)\n                : backend_t(engines_folder, executor_worker_path) {};\n\n        /**\n         * Submit a new request to the executor\n         * @param token_ids\n         * @param generation_params\n         * @param sampling_params\n         * @return Either newly submitted request's id or the error why it failed to submit\n         */\n        [[nodiscard(\"Discarded executor request_id needs to be assigned\")]]\n        std::expected<request_id_t, backend_error_t>\n        submit(std::span<const token_id_t> token_ids, generation_params_t generation_params,\n               sampling_params_t sampling_params) noexcept;\n\n        /**\n         * Query the number of tokens available across all in-flight generations\n         * @return\n         */\n        [[nodiscard(\"Pulling out the number of tokens\")]]\n        size_t num_tokens_ready() const noexcept;\n\n        /**\n         * Pull out newly generated tokens from the executor\n         * @return\n         */\n        [[nodiscard(\"\")]]\n        std::vector<tle::Response> pull_tokens() noexcept;\n\n        /**\n         * Cancel the specified request on the executor' set\n         * @param request_id Request's Identifier to remove from the in-flight executor\n         */\n        void cancel(request_id_t) noexcept;\n    };\n\n    /**\n     * Create a TensorRT-LLM executor from a workspace\n     */\n    const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {\n        return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,\n                workspace.executor_config()};\n    };\n}\n\n/**\n * Helper structures to define formatting strategies for various types in the backend\n */\ntemplate<>\nstruct fmt::formatter<huggingface::tgi::backends::trtllm::generation_params_t> : formatter<string_view> {\n    auto format(huggingface::tgi::backends::trtllm::generation_params_t const &c,\n                format_context &ctx) const -> format_context::iterator {\n        return fmt::format_to(ctx.out(), \"generation_params_t{{ max_new_tokens={:d} }}\", c.max_new_tokens);\n    }\n};\n\ntemplate<>\nstruct fmt::formatter<huggingface::tgi::backends::trtllm::sampling_params_t> : formatter<string_view> {\n    auto format(huggingface::tgi::backends::trtllm::sampling_params_t const &c,\n                format_context &ctx) const -> format_context::iterator {\n        return fmt::format_to(\n                ctx.out(),\n                \"sampling_params_t{{ top_k={:d}, top_p={:.3f}, repetition_penalty={:.3f}, frequency_penalty={:.3f}, temperature={:.3f}, seed={:d} }}\",\n                c.top_k, c.top_p, c.repetition_penalty, c.frequency_penalty, c.temperature, c.seed\n        );\n    }\n};\n\n#endif\n"
  },
  {
    "path": "backends/trtllm/csrc/ffi.hpp",
    "content": "#ifndef TGI_BACKEND_TRTLLM_FFI\n#define TGI_BACKEND_TRTLLM_FFI\n\n#include <memory>\n#include <thread>\n\n#include <nvml.h>\n#include <tensorrt_llm/common/tllmException.h>\n#include <tensorrt_llm/plugins/api/tllmPlugin.h>\n\n#include <spdlog/spdlog.h>\n\n#include <backend.hpp>\n#include <hardware.hpp>\n\nnamespace rust::behavior {\n    template<typename Try, typename Fail>\n    static void trycatch(Try &&func, Fail &&fail) noexcept try {\n        func();\n    } catch (tensorrt_llm::common::TllmException &e) {\n        fail(e.what());\n    }\n}\n\nnamespace huggingface::tgi::backends::trtllm {\n    class tensorrt_llm_backend_t;\n}\n\n#include \"backends/trtllm/src/lib.rs.h\"\n\n\nnamespace huggingface::tgi::backends::trtllm {\n    std::once_flag backend_initialized_flag;\n\n    constexpr finish_reason_t as_finish_reason_t(const tle::FinishReason reason) noexcept {\n        switch (reason) {\n            case tle::FinishReason::kNOT_FINISHED:\n                return finish_reason_t::kNOT_FINISHED;\n            case tle::FinishReason::kSTOP_WORDS:\n                return finish_reason_t::kSTOP_WORDS;\n            case tle::FinishReason::kEND_ID:\n                return finish_reason_t::kEND_ID;\n            case tle::FinishReason::kLENGTH:\n                return finish_reason_t::kLENGTH;\n            default:\n                std::unreachable();\n        }\n    }\n\n    static auto as_generation_step = [](const tle::Response &r) {\n        const auto reqId = r.getRequestId();\n        if (!r.hasError()) [[likely]] {\n            const auto result = r.getResult();\n            const auto logits = result.logProbs.value()[0];\n            return generation_step_t{\n                    reqId,\n                    static_cast<uint32_t>(result.outputTokenIds[0][0]),\n                    logits.back(),\n                    result.isFinal,\n                    as_finish_reason_t(result.finishReasons[0]),\n                    false,\n                    std::string()\n            };\n        } else {\n            return generation_step_t{\n                    reqId,\n                    0,\n                    0.0,\n                    true,\n                    finish_reason_t::kNOT_FINISHED,\n                    true,\n                    std::move(r.getErrorMsg())\n            };\n        }\n    };\n\n\n    class tensorrt_llm_backend_t {\n    private:\n        backend_t inner_;\n\n    public:\n        tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path)\n                : inner_(engine_folder, executor_worker_path) {}\n\n        size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); }\n\n        request_id_t submit(\n                rust::Slice<const uint32_t> tokens,\n                uint32_t max_new_tokens,\n                uint32_t top_k,\n                float_t top_p,\n                float_t temperature,\n                float_t repetition_penalty,\n                float_t frequency_penalty,\n                uint64_t seed\n        ) {\n            // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)\n            SPDLOG_TRACE(FMT_STRING(\"[FFI] Submitting {:d} prompt tokens to the executor\"));\n\n            // Submit the request to the executor and get back a potential request_id used to track request status\n            const auto signed_tokens = std::vector<int32_t>(tokens.begin(), tokens.end());\n            const auto maybe_request_id = inner_.submit(\n                    signed_tokens,\n                    {max_new_tokens},\n                    {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}\n            );\n\n            // If we do have a value, let's return the request_id\n            if (maybe_request_id.has_value()) [[likely]] {\n                return *maybe_request_id;\n            } else {\n                SPDLOG_WARN(\"[FFI] Failed to submit request to the executor\");\n                return maybe_request_id.error();\n            }\n        }\n\n        std::unique_ptr<std::vector<generation_step_t>> pull_tokens() noexcept {\n            if (num_tokens_ready() > 0) [[likely]] {\n                const auto responses = inner_.pull_tokens();\n\n                SPDLOG_TRACE(\"[FFI] Successfully pulled out {:d} responses from executor\", responses.size());\n\n                // Transform tle::Response to generation_step_t\n#ifdef __cpp_lib_ranges_to_container\n                auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to<std::vector>();\n#else\n                auto steps = std::vector<generation_step_t>();\n                steps.reserve(responses.size());\n                std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step);\n#endif\n                return std::make_unique<std::vector<generation_step_t>>(steps);\n\n            } else {\n                return std::make_unique<std::vector<generation_step_t>>();\n            }\n        }\n\n        void cancel(request_id_t request_id) noexcept {\n            SPDLOG_DEBUG(\"[FFI] cancelling request {:d}\", request_id);\n            inner_.cancel(request_id);\n        }\n    };\n\n    void initialize_logging() {\n#ifndef TGI_TRTLLM_BACKEND_DEBUG\n        if (const auto TRTLLM_LOG_LEVEL_CSTR = std::getenv(\"TRTLLM_LOG_LEVEL\")) {\n            std::string log_level(TRTLLM_LOG_LEVEL_CSTR);\n            std::transform(log_level.begin(), log_level.end(), log_level.begin(), [](unsigned char c) {\n                return std::tolower(c);\n            });\n\n            if (log_level == \"debug\")\n                spdlog::set_level(spdlog::level::debug);\n            else\n                spdlog::set_level(spdlog::level::info);\n        }\n#else\n        spdlog::set_level(spdlog::level::debug);\n#endif\n    }\n\n    void initialize_tensorrt_llm_backend() {\n        SPDLOG_INFO(\"Initializing TGI - TensoRT-LLM Backend (v{})\", tle::version());\n\n        // Initialize everyone\n        initialize_logging();\n        nvmlInit_v2();\n        initTrtLlmPlugins();\n\n        const auto numGpus = huggingface::tgi::hardware::cuda::get_device_count();\n        if (numGpus.has_value()) {\n            SPDLOG_INFO(\"[FFI] Detected {:d} Nvidia GPU(s)\", *numGpus);\n        } else {\n            SPDLOG_WARN(\"[FFI] Failed to detected Nvidia GPU(s) on the system\");\n            // todo: throw\n        }\n    }\n\n    std::unique_ptr<tensorrt_llm_backend_t>\n    create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {\n        std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);\n        return std::make_unique<tensorrt_llm_backend_t>(\n                std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),\n                                      std::filesystem::path::format::auto_format),\n                std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),\n                                      std::filesystem::path::format::auto_format)\n        );\n    }\n}\n#endif\n"
  },
  {
    "path": "backends/trtllm/csrc/hardware.hpp",
    "content": "#ifndef TGI_HARDWARE_CUDA\n#define TGI_HARDWARE_CUDA\n#include <cstdint>\n#include <optional>\n\n#include <nvml.h>\n\nnamespace huggingface::tgi::hardware::cuda {\n    static constexpr auto VOLTA = std::make_tuple(7u, 0u);\n    static constexpr auto TURING = std::make_tuple(7u, 5u);\n    static constexpr auto AMPERE = std::make_tuple(8u, 0u);\n    static constexpr auto HOPPER = std::make_tuple(9u, 0u);\n    static constexpr auto ADA_LOVELACE = std::make_tuple(8u, 9u);\n\n    /**\n     * Get the number of GPUs on the local machine\n     * @return std::nullopt if no device is available, otherwise >= 1\n     */\n    inline std::optional<size_t> get_device_count() {\n        uint32_t numGpus = 0;\n        if (nvmlDeviceGetCount_v2(&numGpus) == NVML_SUCCESS) {\n            return numGpus;\n        }\n        return std::nullopt;\n    }\n\n    /**\n     * Store information about the version of the CUDA Compute Capabilities detected on the device\n     */\n    struct compute_capabilities_t {\n        int32_t major;\n        int32_t minor;\n\n        compute_capabilities_t(): compute_capabilities_t(0) {}\n        explicit compute_capabilities_t(size_t device_idx): major(-1), minor(-1) {\n            nvmlDevice_t device;\n            if (nvmlDeviceGetHandleByIndex_v2(device_idx, &device) == NVML_SUCCESS) {\n               nvmlDeviceGetCudaComputeCapability(device, &major, &minor);\n            }\n        };\n        compute_capabilities_t(int32_t major, int32_t minor): major(major), minor(minor) {}\n\n        /**\n         * Evaluate if the underlying capabilities is at least greater or equals to the provided 2-tuple (major, minor)\n         * @param sm Architecture version (major, minor)\n         * @return True if greater or equals to the underlying compute capabilities\n         */\n        [[nodiscard]] constexpr auto is_at_least(std::tuple<uint32_t, uint32_t> sm) const -> decltype(auto) { return std::tie(major, minor) >= sm; }\n\n        /**\n         * Check if the capabilities match at least Volta architecture (sm_70)\n         * @return true if at least Volta (>= sm_70), false otherwise\n         */\n        [[nodiscard]] constexpr bool is_at_least_volta() const { return is_at_least(VOLTA); }\n\n        /**\n         * Check if the capabilities match at least Turing architecture (sm_75)\n         * @return true if at least Turing (>= sm_75), false otherwise\n         */\n        [[nodiscard]] constexpr bool is_at_least_turing() const { return is_at_least(TURING); }\n\n        /**\n         * Check if the capabilities match at least Ampere architecture (sm_80)\n         * @return true if at least Ampere (>= sm_80), false otherwise\n         */\n        [[nodiscard]] constexpr bool is_at_least_ampere() const { return is_at_least(AMPERE); }\n\n        /**\n         * Check if the capabilities match at least Ada Lovelace architecture (sm_89)\n         * @return true if at least Ada Lovelace (>= sm_89), false otherwise\n         */\n        [[nodiscard]] constexpr bool is_at_least_ada_lovelace() const { return is_at_least(ADA_LOVELACE); }\n\n        /**\n         * Check if the capabilities match at least Hopper architecture (sm_90)\n         * @return true if at least Hopper (>= sm_90), false otherwise\n         */\n        [[nodiscard]] constexpr bool is_at_least_hopper() const { return is_at_least(HOPPER); }\n    };\n}\n#endif\n"
  },
  {
    "path": "backends/trtllm/scripts/install_tensorrt.sh",
    "content": "#!/bin/bash\n\nset -ex\n\nTRT_VER_BASE=\"10.8.0\"\nTRT_VER_FULL=\"${TRT_VER_BASE}.43\"\nCUDA_VER=\"12.8\"\nCUDNN_VER=\"9.7.0.66-1\"\nNCCL_VER=\"2.25.1-1+cuda${CUDA_VER}\"\nCUBLAS_VER=\"${CUDA_VER}.3.14-1\"\nNVRTC_VER=\"${CUDA_VER}.61-1\"\n\nfor i in \"$@\"; do\n    case $i in\n        --TRT_VER=?*) TRT_VER=\"${i#*=}\";;\n        --CUDA_VER=?*) CUDA_VER=\"${i#*=}\";;\n        --CUDNN_VER=?*) CUDNN_VER=\"${i#*=}\";;\n        --NCCL_VER=?*) NCCL_VER=\"${i#*=}\";;\n        --CUBLAS_VER=?*) CUBLAS_VER=\"${i#*=}\";;\n        *) ;;\n    esac\n    shift\ndone\n\nNVCC_VERSION_OUTPUT=$(nvcc --version)\nif [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP \"\\d+\\.\\d+\" | head -n 1) != ${CUDA_VER} ]]; then\n  echo \"The version of pre-installed CUDA is not equal to ${CUDA_VER}.\"\n  exit 1\nfi\n\ninstall_ubuntu_requirements() {\n    apt-get update && apt-get install -y --no-install-recommends gnupg2 curl ca-certificates\n    ARCH=$(uname -m)\n    if [ \"$ARCH\" = \"amd64\" ];then ARCH=\"x86_64\";fi\n    if [ \"$ARCH\" = \"aarch64\" ];then ARCH=\"sbsa\";fi\n    curl -fsSLO https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/${ARCH}/cuda-keyring_1.1-1_all.deb\n    dpkg -i cuda-keyring_1.1-1_all.deb\n    rm /etc/apt/sources.list.d/cuda-ubuntu2404-x86_64.list\n\n    apt-get update\n    if [[ $(apt list --installed | grep libcudnn9) ]]; then\n      apt-get remove --purge -y --allow-change-held-packages libcudnn9*\n    fi\n    if [[ $(apt list --installed | grep libnccl) ]]; then\n      apt-get remove --purge -y --allow-change-held-packages libnccl*\n    fi\n    if [[ $(apt list --installed | grep libcublas) ]]; then\n      apt-get remove --purge -y --allow-change-held-packages libcublas*\n    fi\n    if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then\n      apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev*\n    fi\n    CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\\./-/g')\n    apt-get install -y --no-install-recommends libcudnn9-cuda-12=${CUDNN_VER} libcudnn9-dev-cuda-12=${CUDNN_VER}\n    apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER}\n    apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER}\n    # NVRTC static library doesn't exist in NGC PyTorch container.\n    NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\\./-/g')\n    apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER}\n    apt-get clean\n    rm -rf /var/lib/apt/lists/*\n}\n\ninstall_centos_requirements() {\n    CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\\./-/g')\n    yum -y update\n    yum -y install epel-release\n    yum remove -y libnccl* && yum -y install libnccl-${NCCL_VER} libnccl-devel-${NCCL_VER}\n    yum remove -y libcublas* && yum -y install libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER} libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}\n    yum clean all\n}\n\ninstall_tensorrt() {\n    #PY_VERSION=$(python3 -c 'import sys; print(\".\".join(map(str, sys.version_info[0:2])))')\n    #PARSED_PY_VERSION=$(echo \"${PY_VERSION//./}\")\n    TRT_CUDA_VERSION=\"12.8\"\n\n    if [ -z \"$RELEASE_URL_TRT\" ];then\n        ARCH=${TRT_TARGETARCH}\n        if [ -z \"$ARCH\" ];then ARCH=$(uname -m);fi\n        if [ \"$ARCH\" = \"arm64\" ];then ARCH=\"aarch64\";fi\n        if [ \"$ARCH\" = \"amd64\" ];then ARCH=\"x86_64\";fi\n        if [ \"$ARCH\" = \"x86_64\" ];then DIR_NAME=\"x64-agnostic\"; else DIR_NAME=${ARCH};fi\n        if [ \"$ARCH\" = \"aarch64\" ];then OS1=\"Ubuntu22_04\" && OS2=\"Ubuntu-24.04\" && OS=\"ubuntu-24.04\"; else OS1=\"Linux\" && OS2=\"Linux\" && OS=\"linux\";fi\n        RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${TRT_VER_BASE}/tars/TensorRT-${TRT_VER_FULL}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz\n    fi\n    wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar\n    tar -xf /tmp/TensorRT.tar -C /usr/local/\n    mv /usr/local/TensorRT-${TRT_VER_FULL} /usr/local/tensorrt\n    # pip3 install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl\n    rm -rf /tmp/TensorRT.tar\n}\n\n# Install base packages depending on the base OS\nID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '\"')\ncase \"$ID\" in\n  debian)\n    install_ubuntu_requirements\n    install_tensorrt\n    ;;\n  ubuntu)\n    install_ubuntu_requirements\n    install_tensorrt\n    ;;\n  centos)\n    install_centos_requirements\n    install_tensorrt\n    ;;\n  *)\n    echo \"Unable to determine OS...\"\n    exit 1\n    ;;\nesac\n"
  },
  {
    "path": "backends/trtllm/scripts/setup_sccache.py",
    "content": "from argparse import ArgumentParser\n\nAWS_S3_CACHING_VARIABLES = {\n    \"AWS_ACCESS_KEY_ID\": \"aws_access_key_id\",\n    \"AWS_SECRET_ACCESS_KEY\": \"aws_secret_access_key\",\n    \"AWS_SESSION_TOKEN\": \"aws_session_token\",\n    \"SCCACHE_REGION\": \"s3_region\",\n    \"SCCACHE_BUCKET\": \"s3_bucket_name\",\n}\n\nALL_CACHING_STORAGE_VARIABLES = {\"AWS_S3_CACHING_VARIABLES\"}\n\n\ndef setup_sccache_locally():\n    from os import environ\n\n    print(\"Setting up Local Caching Layer\")\n    for target in ALL_CACHING_STORAGE_VARIABLES:\n        for envvar in globals()[target].keys():\n            if envvar in environ:\n                print(f\"Deleted {envvar} from environment variables\")\n                del environ[envvar]\n\n\ndef setup_sccache_for_s3():\n    from os import environ\n\n    print(\"Setting up AWS S3 Caching Layer\")\n    for envvar in AWS_S3_CACHING_VARIABLES.keys():\n        if envvar not in environ or not environ[envvar] or len(environ[envvar]) == 0:\n            print(f\"Missing definition for environment variable {envvar}\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(\"TensorRT-LLM Build Caching Setup\")\n\n    parser.add_argument(\n        \"--is-gha-build\",\n        type=str,\n        default=\"FALSE\",\n        help=\"Indicate if the build is from Github Actions\",\n    )\n\n    # Parse args\n    args = parser.parse_args()\n    args.is_gha_build = args.is_gha_build.lower() in {\"on\", \"true\", \"1\"}\n\n    if args.is_gha_build:\n        setup_sccache_for_s3()\n    else:\n        setup_sccache_locally()\n"
  },
  {
    "path": "backends/trtllm/src/errors.rs",
    "content": "use std::path::PathBuf;\nuse thiserror::Error;\n\nuse text_generation_router::server;\n\n#[derive(Debug, Error)]\npub enum TensorRtLlmBackendError {\n    #[error(\"Provided engine folder {0} doesn't exist\")]\n    EngineFolderDoesntExists(PathBuf),\n    #[error(\"Provided executorWorker binary path {0} doesn't exist\")]\n    ExecutorWorkerNotFound(PathBuf),\n    #[error(\"TensorRT-LLM Runtime error: {0}\")]\n    Runtime(String),\n    #[error(\"Tokenizer error: {0}\")]\n    Tokenizer(String),\n    #[error(\"Argument validation error: {0}\")]\n    ArgumentValidation(String),\n    #[error(\"WebServer error: {0}\")]\n    WebServer(#[from] server::WebServerError),\n    #[error(\"Tokio runtime failed to start: {0}\")]\n    Tokio(#[from] std::io::Error),\n}\n"
  },
  {
    "path": "backends/trtllm/src/lib.rs",
    "content": "pub use looper::TensorRtLlmBackendV2;\n\npub mod errors;\nmod looper;\nmod utils;\n\n#[cxx::bridge(namespace = \"huggingface::tgi::backends::trtllm\")]\nmod ffi {\n    #[cxx_name = \"finish_reason_t\"]\n    #[derive(Debug, Clone, Copy)]\n    pub enum FinishReason {\n        /// The request is not finished.\n        #[cxx_name = \"kNOT_FINISHED\"]\n        NotFinished = 0u8,\n\n        /// The request finished because the end id was generated.\n        #[cxx_name = \"kEND_ID\"]\n        EndTokenId = 1u8,\n\n        /// The request finished because a stop word was generated.\n        #[cxx_name = \"kSTOP_WORDS\"]\n        StopWords = 2u8,\n\n        /// The request finished because the maximum number of tokens was reached.\n        #[cxx_name = \"kLENGTH\"]\n        MaxLength = 3u8,\n    }\n\n    /// Struct used as shared type between rust and C++ to represent the result\n    /// of a single decoding iteration\n    #[cxx_name = \"generation_step_t\"]\n    #[derive(Debug, Clone)]\n    pub struct GenerationStep {\n        request_id: u64,\n        token_id: u32,\n        log_prob: f32,\n        is_final: bool,\n        finish_reason: FinishReason,\n        has_error: bool,\n        error_msg: String,\n    }\n\n    unsafe extern \"C++\" {\n        include!(\"backends/trtllm/csrc/ffi.hpp\");\n\n        /// Represent an instance of the underlying TensorRT-LLM backend\n        #[cxx_name = \"tensorrt_llm_backend_t\"]\n        type TensorRtLlmBackendImpl;\n\n        /// Create an instance backed behind a std::unique_ptr to manage the lifespan of the backend\n        ///\n        /// # Arguments\n        ///\n        /// * `engine_folder`: Path to the folder containing all the TRTLLM engines\n        /// * `executor_worker`: Path to the TRTLLM executor worker\n        ///\n        /// returns: <unknown>\n        ///\n        /// # Examples\n        ///\n        /// ```\n        ///\n        /// ```\n        fn create_backend_from_engine_folder(\n            engine_folder: &str,\n            executor_worker: &str,\n        ) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;\n\n        fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize;\n\n        fn submit(\n            self: Pin<&mut TensorRtLlmBackendImpl>,\n            tokens: &[u32],\n            max_new_tokens: u32,\n            top_k: u32,\n            top_p: f32,\n            temperature: f32,\n            repetition_penalty: f32,\n            frequency_penalty: f32,\n            seed: u64,\n        ) -> Result<u64>;\n\n        fn pull_tokens(\n            self: Pin<&mut TensorRtLlmBackendImpl>,\n        ) -> Result<UniquePtr<CxxVector<GenerationStep>>>;\n\n        fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64);\n    }\n}\n\nuse ffi::FinishReason;\nuse text_generation_router::FinishReason as InferFinishReason;\n\nimpl From<FinishReason> for InferFinishReason {\n    fn from(reason: FinishReason) -> Self {\n        match reason {\n            FinishReason::StopWords => InferFinishReason::StopSequence,\n            FinishReason::MaxLength => InferFinishReason::Length,\n            FinishReason::EndTokenId => InferFinishReason::EndOfSequenceToken,\n            _ => panic!(\"Cannot convert {reason:?} to text_generation_router::FinishReason\"),\n        }\n    }\n}\n"
  },
  {
    "path": "backends/trtllm/src/looper.rs",
    "content": "use async_trait::async_trait;\nuse cxx::UniquePtr;\nuse hashbrown::HashMap;\nuse std::hint;\nuse std::ops::Deref;\nuse std::path::Path;\nuse tokenizers::Tokenizer;\nuse tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};\nuse tokio::sync::TryAcquireError;\nuse tokio::task::spawn_blocking;\nuse tokio::time::Instant;\nuse tokio_stream::wrappers::UnboundedReceiverStream;\nuse tracing::{debug, error, warn};\n\nuse text_generation_router::infer::InferError::{GenerationError, ValidationError};\nuse text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};\nuse text_generation_router::validation::ValidationError::{\n    EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,\n};\nuse text_generation_router::validation::{Chunk, ValidGenerateRequest};\nuse text_generation_router::Token;\n\nuse crate::errors::TensorRtLlmBackendError;\nuse crate::ffi::{\n    create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,\n};\nuse crate::utils::first_line;\n\ntype InferResult<T> = Result<T, InferError>;\n\n/// Wrap the requests along with the channel used to stream back to the client the decoded tokens\nstruct GenerationContext {\n    request: ValidGenerateRequest,\n    streamer: UnboundedSender<InferResult<InferStreamResponse>>,\n    tokens: Vec<u32>,\n    start: Option<Instant>,\n    queued: Instant,\n}\n\n#[derive(Debug, Copy, Clone)]\nstruct DecodedToken {\n    id: u32,\n    log_prob: f32,\n    is_final: bool,\n    finish_reason: FinishReason,\n}\n\nimpl<'step> TryFrom<&'step GenerationStep> for DecodedToken {\n    type Error = InferError;\n\n    fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {\n        if !step.has_error {\n            Ok(Self {\n                id: step.token_id,\n                log_prob: step.log_prob,\n                is_final: step.is_final,\n                finish_reason: step.finish_reason,\n            })\n        } else {\n            Err(GenerationError(step.error_msg.clone()))\n        }\n    }\n}\n\nfn executor_status_looper(\n    max_inflight_requests: usize,\n    tokenizer: Tokenizer,\n    mut backend: UniquePtr<TensorRtLlmBackendImpl>,\n    mut backlog: UnboundedReceiver<GenerationContext>,\n) {\n    // Track the tuple (request_id, stream) for each request\n    let mut in_flights =\n        HashMap::<u64, GenerationContext>::with_capacity(max_inflight_requests * 2);\n\n    'scheduler: loop {\n        // Is there any request pending to be scheduled?\n        let awaiting_requests = backlog.len();\n        for _ in 0..awaiting_requests {\n            // Retrieve all the requests\n            if let Some(ctx) = backlog.blocking_recv() {\n                // Submit all the request to the executor and move the context to the in-flight tracker\n                let request = &ctx.request;\n                let generation_params = &request.parameters;\n                let stopping_params = &request.stopping_parameters;\n                let input_ids = request.input_ids.as_deref();\n\n                // Submit to the TensorRT-LLM executor for scheduling\n                match backend.pin_mut().submit(\n                    &input_ids.unwrap(), // This is checked beforehand in validate()\n                    stopping_params.max_new_tokens,\n                    generation_params.top_k,\n                    generation_params.top_p,\n                    generation_params.temperature,\n                    generation_params.repetition_penalty,\n                    generation_params.frequency_penalty,\n                    generation_params.seed,\n                ) {\n                    Ok(request_id) => {\n                        // Insert the context linked to the generated request id in the tracker\n                        debug!(\"[in-flight] Added {}\", request_id);\n                        in_flights.insert(request_id, ctx);\n                    }\n                    Err(e) => {\n                        // Return to the caller\n                        let what = e.to_string();\n                        error!(error = what.as_str(), \"Failed to schedule request\");\n\n                        let err = Err(InferError::Overloaded(TryAcquireError::NoPermits));\n                        if let Err(_) = ctx.streamer.send(err) {\n                            error!(\"Failed to send back error to the client\");\n                        }\n                    }\n                };\n            } else {\n                break 'scheduler;\n            }\n        }\n\n        if backend.num_tokens_ready() > 0 {\n            let mut backend = backend.pin_mut();\n            match backend.as_mut().pull_tokens() {\n                Ok(responses) => {\n                    // Iterate through all the decoded token\n                    for step in responses.deref() {\n                        if let Some(ctx) = in_flights.get_mut(&step.request_id) {\n                            // Update the starting timestamp if not set\n                            // This value might not be the actual real starting time of the request\n                            // on the executor side - Need to expose more info from the executor to\n                            // retrieve this value\n                            // TODO : Expose actual real starting time for a request on FFI layer\n                            if ctx.start.is_none() {\n                                ctx.start = Some(Instant::now());\n                            }\n\n                            // Try to map the generation step to a DecodedToken\n                            let response = match DecodedToken::try_from(step) {\n                                Ok(decoded_token) => {\n                                    post_process_decoded_token(&tokenizer, ctx, decoded_token)\n                                }\n                                Err(err) => Err(err),\n                            };\n\n                            // Attempt to send back the response to the client\n                            if let Err(_) = ctx.streamer.send(response) {\n                                // Client has dropped, remove from tracked requests\n                                debug!(\n                                    \"Client dropped - removing request {} from tracked requests\",\n                                    step.request_id\n                                );\n                                backend.as_mut().cancel(step.request_id);\n                                let _ = in_flights.remove(&step.request_id);\n                            }\n                        } else {\n                            warn!(\"Untracked request {}\", step.request_id,);\n                        }\n                    }\n                }\n                Err(ref err) => {\n                    error!(\"Failed to get responses from the executor: {}.\", err.what());\n                    break 'scheduler;\n                }\n            }\n        }\n\n        // Hint the CPU we are spin-locking\n        hint::spin_loop();\n    }\n}\n\nfn post_process_decoded_token(\n    tokenizer: &Tokenizer,\n    ctx: &mut GenerationContext,\n    decoded_token: DecodedToken,\n) -> InferResult<InferStreamResponse> {\n    match tokenizer.decode(&[decoded_token.id], false) {\n        Ok(text) => {\n            let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);\n            let token = Token {\n                id: decoded_token.id,\n                text,\n                logprob: decoded_token.log_prob,\n                special: is_special,\n            };\n\n            // Append the token to the tracked generated tokens\n            ctx.tokens.push(token.id);\n\n            // Map the correct response depending on the step is final or not\n            let out = if !decoded_token.is_final {\n                InferStreamResponse::Intermediate {\n                    token,\n                    top_tokens: vec![],\n                }\n            } else {\n                let text = tokenizer.decode(&ctx.tokens, true);\n                let generated_text = GeneratedText {\n                    text: text.unwrap(),\n                    generated_tokens: ctx.tokens.len() as u32,\n                    finish_reason: decoded_token.finish_reason.into(),\n                    seed: None,\n                };\n\n                InferStreamResponse::End {\n                    token,\n                    top_tokens: vec![],\n                    generated_text,\n                    start: ctx.start.unwrap(),\n                    queued: ctx.queued,\n                }\n            };\n\n            Ok(out)\n        }\n        Err(err) => Err(GenerationError(err.to_string())),\n    }\n}\n\nfn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(\n    engine_folder: P,\n    executor_worker_path: PP,\n) -> Result<(String, String), TensorRtLlmBackendError> {\n    // Retrieve paths as &str for the backend creation\n    let engine_folder = engine_folder.as_ref();\n    let executor_worker_path = executor_worker_path.as_ref();\n\n    // Ensure the engine folder exists\n    if !engine_folder.exists() {\n        let err = TensorRtLlmBackendError::EngineFolderDoesntExists(engine_folder.to_path_buf());\n\n        error!(\"Path validation failed: {}\", err,);\n        return Err(err);\n    }\n\n    // Ensure executor worker binary exists\n    if !executor_worker_path.exists() {\n        let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf());\n\n        error!(\"Path validation failed: {}\", err,);\n        return Err(err);\n    }\n\n    let engine_folder = String::from(\n        engine_folder\n            .to_str()\n            .expect(\"Failed to convert engine_folder to valid UTF-8\"),\n    );\n\n    let executor_worker_path = String::from(\n        executor_worker_path\n            .to_str()\n            .expect(\"Failed to convert executor_worker_path to valid UTF-8\"),\n    );\n\n    Ok((engine_folder, executor_worker_path))\n}\n\nunsafe impl Send for TensorRtLlmBackendImpl {}\n\npub struct TensorRtLlmBackendV2(UnboundedSender<GenerationContext>);\n\nimpl TensorRtLlmBackendV2 {\n    pub fn new<P: AsRef<Path> + Send, PP: AsRef<Path> + Send>(\n        tokenizer: Tokenizer,\n        engine_folder: P,\n        executor_worker_path: PP,\n        max_inflight_requests: usize,\n    ) -> Result<Self, TensorRtLlmBackendError> {\n        let (engine_folder, executor_worker_path) =\n            ensure_paths_exist(engine_folder, executor_worker_path)?;\n\n        // Allocate the IPC layer to communicate with the backend\n        let (executor_sender, executor_receiver) = unbounded_channel();\n\n        // Create the FFI backend\n        let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)\n            .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), \"Unknown error\")))?;\n\n        // Executor looper is responsible for scheduling and pulling requests state at regular interval\n        spawn_blocking(move || {\n            executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver)\n        });\n\n        Ok(TensorRtLlmBackendV2(executor_sender))\n    }\n\n    fn validate(request: &ValidGenerateRequest) -> InferResult<()> {\n        if request.input_ids.is_none() {\n            return Err(ValidationError(UnsupportedModality(\"No token provided\")));\n        }\n\n        if request.top_n_tokens > 1 {\n            return Err(ValidationError(TopNTokensDisabled));\n        }\n\n        // TODO: Is it really needed? How can it be validated before?\n        if request.parameters.grammar.is_some() {\n            return Err(ValidationError(Grammar));\n        }\n\n        match request.inputs.len() {\n            0 => Err(ValidationError(EmptyInput)),\n            2.. => Err(GenerationError(\n                \"TensorRT-LLM backend don't support multi-chunk\".into(),\n            )),\n            1 => match request.inputs.first().expect(\"Single item-chunk\") {\n                Chunk::Text(_) => Ok(()),\n                Chunk::Image(_) => Err(ValidationError(UnsupportedModality(\"image\"))),\n            },\n        }\n    }\n}\n\n#[async_trait]\nimpl Backend for TensorRtLlmBackendV2 {\n    fn schedule(\n        &self,\n        request: ValidGenerateRequest,\n    ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {\n        Self::validate(&request)?;\n\n        // Open-up the stream to send tokens\n        let (streamer, receiver) = unbounded_channel::<InferResult<InferStreamResponse>>();\n\n        // Send the context to the executor for scheduling\n        let queued = Instant::now();\n        match self.0.send(GenerationContext {\n            request,\n            streamer,\n            tokens: Vec::with_capacity(256),\n            start: None,\n            queued,\n        }) {\n            Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),\n            Err(_) => Err(GenerationError(\n                \"Failed to submit request to the backend\".into(),\n            )),\n        }\n    }\n\n    async fn health(&self, _: bool) -> bool {\n        true\n    }\n\n    fn name(&self) -> &'static str {\n        \"TensorRT-LLM\"\n    }\n}\n"
  },
  {
    "path": "backends/trtllm/src/main.rs",
    "content": "use std::path::{Path, PathBuf};\n\nuse clap::Parser;\nuse hf_hub::api::tokio::{Api, ApiBuilder};\nuse hf_hub::{Cache, Repo, RepoType};\nuse tracing::info;\n\nuse text_generation_backends_trtllm::errors::TensorRtLlmBackendError;\nuse text_generation_backends_trtllm::TensorRtLlmBackendV2;\nuse text_generation_router::server::{\n    get_hub_model_info, legacy_tokenizer_handle, py_resolve_tokenizer,\n};\nuse text_generation_router::usage_stats::UsageStatsLevel;\nuse text_generation_router::{server, Tokenizer};\n\n/// App Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    #[clap(default_value = \"128\", long, env)]\n    max_concurrent_requests: usize,\n    #[clap(default_value = \"2\", long, env)]\n    max_best_of: usize,\n    #[clap(default_value = \"4\", long, env)]\n    max_stop_sequences: usize,\n    #[clap(default_value = \"5\", long, env)]\n    max_top_n_tokens: u32,\n    #[clap(default_value = \"1024\", long, env)]\n    max_input_tokens: usize,\n    #[clap(default_value = \"2048\", long, env)]\n    max_total_tokens: usize,\n    #[clap(default_value = \"4096\", long, env)]\n    max_batch_prefill_tokens: u32,\n    #[clap(long, env)]\n    max_batch_total_tokens: Option<u32>,\n    #[clap(default_value = \"0.0.0.0\", long, env)]\n    hostname: String,\n    #[clap(default_value = \"3000\", long, short, env)]\n    port: u16,\n    #[clap(default_value = \"9000\", long, short, env)]\n    prometheus_port: u16,\n    #[clap(long, env, required = true)]\n    tokenizer_name: String,\n    #[clap(long, env)]\n    tokenizer_config_path: Option<String>,\n    #[clap(long, env)]\n    revision: Option<String>,\n    #[clap(long, env)]\n    model_id: String,\n    #[clap(default_value = \"2\", long, env)]\n    validation_workers: usize,\n    #[clap(long, env)]\n    json_output: bool,\n    #[clap(long, env)]\n    otlp_endpoint: Option<String>,\n    #[clap(default_value = \"text-generation-inference.router\", long, env)]\n    otlp_service_name: String,\n    #[clap(long, env)]\n    cors_allow_origin: Option<Vec<String>>,\n    #[clap(default_value = \"4\", long, env)]\n    max_client_batch_size: usize,\n    #[clap(long, env)]\n    auth_token: Option<String>,\n    #[clap(long, env, help = \"Path to the TensorRT-LLM Orchestrator worker\")]\n    executor_worker: PathBuf,\n    #[clap(default_value = \"on\", long, env)]\n    usage_stats: UsageStatsLevel,\n    #[clap(default_value = \"2000000\", long, env)]\n    payload_limit: usize,\n    #[clap(default_value = \"1073741824\", long, env)]\n    max_image_fetch_size: usize,\n}\n\nasync fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {\n    // Parse Huggingface hub token\n    let authorization_token = std::env::var(\"HF_TOKEN\")\n        .or_else(|_| std::env::var(\"HUGGING_FACE_HUB_TOKEN\"))\n        .ok();\n\n    // Tokenizer instance\n    let local_path = Path::new(tokenizer_name);\n\n    // Shared API builder initialization\n    let api_builder = || {\n        let mut builder = ApiBuilder::new()\n            .with_progress(false)\n            .with_token(authorization_token);\n\n        if let Ok(cache_dir) = std::env::var(\"HUGGINGFACE_HUB_CACHE\") {\n            builder = builder.with_cache_dir(cache_dir.into());\n        }\n\n        if let Ok(origin) = std::env::var(\"HF_HUB_USER_AGENT_ORIGIN\") {\n            builder = builder.with_user_agent(\"origin\", origin.as_str());\n        }\n\n        builder\n    };\n\n    // Decide if we need to use the API based on the revision and local path\n    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();\n\n    // Initialize API if needed\n    #[derive(Clone)]\n    enum Type {\n        Api(Api),\n        Cache(Cache),\n        None,\n    }\n    let api = if use_api {\n        if std::env::var(\"HF_HUB_OFFLINE\") == Ok(\"1\".to_string()) {\n            let cache = std::env::var(\"HUGGINGFACE_HUB_CACHE\")\n                .map_err(|_| ())\n                .map(|cache_dir| Cache::new(cache_dir.into()))\n                .unwrap_or_else(|_| Cache::default());\n            tracing::warn!(\"Offline mode active using cache defaults\");\n            Type::Cache(cache)\n        } else {\n            tracing::info!(\"Using the Hugging Face API\");\n            match api_builder().build() {\n                Ok(api) => Type::Api(api),\n                Err(_) => {\n                    tracing::warn!(\"Unable to build the Hugging Face API\");\n                    Type::None\n                }\n            }\n        }\n    } else {\n        Type::None\n    };\n\n    // Load tokenizer and model info\n    let (\n        config_filename,\n        _tokenizer_config_filename,\n        _preprocessor_config_filename,\n        _processor_config_filename,\n        _model_info,\n    ) = match api {\n        Type::None => (\n            Some(local_path.join(\"config.json\")),\n            Some(local_path.join(\"tokenizer_config.json\")),\n            Some(local_path.join(\"preprocessor_config.json\")),\n            Some(local_path.join(\"processor_config.json\")),\n            None,\n        ),\n        Type::Api(api) => {\n            let api_repo = api.repo(Repo::with_revision(\n                tokenizer_name.to_string(),\n                RepoType::Model,\n                revision.unwrap_or_else(|| \"main\").to_string(),\n            ));\n\n            let config_filename = api_repo.get(\"config.json\").await.ok();\n            let tokenizer_config_filename = api_repo.get(\"tokenizer_config.json\").await.ok();\n            let preprocessor_config_filename = api_repo.get(\"preprocessor_config.json\").await.ok();\n            let processor_config_filename = api_repo.get(\"processor_config.json\").await.ok();\n\n            let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {\n                Some(model_info)\n            } else {\n                tracing::warn!(\"Could not retrieve model info from the Hugging Face hub.\");\n                None\n            };\n            (\n                config_filename,\n                tokenizer_config_filename,\n                preprocessor_config_filename,\n                processor_config_filename,\n                model_info,\n            )\n        }\n        Type::Cache(cache) => {\n            let repo = cache.repo(Repo::with_revision(\n                tokenizer_name.to_string(),\n                RepoType::Model,\n                revision.clone().unwrap_or_else(|| \"main\").to_string(),\n            ));\n            (\n                repo.get(\"config.json\"),\n                repo.get(\"tokenizer_config.json\"),\n                repo.get(\"preprocessor_config.json\"),\n                repo.get(\"processor_config.json\"),\n                None,\n            )\n        }\n    };\n\n    let tokenizer: Tokenizer = {\n        use pyo3::prelude::*;\n        pyo3::Python::with_gil(|py| -> PyResult<()> {\n            py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), false)?;\n            Ok(())\n        })\n        .inspect_err(|err| {\n            tracing::error!(\"Failed to import python tokenizer {err}\");\n        })\n        .or_else(|err| {\n            let out = legacy_tokenizer_handle(config_filename.as_ref());\n            out.ok_or(err)\n        })\n        .expect(\"We cannot load a tokenizer\");\n        let filename = \"out/tokenizer.json\";\n        if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {\n            Tokenizer::Rust(tok)\n        } else {\n            Tokenizer::Python {\n                tokenizer_name: tokenizer_name.to_string(),\n                revision: revision.map(|revision| revision.to_string()),\n                trust_remote_code: false,\n            }\n        }\n    };\n\n    Some(tokenizer)\n}\n\n#[tokio::main]\nasync fn main() -> Result<(), TensorRtLlmBackendError> {\n    // Get args\n    let args = Args::parse();\n    // Pattern match configuration\n    let Args {\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        hostname,\n        port,\n        prometheus_port,\n        tokenizer_name,\n        tokenizer_config_path,\n        revision,\n        model_id,\n        validation_workers,\n        json_output,\n        otlp_endpoint,\n        otlp_service_name,\n        cors_allow_origin,\n        max_client_batch_size,\n        auth_token,\n        executor_worker,\n        usage_stats,\n        payload_limit,\n        max_image_fetch_size,\n    } = args;\n\n    // Launch Tokio runtime\n    text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);\n\n    // Validate args\n    if max_input_tokens >= max_total_tokens {\n        return Err(TensorRtLlmBackendError::ArgumentValidation(\n            \"`max_input_tokens` must be < `max_total_tokens`\".to_string(),\n        ));\n    }\n    if max_input_tokens as u32 > max_batch_prefill_tokens {\n        return Err(TensorRtLlmBackendError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}\")));\n    }\n\n    if validation_workers == 0 {\n        return Err(TensorRtLlmBackendError::ArgumentValidation(\n            \"`validation_workers` must be > 0\".to_string(),\n        ));\n    }\n\n    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {\n        if max_batch_prefill_tokens > *max_batch_total_tokens {\n            return Err(TensorRtLlmBackendError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}\")));\n        }\n        if max_total_tokens as u32 > *max_batch_total_tokens {\n            return Err(TensorRtLlmBackendError::ArgumentValidation(format!(\"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}\")));\n        }\n    }\n\n    if !executor_worker.exists() {\n        return Err(TensorRtLlmBackendError::ArgumentValidation(format!(\n            \"`executor_work` specified path doesn't exists: {}\",\n            executor_worker.display()\n        )));\n    }\n\n    // Create the backend\n    match get_tokenizer(&tokenizer_name, revision.as_deref())\n        .await\n        .expect(\"Failed to retrieve tokenizer implementation\")\n    {\n        Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(\n            \"Failed to retrieve Rust based tokenizer\".to_string(),\n        )),\n        Tokenizer::Rust(tokenizer) => {\n            info!(\"Successfully retrieved tokenizer {}\", &tokenizer_name);\n            let backend = TensorRtLlmBackendV2::new(\n                tokenizer,\n                model_id,\n                executor_worker,\n                max_concurrent_requests,\n            )?;\n\n            info!(\"Successfully created backend\");\n\n            // Run server\n            server::run(\n                backend,\n                max_concurrent_requests,\n                max_best_of,\n                max_stop_sequences,\n                max_top_n_tokens,\n                max_input_tokens,\n                max_total_tokens,\n                validation_workers,\n                auth_token,\n                tokenizer_name,\n                tokenizer_config_path,\n                revision,\n                false,\n                hostname,\n                port,\n                cors_allow_origin,\n                false,\n                None,\n                None,\n                true,\n                max_client_batch_size,\n                usage_stats,\n                payload_limit,\n                max_image_fetch_size,\n                prometheus_port,\n            )\n            .await?;\n            Ok(())\n        }\n    }\n}\n"
  },
  {
    "path": "backends/trtllm/src/utils.rs",
    "content": "///\n/// Extract the first line of the provided string reference.\n/// If there is no lines in the buffer, it returns a string\n/// which content is defined by the content of `fail`\n/// # Arguments\n///\n/// * `s`: The string buffer to extract the first-line from\n/// * `fail`: A string content which is returned if no lines are\n/// present in `s`\n///\n/// returns: String\n///\n/// # Examples\n///\n/// ```\n/// let s = \"My name is Morgan.\\n I'm working at Hugging Face.\";\n/// first_line(s, \"No line in string\");\n/// ```\n#[inline]\npub(crate) fn first_line(s: &str, fail: &str) -> String {\n    s.lines().next().unwrap_or(fail).to_string()\n}\n"
  },
  {
    "path": "backends/trtllm/tests/test_backend.cpp",
    "content": "//\n// Created by mfuntowicz on 12/3/24.\n//\n\n#include <catch2/catch_all.hpp>\n#include <nlohmann/json.hpp>\n#include <tensorrt_llm/executor/executor.h>\n\n#include \"backend.hpp\"\n\nusing namespace huggingface::tgi::backends::trtllm;\n\nTEST_CASE(\"parse generation_config.json all set\", \"[generation_config_t]\")\n{\n    const json config_j = {{\"temperature\",  0.6},\n                           {\"top_p\",        0.95},\n                           {\"eos_token_id\", {1, 2, 3}}};\n    const auto generation_config = generation_config_t(config_j);\n\n    REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(0.6, 1e-6));\n    REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(0.95, 1e-6));\n\n    // Stop words\n    REQUIRE_FALSE(generation_config.stop_words.empty());\n    REQUIRE(generation_config.stop_words.size() == config_j[\"/eos_token_id\"_json_pointer].size());\n\n    for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},\n                                                                                                        {2},\n                                                                                                        {3}})) {\n        // Currently we do not support multi-tokens stop words\n        REQUIRE(lhs.size() == 1);\n        REQUIRE(rhs.size() == 1);\n        REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));\n    }\n}\n\nTEST_CASE(\"parse generation_config.json default\", \"[generation_config_t]\")\n{\n    const json config_j = {{\"eos_token_id\", {1, 2, 3}}};\n    const auto generation_config = generation_config_t(config_j);\n\n    REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));\n    REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));\n\n    REQUIRE_FALSE(generation_config.stop_words.empty());\n    REQUIRE(generation_config.stop_words.size() == config_j[\"/eos_token_id\"_json_pointer].size());\n\n    for (auto [lhs, rhs]: std::views::zip(generation_config.stop_words, std::list<std::vector<int32_t>>{{1},\n                                                                                                        {2},\n                                                                                                        {3}})) {\n        // Currently we do not support multi-tokens stop words\n        REQUIRE(lhs.size() == 1);\n        REQUIRE(rhs.size() == 1);\n        REQUIRE_THAT(lhs, Catch::Matchers::UnorderedEquals(rhs));\n    }\n}\n\nTEST_CASE(\"parse generation_config.json empty\", \"[generation_config_t]\")\n{\n    const json config_j = {{\"eos_token_id\", {}}};\n    const auto generation_config = generation_config_t(config_j);\n\n    REQUIRE_THAT(generation_config.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));\n    REQUIRE_THAT(generation_config.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));\n\n    REQUIRE(generation_config.stop_words.empty());\n\n    const json config_j2 = {};\n    const auto generation_config2 = generation_config_t(config_j);\n\n    REQUIRE_THAT(generation_config2.temperature, Catch::Matchers::WithinAbs(1.0, 1e-6));\n    REQUIRE_THAT(generation_config2.top_p, Catch::Matchers::WithinAbs(1.0, 1e-6));\n\n    REQUIRE(generation_config2.stop_words.empty());\n}\n\nTEST_CASE(\"parallel_config single\", \"[backend_workspace_t]\")\n{\n    // Generate temporary folder\n    const auto tmp_p = std::filesystem::temp_directory_path();\n    const auto config_p = tmp_p / \"config.json\";\n    const auto generation_config_p = tmp_p / \"generation_config.json\";\n\n    // Generate content\n    std::ofstream o_config(config_p);\n    o_config << R\"({\"pretrained_config\": {\"mapping\": {\"world_size\": 2}}})\"_json;\n    o_config.close();\n\n    std::ofstream o_generation_config(generation_config_p);\n    o_generation_config << R\"({\"eos_token_id\": []})\"_json;\n    o_generation_config.close();\n\n    const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());\n    const auto parallel = workspace.parallel_config();\n    REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kORCHESTRATOR);\n    REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);\n\n    std::filesystem::remove(config_p);\n    std::filesystem::remove(generation_config_p);\n}\n\nTEST_CASE(\"parallel_config multi\", \"[backend_workspace_t]\")\n{\n    // Generate temporary folder\n    const auto tmp_p = std::filesystem::temp_directory_path();\n    const auto config_p = tmp_p / \"config.json\";\n    const auto generation_config_p = tmp_p / \"generation_config.json\";\n\n    // Generate content\n    std::ofstream o_config(config_p);\n    o_config << R\"({\"pretrained_config\": {\"mapping\": {\"world_size\": 1}}})\"_json;\n    o_config.close();\n\n    std::ofstream o_generation_config(generation_config_p);\n    o_generation_config << R\"({\"eos_token_id\": []})\"_json;\n    o_generation_config.close();\n\n    const auto workspace = backend_workspace_t(tmp_p.generic_string(), tmp_p.generic_string());\n    const auto parallel = workspace.parallel_config();\n    REQUIRE(parallel.getCommunicationMode() == tle::CommunicationMode::kLEADER);\n    REQUIRE(parallel.getCommunicationType() == tle::CommunicationType::kMPI);\n\n    std::filesystem::remove(config_p);\n    std::filesystem::remove(generation_config_p);\n}\n\nTEST_CASE(\"executor_config\", \"[backend_workspace_t]\")\n{\n\n}\n\nTEST_CASE(\"sampling_params_t to tle::SamplingConfig\", \"[backend_t]\")\n{\n    const sampling_params_t params = {40, 0.95, 0.9, 1.0, 0.6, 2014};\n    const auto config = static_cast<tle::SamplingConfig>(params);\n\n    REQUIRE(config.getTopK().has_value());\n    REQUIRE(config.getTopK().value() == params.top_k);\n\n    REQUIRE(config.getSeed().has_value());\n    REQUIRE(config.getSeed().value() == params.seed);\n\n    REQUIRE(config.getTopP().has_value());\n    REQUIRE_THAT(*config.getTopP(), Catch::Matchers::WithinAbs(params.top_p, 1e-6f));\n\n    REQUIRE(config.getRepetitionPenalty().has_value());\n    REQUIRE_THAT(*config.getRepetitionPenalty(), Catch::Matchers::WithinAbs(params.repetition_penalty, 1e-6f));\n\n    REQUIRE(config.getFrequencyPenalty().has_value());\n    REQUIRE_THAT(*config.getFrequencyPenalty(), Catch::Matchers::WithinAbs(params.frequency_penalty, 1e-6f));\n\n    REQUIRE(config.getTemperature().has_value());\n    REQUIRE_THAT(*config.getTemperature(), Catch::Matchers::WithinAbs(params.temperature, 1e-6f));\n}\n"
  },
  {
    "path": "backends/trtllm/tests/test_hardware.cpp",
    "content": "//\n// Created by mfuntowicz on 11/16/24.\n//\n\n#include <catch2/catch_all.hpp>\n#include \"../csrc/hardware.hpp\"\n\nusing namespace huggingface::tgi::hardware::cuda;\n\nTEST_CASE(\"is_at_least_<arch>\") {\n    const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);\n    REQUIRE(VOLTA_CAPABILITIES.is_at_least_volta());\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_turing());\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ampere());\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_ada_lovelace());\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least_hopper());\n\n    const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);\n    REQUIRE(TURING_CAPABILITIES.is_at_least_volta());\n    REQUIRE(TURING_CAPABILITIES.is_at_least_turing());\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ampere());\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_ada_lovelace());\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least_hopper());\n\n    const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least_volta());\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least_turing());\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least_ampere());\n    REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_ada_lovelace());\n    REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least_hopper());\n\n    const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_volta());\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_turing());\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ampere());\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least_ada_lovelace());\n    REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least_hopper());\n\n    const static auto HOPPER_CAPABILITIES = compute_capabilities_t(9, 0);\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least_volta());\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least_turing());\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least_ampere());\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least_ada_lovelace());\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least_hopper());\n}\n\nTEST_CASE(\"is_at_least\") {\n    const static auto VOLTA_CAPABILITIES = compute_capabilities_t(7, 0);\n    REQUIRE(VOLTA_CAPABILITIES.is_at_least(VOLTA));\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(TURING));\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(AMPERE));\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(ADA_LOVELACE));\n    REQUIRE_FALSE(VOLTA_CAPABILITIES.is_at_least(HOPPER));\n\n    const static auto TURING_CAPABILITIES = compute_capabilities_t(7, 5);\n    REQUIRE(TURING_CAPABILITIES.is_at_least(VOLTA));\n    REQUIRE(TURING_CAPABILITIES.is_at_least(TURING));\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(AMPERE));\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(ADA_LOVELACE));\n    REQUIRE_FALSE(TURING_CAPABILITIES.is_at_least(HOPPER));\n\n    const static auto AMPERE_CAPABILITIES = compute_capabilities_t(8, 0);\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least(VOLTA));\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least(TURING));\n    REQUIRE(AMPERE_CAPABILITIES.is_at_least(AMPERE));\n    REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(ADA_LOVELACE));\n    REQUIRE_FALSE(AMPERE_CAPABILITIES.is_at_least(HOPPER));\n\n    const static auto ADA_LOVELACE_CAPABILITIES = compute_capabilities_t(8, 9);\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(VOLTA));\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(TURING));\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(AMPERE));\n    REQUIRE(ADA_LOVELACE_CAPABILITIES.is_at_least(ADA_LOVELACE));\n    REQUIRE_FALSE(ADA_LOVELACE_CAPABILITIES.is_at_least(HOPPER));\n\n    const static auto HOPPER_CAPABILITIES = compute_capabilities_t (9, 0);\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least(VOLTA));\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least(TURING));\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least(AMPERE));\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least(ADA_LOVELACE));\n    REQUIRE(HOPPER_CAPABILITIES.is_at_least(HOPPER));\n}\n"
  },
  {
    "path": "backends/v2/Cargo.toml",
    "content": "[package]\nname = \"text-generation-router-v2\"\ndescription = \"Text Generation Webserver\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[lib]\npath = \"src/lib.rs\"\n\n[[bin]]\nname = \"text-generation-router-v2\"\npath = \"src/main.rs\"\n\n[dependencies]\nasync-trait = \"0.1.74\"\nasync-stream = \"0.3.5\"\naxum = { version = \"0.7\", features = [\"json\"] }\naxum-tracing-opentelemetry = \"0.16\"\ntext-generation-router = { path = \"../../router\" }\nclap = { version = \"4.4.5\", features = [\"derive\", \"env\"] }\ngrpc-metadata = { path = \"../grpc-metadata\" }\nfutures = \"0.3.28\"\nhf-hub = { workspace = true }\njsonschema = { version = \"0.28.0\" }\nmetrics = { workspace = true }\nmetrics-exporter-prometheus = { workspace = true }\nnohash-hasher = \"0.2.0\"\nopentelemetry = { version = \"0.20.0\", features = [\"rt-tokio\"] }\nopentelemetry-otlp = \"0.13.0\"\nrand = \"0.8.5\"\nreqwest = { version = \"0.11.20\", features = [] }\nserde = \"1.0.188\"\nserde_json = \"1.0.107\"\nslotmap = \"1.0.7\"\nthiserror = \"1.0.48\"\ntokenizers = { workspace = true }\ntokio = { version = \"1.32.0\", features = [\n  \"rt\",\n  \"rt-multi-thread\",\n  \"parking_lot\",\n  \"signal\",\n  \"sync\",\n] }\ntokio-stream = \"0.1.14\"\ntower-http = { version = \"0.5.1\", features = [\"cors\"] }\ntracing = \"0.1.37\"\ntracing-opentelemetry = \"0.21.0\"\ntracing-subscriber = { version = \"0.3.17\", features = [\"json\", \"env-filter\"] }\nutoipa = { version = \"4.2.0\", features = [\"axum_extras\"] }\nutoipa-swagger-ui = { version = \"6.0.0\", features = [\"axum\"] }\ninit-tracing-opentelemetry = { version = \"0.14.1\", features = [\n  \"opentelemetry-otlp\",\n] }\nminijinja = { workspace = true }\nminijinja-contrib = { workspace = true }\nfutures-util = \"0.3.30\"\nregex = \"1.10.3\"\nonce_cell = \"1.19.0\"\nimage = \"0.25.1\"\nbase64 = { workspace = true }\nprost = \"^0.12\"\ntonic = \"^0.10\"\ntower = \"^0.4\"\n\n[build-dependencies]\ntonic-build = \"0.10.1\"\nprost-build = \"0.12.1\"\n\n[features]\ndefault = [\"ngrok\"]\nngrok = [\"text-generation-router/ngrok\"]\ngoogle = [\"text-generation-router/google\"]\nkserve = [\"text-generation-router/kserve\"]\n"
  },
  {
    "path": "backends/v2/build.rs",
    "content": "use std::fs;\n\nfn main() -> Result<(), Box<dyn std::error::Error>> {\n    println!(\"cargo:rerun-if-changed=../../proto/\");\n\n    fs::create_dir_all(\"src/client/pb\").unwrap_or(());\n    let mut config = prost_build::Config::new();\n    config.protoc_arg(\"--experimental_allow_proto3_optional\");\n\n    tonic_build::configure()\n        .build_client(true)\n        .build_server(false)\n        .out_dir(\"src/client/pb\")\n        .include_file(\"mod.rs\")\n        .compile_with_config(config, &[\"../../proto/generate.proto\"], &[\"../../proto\"])\n        .unwrap_or_else(|e| panic!(\"protobuf compilation failed: {e}\"));\n\n    Ok(())\n}\n"
  },
  {
    "path": "backends/v2/src/backend.rs",
    "content": "use crate::client::{Batch, CachedBatch, ClientError, Generation, Health, ShardedClient};\n/// Batching and inference logic\nuse crate::queue::{Entry, Queue};\nuse async_trait::async_trait;\nuse nohash_hasher::IntMap;\nuse std::sync::Arc;\nuse text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};\nuse text_generation_router::validation::ValidGenerateRequest;\nuse text_generation_router::{FinishReason, PrefillToken, Token};\nuse tokio::sync::mpsc::error::SendError;\nuse tokio::sync::{mpsc, Notify};\nuse tokio::time::Instant;\nuse tokio_stream::wrappers::UnboundedReceiverStream;\nuse tracing::{info_span, instrument, Instrument, Span};\n\npub struct BackendV2 {\n    /// Request queue\n    queue: Queue,\n    /// Notify batcher on queue appends\n    batching_task_notifier: Arc<Notify>,\n    /// Client clone, used for health checks to skip the queue\n    client: ShardedClient,\n}\n\nimpl BackendV2 {\n    #[allow(clippy::too_many_arguments)]\n    pub(crate) fn new(\n        client: ShardedClient,\n        waiting_served_ratio: f32,\n        max_batch_prefill_tokens: u32,\n        max_batch_total_tokens: u32,\n        max_waiting_tokens: usize,\n        max_batch_size: Option<usize>,\n        requires_padding: bool,\n        window_size: Option<u32>,\n        speculate: u32,\n    ) -> Self {\n        // Infer shared state\n        let attention = std::env::var(\"ATTENTION\").unwrap_or(\"paged\".to_string());\n        let block_size = match attention.as_str() {\n            \"flashinfer\" => 1,\n            \"flashdecoding\" => 256,\n            \"paged\" => 16,\n            _ => unreachable!(),\n        };\n\n        let queue = Queue::new(requires_padding, block_size, window_size, speculate);\n        let batching_task_notifier = Arc::new(Notify::new());\n\n        // Spawn batching background task that contains all the inference logic\n        tokio::spawn(batching_task(\n            client.clone(),\n            waiting_served_ratio,\n            max_batch_prefill_tokens,\n            max_batch_total_tokens,\n            max_waiting_tokens,\n            max_batch_size,\n            queue.clone(),\n            batching_task_notifier.clone(),\n        ));\n\n        Self {\n            queue,\n            batching_task_notifier,\n            client,\n        }\n    }\n}\n\n#[async_trait]\nimpl Backend for BackendV2 {\n    #[instrument(skip_all)]\n    fn schedule(\n        &self,\n        request: ValidGenerateRequest,\n    ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {\n        // MPSC channel to communicate with the background batching task\n        let (response_tx, response_rx) = mpsc::unbounded_channel();\n\n        // Append the request to the queue\n        self.queue.append(Entry {\n            request,\n            response_tx,\n            span: Span::current(),\n            temp_span: None,\n            queue_time: Instant::now(),\n            batch_time: None,\n        });\n\n        // Notify the background task that we have a new entry in the queue that needs\n        // to be batched\n        self.batching_task_notifier.notify_one();\n\n        // Return stream\n        Ok(UnboundedReceiverStream::new(response_rx))\n    }\n\n    async fn health(&self, current_health: bool) -> bool {\n        if current_health {\n            // Generation is healthy, we only check that the shards can allocate on device\n            self.client.device_health().await\n        } else {\n            self.client.model_health().await\n        }\n        .is_ok()\n    }\n\n    fn start_health(&self) -> bool {\n        true\n    }\n\n    fn name(&self) -> &'static str {\n        \"tgi-v2\"\n    }\n}\n\n/// Batching logic\n/// Will be launched in a background Tokio task\n///\n/// Batches requests and sends them to the inference server\n#[allow(clippy::too_many_arguments)]\npub(crate) async fn batching_task(\n    mut client: ShardedClient,\n    waiting_served_ratio: f32,\n    max_batch_prefill_tokens: u32,\n    max_batch_total_tokens: u32,\n    max_waiting_tokens: usize,\n    max_batch_size: Option<usize>,\n    queue: Queue,\n    notifier: Arc<Notify>,\n) {\n    // Infinite loop\n    loop {\n        // Wait for a notification from the Infer struct\n        notifier.notified().await;\n\n        // Get the next batch from the queue\n        // This batch might be smaller than the maximum batch size if there are not enough requests\n        // waiting in the queue\n        while let Some((mut entries, batch, span)) = queue\n            .next_batch(\n                None,\n                max_batch_size,\n                max_batch_prefill_tokens,\n                max_batch_total_tokens,\n            )\n            .await\n        {\n            let mut cached_batch = prefill(&mut client, batch, &mut entries)\n                .instrument(span)\n                .await;\n            let mut waiting_tokens = 1;\n\n            // We loop until we do not receive any cached batch from the inference server (== until\n            // all requests have met their stopping criteria)\n            while let Some(batch) = cached_batch {\n                // Get current batch info\n                let batch_size = batch.size;\n                let batch_max_tokens = batch.max_tokens;\n                let mut batches = vec![batch];\n                metrics::gauge!(\"tgi_batch_current_size\").set(batch_size as f64);\n                metrics::gauge!(\"tgi_batch_current_max_tokens\").set(batch_max_tokens as f64);\n\n                let min_size = if waiting_tokens >= max_waiting_tokens {\n                    // If we didn't onboard any new requests since >= max_waiting_tokens, we try\n                    // to add a new batch even though its size might be small\n                    None\n                } else {\n                    // Minimum batch size\n                    Some((batch_size as f32 * waiting_served_ratio).floor() as usize)\n                };\n\n                let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);\n                let max_size =\n                    max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));\n                // Try to get a new batch\n                if let Some((mut new_entries, new_batch, span)) = queue\n                    .next_batch(min_size, max_size, max_batch_prefill_tokens, token_budget)\n                    .await\n                {\n                    // Tracking metrics\n                    if min_size.is_some() {\n                        metrics::counter!(\"tgi_batch_concat\", \"reason\" => \"backpressure\")\n                            .increment(1);\n                    } else {\n                        metrics::counter!(\"tgi_batch_concat\", \"reason\" => \"wait_exceeded\")\n                            .increment(1);\n                    }\n\n                    entries.iter_mut().for_each(|(_, entry)| {\n                        // Create a new span to add the info that this entry is waiting\n                        // because a new batch is being computed\n                        let entry_waiting_span = info_span!(parent: &entry.span, \"waiting\");\n                        // Add relationships\n                        span.follows_from(&entry_waiting_span);\n                        entry_waiting_span.follows_from(&span);\n                        // Update entry\n                        entry.temp_span = Some(entry_waiting_span);\n                    });\n\n                    // Generate one token for this new batch to have the attention past in cache\n                    let new_cached_batch = prefill(&mut client, new_batch, &mut new_entries)\n                        .instrument(span)\n                        .await;\n                    // Reset waiting counter\n                    waiting_tokens = 1;\n                    // Extend current batch with the new batch\n                    if let Some(new_cached_batch) = new_cached_batch {\n                        entries.extend(new_entries);\n                        batches.push(new_cached_batch);\n                    }\n                }\n\n                // Create span for this batch to add context to inference calls\n                let next_batch_size = entries.len();\n                let next_batch_span =\n                    info_span!(parent: None, \"batch\", batch_size = next_batch_size);\n                entries.iter_mut().for_each(|(_, entry)| {\n                    // Create a new span to link the batch back to this entry\n                    let entry_batch_span = info_span!(parent: &entry.span, \"infer\");\n                    // Add relationships\n                    next_batch_span.follows_from(&entry_batch_span);\n                    entry_batch_span.follows_from(&next_batch_span);\n                    // Update entry\n                    entry.temp_span = Some(entry_batch_span);\n                });\n\n                cached_batch = decode(&mut client, batches, &mut entries)\n                    .instrument(next_batch_span)\n                    .await;\n                waiting_tokens += 1;\n            }\n            metrics::gauge!(\"tgi_batch_current_size\").set(0.0);\n            metrics::gauge!(\"tgi_batch_current_max_tokens\").set(0.0);\n        }\n    }\n}\n\n#[instrument(skip_all)]\nasync fn prefill(\n    client: &mut ShardedClient,\n    batch: Batch,\n    entries: &mut IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let start_time = Instant::now();\n    let batch_id = batch.id;\n    metrics::counter!(\"tgi_batch_inference_count\", \"method\" => \"prefill\").increment(1);\n\n    match client.prefill(batch).await {\n        Ok((generations, next_batch, timings)) => {\n            let start_filtering_time = Instant::now();\n            // Send generated tokens and filter stopped entries\n            filter_send_generations(generations, entries);\n\n            // Filter next batch and remove requests that were stopped\n            let next_batch = filter_batch(client, next_batch, entries).await;\n\n            metrics::histogram!(\"tgi_batch_forward_duration\",\"method\" => \"prefill\")\n                .record(timings.forward.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_decode_duration\", \"method\" => \"prefill\")\n                .record(timings.decode.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_filter_duration\", \"method\" => \"prefill\")\n                .record(start_filtering_time.elapsed().as_secs_f64());\n            metrics::histogram!(\"tgi_batch_inference_duration\",\"method\" => \"prefill\")\n                .record(start_time.elapsed().as_secs_f64());\n            metrics::counter!(\"tgi_batch_inference_success\", \"method\" => \"prefill\").increment(1);\n            next_batch\n        }\n        // If we have an error, we discard the whole batch\n        Err(err) => {\n            let _ = client.clear_cache(Some(batch_id)).await;\n            send_errors(err, entries);\n            metrics::counter!(\"tgi_batch_inference_failure\", \"method\" => \"prefill\").increment(1);\n            None\n        }\n    }\n}\n\n#[instrument(skip_all)]\nasync fn decode(\n    client: &mut ShardedClient,\n    batches: Vec<CachedBatch>,\n    entries: &mut IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let start_time = Instant::now();\n    let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();\n    metrics::counter!(\"tgi_batch_inference_count\", \"method\" => \"decode\").increment(1);\n\n    match client.decode(batches).await {\n        Ok((generations, next_batch, timings)) => {\n            let start_filtering_time = Instant::now();\n            // Send generated tokens and filter stopped entries\n            filter_send_generations(generations, entries);\n\n            // Filter next batch and remove requests that were stopped\n            let next_batch = filter_batch(client, next_batch, entries).await;\n\n            if let Some(concat_duration) = timings.concat {\n                metrics::histogram!(\"tgi_batch_concat_duration\", \"method\" => \"decode\")\n                    .record(concat_duration.as_secs_f64());\n            }\n            metrics::histogram!(\"tgi_batch_forward_duration\", \"method\" => \"decode\")\n                .record(timings.forward.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_decode_duration\", \"method\" => \"decode\")\n                .record(timings.decode.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_filter_duration\", \"method\" => \"decode\")\n                .record(start_filtering_time.elapsed().as_secs_f64());\n            metrics::histogram!(\"tgi_batch_inference_duration\", \"method\" => \"decode\")\n                .record(start_time.elapsed().as_secs_f64());\n            metrics::counter!(\"tgi_batch_inference_success\", \"method\" => \"decode\").increment(1);\n            next_batch\n        }\n        // If we have an error, we discard the whole batch\n        Err(err) => {\n            for id in batch_ids {\n                let _ = client.clear_cache(Some(id)).await;\n            }\n            send_errors(err, entries);\n            metrics::counter!(\"tgi_batch_inference_failure\", \"method\" => \"decode\").increment(1);\n            None\n        }\n    }\n}\n\n/// Filter a `batch` and remove all requests not present in `entries`\n#[instrument(skip_all)]\nasync fn filter_batch(\n    client: &mut ShardedClient,\n    next_batch: Option<CachedBatch>,\n    entries: &IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let mut batch = next_batch?;\n\n    // No need to filter\n    if batch.size as usize == entries.len() {\n        return Some(batch);\n    }\n\n    let id = batch.id;\n\n    // Retain only requests that are still in entries\n    batch.request_ids.retain(|id| entries.contains_key(id));\n\n    if batch.request_ids.is_empty() {\n        // All requests have been filtered out\n        // Next batch is now empty\n        // Clear it from the Python shards cache\n        // We unwrap here as we need to panic since we cannot recover if this method fails\n        client.clear_cache(Some(id)).await.unwrap();\n        None\n    } else {\n        // Filter Python shard cache\n        // We unwrap here as we need to panic since we cannot recover if this method fails\n        client.filter_batch(id, batch.request_ids).await.unwrap()\n    }\n}\n\n/// Send one or multiple `InferStreamResponse` to Infer for all `entries`\n/// and filter entries\n#[instrument(skip_all)]\nfn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {\n    generations.into_iter().for_each(|generation| {\n        let id = generation.request_id;\n        // Get entry\n        // We can `expect` here as the request id should always be in the entries\n        let entry = entries\n            .get(&id)\n            .expect(\"ID not found in entries. This is a bug.\");\n\n        // Create and enter a span to link this function back to the entry\n        let _span = info_span!(parent: entry.temp_span.as_ref().expect(\"batch_span is None. This is a bug.\"), \"send_generation\", generation = ?generation).entered();\n        // Send generation responses back to the infer task\n        // If the receive an error from the Flume channel, it means that the client dropped the\n        // request and we need to stop generating hence why we unwrap_or(true)\n        let stopped = send_responses(generation, entry).inspect_err(|_err| {\n            tracing::error!(\"Entry response channel error.\");\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n        }).unwrap_or(true);\n        if stopped {\n            entries.remove(&id).expect(\"ID not found in entries. This is a bug.\");\n        }\n    });\n}\n\n/// Send responses through the `entry` response channel\nfn send_responses(\n    generation: Generation,\n    entry: &Entry,\n) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {\n    // Return directly if the channel is disconnected\n    if entry.response_tx.is_closed() {\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n        return Ok(true);\n    }\n\n    let mut stopped = false;\n\n    if let Some(prefill_tokens) = generation.prefill_tokens {\n        // Create Token objects\n        // We do that here instead of in the Python code as Rust for loops are faster\n        let prefill_tokens = prefill_tokens\n            .ids\n            .into_iter()\n            .zip(prefill_tokens.logprobs)\n            .zip(prefill_tokens.texts)\n            .map(|((id, logprob), text)| PrefillToken { id, text, logprob })\n            .collect();\n\n        // Send message\n        entry\n            .response_tx\n            .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;\n    }\n\n    // Create last Token\n    let tokens_ = generation.tokens.expect(\"Non empty tokens in generation\");\n    let n = tokens_.ids.len();\n    metrics::histogram!(\"tgi_request_skipped_tokens\").record((n - 1) as f64);\n    let mut iterator = tokens_\n        .ids\n        .into_iter()\n        .zip(tokens_.logprobs)\n        .zip(tokens_.texts)\n        .zip(tokens_.is_special)\n        .enumerate()\n        .peekable();\n    while let Some((i, (((id, logprob), text), special))) = iterator.next() {\n        let token = Token {\n            id,\n            text,\n            logprob,\n            special,\n        };\n        let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {\n            top_tokens_\n                .ids\n                .iter()\n                .zip(top_tokens_.logprobs.iter())\n                .zip(top_tokens_.texts.iter())\n                .zip(top_tokens_.is_special.iter())\n                .map(|(((&id, &logprob), text), &special)| Token {\n                    id,\n                    text: text.to_string(),\n                    logprob,\n                    special,\n                })\n                .collect()\n        } else {\n            vec![]\n        };\n        match (&generation.generated_text, iterator.peek()) {\n            (Some(generated_text), None) => {\n                // Generation has ended\n                stopped = true;\n                // Send message\n                entry.response_tx.send(Ok(InferStreamResponse::End {\n                    token,\n                    top_tokens,\n                    generated_text: GeneratedText::from(generated_text.clone()),\n                    queued: entry.queue_time,\n                    start: entry.batch_time.unwrap(),\n                }))?;\n            }\n            _ => {\n                // Send message\n                entry\n                    .response_tx\n                    .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;\n            }\n        }\n    }\n\n    Ok(stopped)\n}\n\n/// Send errors to Infer for all `entries`\n#[instrument(skip_all)]\nfn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {\n    entries.drain().for_each(|(_, entry)| {\n        // Create and enter a span to link this function back to the entry\n        let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect(\"batch_span is None. This is a bug.\"), \"send_error\").entered();\n        let err = InferError::GenerationError(error.to_string());\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"generation\").increment(1);\n        tracing::error!(\"{err}\");\n\n        // unwrap_or is valid here as we don't care if the receiver is gone.\n        entry\n            .response_tx\n            .send(Err(err))\n            .unwrap_or(());\n    });\n}\n\nimpl From<crate::client::GeneratedText> for GeneratedText {\n    fn from(value: crate::client::GeneratedText) -> Self {\n        let v2_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();\n        let finish_reason = match v2_finish_reason {\n            crate::client::FinishReason::Length => FinishReason::Length,\n            crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,\n            crate::client::FinishReason::StopSequence => FinishReason::StopSequence,\n        };\n\n        Self {\n            text: value.text,\n            generated_tokens: value.generated_tokens,\n            finish_reason,\n            seed: value.seed,\n        }\n    }\n}\n"
  },
  {
    "path": "backends/v2/src/client/grpc_client.rs",
    "content": "/// Single shard Client\nuse crate::client::pb;\nuse crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};\nuse grpc_metadata::InjectTelemetryContext;\nuse pb::generate::v2::text_generation_service_client::TextGenerationServiceClient;\nuse pb::generate::v2::*;\nuse std::cmp::min;\nuse std::time::Duration;\nuse tonic::transport::{Channel, Uri};\nuse tracing::instrument;\n\n/// Text Generation Inference gRPC client\n#[derive(Debug, Clone)]\npub struct Client {\n    stub: TextGenerationServiceClient<Channel>,\n}\n\nimpl Client {\n    /// Returns a client connected to the given url\n    #[allow(dead_code)]\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let channel = Channel::builder(uri).connect().await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let channel = Channel::from_shared(\"http://[::]:50051\".to_string())\n            .unwrap()\n            .connect_with_connector(tower::service_fn(move |_: Uri| {\n                tokio::net::UnixStream::connect(path.clone())\n            }))\n            .await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a list of uris or unix sockets of all shards\n    #[instrument(skip(self))]\n    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {\n        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();\n        let response = self.stub.service_discovery(request).await.map_err(|_| {\n            ClientError::Connection(\"Server does not support v2 interface\".to_string())\n        })?;\n        let urls = response\n            .into_inner()\n            .urls\n            .into_iter()\n            // Remove unix socket prefix\n            .map(|url| match url.strip_prefix(\"unix://\") {\n                None => url,\n                Some(stripped_url) => stripped_url.to_string(),\n            })\n            .collect();\n        Ok(urls)\n    }\n\n    /// Get model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<InfoResponse> {\n        let request = tonic::Request::new(InfoRequest {}).inject_context();\n        let response = self.stub.info(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Get model health\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let request = tonic::Request::new(HealthRequest {}).inject_context();\n        let response = self.stub.health(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();\n        self.stub.clear_cache(request).await?;\n        Ok(())\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let request = tonic::Request::new(FilterBatchRequest {\n            batch_id,\n            request_ids,\n        })\n        .inject_context();\n        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();\n        Ok(filtered_batch.batch)\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip_all)]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: u32,\n        max_prefill_tokens: u32,\n        max_total_tokens: u32,\n        max_batch_size: Option<usize>,\n    ) -> Result<Option<u32>> {\n        let mut n_tokens = 0;\n        let mut requests = Vec::new();\n        // Create requests\n        while n_tokens < max_prefill_tokens {\n            let truncate = min(max_input_length, max_prefill_tokens - n_tokens);\n\n            let mut inputs = String::new();\n            inputs.push_str(&\"_test \".to_string().repeat(max_input_length as usize));\n            if n_tokens == 0 {\n                // 1 request is enough to test vision heads.\n                // Sending images on other queries messes up easily with truncation.\n                inputs.push_str(&format!(\n                    \"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})\",\n                ));\n            }\n\n            requests.push(Request {\n                id: 0,\n                inputs,\n                // We truncate the input on the server side to be sure that it has the correct size\n                truncate,\n                // Set sampling parameters to also take these ops into account in the max memory\n                parameters: Some(NextTokenChooserParameters {\n                    temperature: 0.9,\n                    top_k: 10,\n                    top_p: 0.9,\n                    typical_p: 0.9,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 1.2,\n                    frequency_penalty: 0.1,\n                    watermark: true,\n                    grammar: String::new(),\n                    grammar_type: GrammarType::None as i32,\n                }),\n                stopping_parameters: Some(StoppingCriteriaParameters {\n                    max_new_tokens: max_total_tokens - truncate,\n                    stop_sequences: vec![],\n                    ignore_eos_token: true,\n                }),\n                prefill_logprobs: true,\n                top_n_tokens: 20,\n            });\n            n_tokens += max_input_length;\n\n            // Check max_batch_size\n            if Some(requests.len()) == max_batch_size {\n                break;\n            }\n        }\n\n        let batch = Batch {\n            id: 0,\n            size: requests.len() as u32,\n            requests,\n            max_tokens: 0,\n        };\n\n        let request = tonic::Request::new(WarmupRequest {\n            batch: Some(batch),\n            max_input_length,\n            max_prefill_tokens,\n            max_total_tokens,\n        })\n        .inject_context();\n        let response = self.stub.warmup(request).await?.into_inner();\n        Ok(response.max_supported_total_tokens)\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();\n        let response = self.stub.prefill(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            PrefillTimings::new(response.forward_ns, response.decode_ns, response.total_ns),\n        ))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();\n        let response = self.stub.decode(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            DecodeTimings::new(\n                response.concat_ns,\n                response.forward_ns,\n                response.decode_ns,\n                response.total_ns,\n            ),\n        ))\n    }\n}\n\npub struct PrefillTimings {\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl PrefillTimings {\n    fn new(forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n\npub struct DecodeTimings {\n    pub concat: Option<Duration>,\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl DecodeTimings {\n    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            concat: concat_ns.map(Duration::from_nanos),\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n"
  },
  {
    "path": "backends/v2/src/client/mod.rs",
    "content": "//! Text Generation gRPC client library\n\nuse async_trait::async_trait;\nuse thiserror::Error;\nuse tonic::transport;\nuse tonic::Status;\n\n#[allow(clippy::derive_partial_eq_without_eq)]\nmod pb;\n\nmod grpc_client;\nmod sharded_client;\n\npub use grpc_client::Client;\npub use pb::generate::v2::{\n    Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType, HealthResponse,\n    InfoResponse, NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\npub use sharded_client::ShardedClient;\n\n#[async_trait]\npub trait Health {\n    /// Check if a generate server is healthy by asking it to allocate a tensor on device\n    async fn device_health(&self) -> Result<()>;\n\n    /// Check if a generate server is healthy by doing a forward pass.\n    /// EXPENSIVE\n    async fn model_health(&self) -> Result<()>;\n}\n\n#[derive(Debug)]\npub struct ShardInfo {\n    pub requires_padding: bool,\n    pub dtype: String,\n    pub device_type: String,\n    pub window_size: Option<u32>,\n    pub speculate: u32,\n}\n\n#[derive(Error, Debug, Clone)]\npub enum ClientError {\n    #[error(\"Could not connect to Text Generation server: {0}\")]\n    Connection(String),\n    #[error(\"Server error: {0}\")]\n    Generation(String),\n    #[error(\"Sharded results are empty\")]\n    EmptyResults,\n}\n\nimpl From<Status> for ClientError {\n    fn from(err: Status) -> Self {\n        let err = Self::Generation(err.message().to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\nimpl From<transport::Error> for ClientError {\n    fn from(err: transport::Error) -> Self {\n        let err = Self::Connection(err.to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\nstatic WARMUP_IMAGE_BASE64 :&str = \"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=\";\n\npub type Result<T> = std::result::Result<T, ClientError>;\n"
  },
  {
    "path": "backends/v2/src/client/sharded_client.rs",
    "content": "/// Multi shard Client\nuse crate::client::{ClientError, Result};\nuse crate::client::{Health, ShardInfo};\n\nuse crate::client::grpc_client::{DecodeTimings, PrefillTimings};\nuse crate::client::InfoResponse;\nuse crate::client::{\n    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,\n    NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\nuse async_trait::async_trait;\nuse futures::future::join_all;\nuse tonic::transport::Uri;\nuse tracing::instrument;\n\n#[derive(Debug, Clone)]\n/// Text Generation Inference gRPC multi client\npub struct ShardedClient {\n    clients: Vec<Client>,\n}\n\nimpl ShardedClient {\n    fn new(clients: Vec<Client>) -> Self {\n        Self { clients }\n    }\n\n    /// Create a new ShardedClient from a master client. The master client will communicate with\n    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.\n    async fn from_master_client(mut master_client: Client) -> Result<Self> {\n        // Get all uris/unix sockets from the master client\n        let uris = master_client.service_discovery().await?;\n        let futures = uris.into_iter().map(Client::connect_uds);\n        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();\n        Ok(Self::new(clients?))\n    }\n\n    /// Returns a client connected to the given uri\n    #[allow(dead_code)]\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let master_client = Client::connect(uri).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let master_client = Client::connect_uds(path).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Get the model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<ShardInfo> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.info())\n            .collect();\n        join_all(futures).await.pop().unwrap().map(ShardInfo::from)\n    }\n\n    /// GRPC health check\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.health())\n            .collect();\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.clear_cache(batch_id))\n            .collect();\n        join_all(futures).await.into_iter().collect()\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))\n            .collect();\n        // all shards return the same message\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip(self))]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: u32,\n        max_prefill_tokens: u32,\n        max_total_tokens: u32,\n        max_batch_size: Option<usize>,\n    ) -> Result<Option<u32>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| {\n                Box::pin(client.warmup(\n                    max_input_length,\n                    max_prefill_tokens,\n                    max_total_tokens,\n                    max_batch_size,\n                ))\n            })\n            .collect();\n        // Take the minimum value\n        let results = join_all(futures)\n            .await\n            .into_iter()\n            .collect::<Result<Vec<Option<u32>>>>()?;\n        Ok(results.into_iter().flatten().min())\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.prefill(batch.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.decode(batches.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n}\n\nimpl From<InfoResponse> for ShardInfo {\n    fn from(value: InfoResponse) -> Self {\n        Self {\n            requires_padding: value.requires_padding,\n            dtype: value.dtype,\n            device_type: value.device_type,\n            window_size: value.window_size,\n            speculate: value.speculate,\n        }\n    }\n}\n\n#[async_trait]\nimpl Health for ShardedClient {\n    async fn device_health(&self) -> Result<()> {\n        self.clone().health().await?;\n        Ok(())\n    }\n\n    async fn model_health(&self) -> Result<()> {\n        // Dummy batch of 1 token and 1 generated token\n        let liveness_request = Request {\n            id: u64::MAX,\n            inputs: \"liveness\".to_string(),\n            truncate: 10,\n            prefill_logprobs: false,\n            parameters: Some(NextTokenChooserParameters {\n                temperature: 1.0,\n                top_k: 0,\n                top_p: 1.0,\n                typical_p: 1.0,\n                do_sample: false,\n                seed: 0,\n                repetition_penalty: 1.0,\n                frequency_penalty: 0.0,\n                watermark: false,\n                grammar: String::new(),\n                grammar_type: GrammarType::None as i32,\n            }),\n            stopping_parameters: Some(StoppingCriteriaParameters {\n                max_new_tokens: 1,\n                stop_sequences: vec![],\n                ignore_eos_token: false,\n            }),\n            top_n_tokens: 0,\n        };\n        let batch = Batch {\n            id: u64::MAX,\n            requests: vec![liveness_request],\n            size: 1,\n            max_tokens: 2,\n        };\n        self.clone().prefill(batch).await?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "backends/v2/src/lib.rs",
    "content": "mod backend;\nmod client;\nmod queue;\n\nuse crate::client::{ClientError, ShardedClient};\npub(crate) use backend::BackendV2;\nuse serde::Serialize;\nuse thiserror::Error;\nuse utoipa::ToSchema;\n\n#[derive(Clone, Debug, Serialize, ToSchema)]\npub struct BackendInfo {\n    /// Mandatory\n    #[schema(example = \"cuda\")]\n    pub model_device_type: String,\n    #[schema(example = \"torch.float16\")]\n    pub model_dtype: String,\n\n    /// Backend parameters\n    #[schema(example = \"1\")]\n    pub speculate: usize,\n    #[schema(example = \"1.2\")]\n    pub waiting_served_ratio: f32,\n    #[schema(example = \"32000\")]\n    pub max_batch_total_tokens: u32,\n    #[schema(example = \"20\")]\n    pub max_waiting_tokens: usize,\n    #[schema(nullable = true, example = \"null\")]\n    pub max_batch_size: Option<usize>,\n}\n\n#[allow(clippy::too_many_arguments)]\npub async fn connect_backend(\n    max_input_tokens: usize,\n    max_total_tokens: usize,\n    master_shard_uds_path: String,\n    waiting_served_ratio: f32,\n    max_batch_prefill_tokens: u32,\n    max_batch_total_tokens: Option<u32>,\n    max_waiting_tokens: usize,\n    max_batch_size: Option<usize>,\n) -> Result<(BackendV2, BackendInfo), V2Error> {\n    // Helper function\n    let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {\n        match max_supported_batch_total_tokens {\n            // Older models do not support automatic max-batch-total-tokens\n            None => {\n                let max_batch_total_tokens = max_batch_total_tokens\n                    .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));\n                tracing::warn!(\"Model does not support automatic max batch total tokens\");\n                Ok(max_batch_total_tokens)\n            }\n            // Flash attention models return their max supported total tokens\n            Some(max_supported_batch_total_tokens) => {\n                // Warn if user added his own max-batch-total-tokens as we will ignore it\n                if max_batch_total_tokens.is_some() {\n                    tracing::warn!(\n                        \"`--max-batch-total-tokens` is deprecated for Flash \\\n                        Attention models.\"\n                    );\n                    tracing::warn!(\n                        \"Inferred max batch total tokens: {max_supported_batch_total_tokens}\"\n                    );\n                }\n                if max_total_tokens as u32 > max_supported_batch_total_tokens {\n                    return Err(V2Error::NotEnoughMemory(max_total_tokens));\n                }\n\n                Ok(max_supported_batch_total_tokens)\n            }\n        }\n    };\n\n    let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)\n        .await\n        .map_err(V2Error::Connection)?;\n\n    // server is running on v2\n    // Clear the cache; useful if the webserver rebooted\n    sharded_client\n        .clear_cache(None)\n        .await\n        .map_err(V2Error::Cache)?;\n    // Get info from the shard\n    let shard_info = sharded_client.info().await.map_err(V2Error::Info)?;\n\n    // Warmup model\n    tracing::info!(\"Warming up model\");\n    let max_batch_total_tokens = check_max_batch_total_tokens(\n        sharded_client\n            .warmup(\n                max_input_tokens as u32,\n                max_batch_prefill_tokens,\n                max_total_tokens as u32,\n                max_batch_size,\n            )\n            .await\n            .map_err(V2Error::Warmup)?,\n    )?;\n    tracing::info!(\"Setting max batch total tokens to {max_batch_total_tokens}\");\n\n    let backend_info = BackendInfo {\n        waiting_served_ratio,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        model_device_type: shard_info.device_type.clone(),\n        model_dtype: shard_info.dtype.clone(),\n        speculate: shard_info.speculate as usize,\n    };\n\n    let backend = BackendV2::new(\n        sharded_client,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        shard_info.requires_padding,\n        shard_info.window_size,\n        shard_info.speculate,\n    );\n\n    tracing::info!(\"Using backend V3\");\n\n    Ok((backend, backend_info))\n}\n\n#[derive(Debug, Error)]\npub enum V2Error {\n    #[error(\"Unable to clear the Python model shards cache: {0}\")]\n    Cache(ClientError),\n    #[error(\"Unable to connect to the Python model shards: {0}\")]\n    Connection(ClientError),\n    #[error(\"Unable to get the Python model shards info: {0}\")]\n    Info(ClientError),\n    #[error(\"Unable to warmup the Python model shards: {0}\")]\n    Warmup(ClientError),\n    #[error(\"Not enough memory to handle `max_total_tokens={0}`\")]\n    NotEnoughMemory(usize),\n}\n"
  },
  {
    "path": "backends/v2/src/main.rs",
    "content": "use clap::{Parser, Subcommand};\nuse text_generation_router::{server, usage_stats};\nuse text_generation_router_v2::{connect_backend, V2Error};\nuse thiserror::Error;\n\n/// App Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    #[command(subcommand)]\n    command: Option<Commands>,\n\n    #[clap(default_value = \"128\", long, env)]\n    max_concurrent_requests: usize,\n    #[clap(default_value = \"2\", long, env)]\n    max_best_of: usize,\n    #[clap(default_value = \"4\", long, env)]\n    max_stop_sequences: usize,\n    #[clap(default_value = \"5\", long, env)]\n    max_top_n_tokens: u32,\n    #[clap(default_value = \"1024\", long, env)]\n    max_input_tokens: usize,\n    #[clap(default_value = \"2048\", long, env)]\n    max_total_tokens: usize,\n    #[clap(default_value = \"1.2\", long, env)]\n    waiting_served_ratio: f32,\n    #[clap(default_value = \"4096\", long, env)]\n    max_batch_prefill_tokens: u32,\n    #[clap(long, env)]\n    max_batch_total_tokens: Option<u32>,\n    #[clap(default_value = \"20\", long, env)]\n    max_waiting_tokens: usize,\n    #[clap(long, env)]\n    max_batch_size: Option<usize>,\n    #[clap(default_value = \"0.0.0.0\", long, env)]\n    hostname: String,\n    #[clap(default_value = \"3000\", long, short, env)]\n    port: u16,\n    #[clap(default_value = \"9000\", long, short, env)]\n    prometheus_port: u16,\n    #[clap(default_value = \"/tmp/text-generation-server-0\", long, env)]\n    master_shard_uds_path: String,\n    #[clap(default_value = \"bigscience/bloom\", long, env)]\n    tokenizer_name: String,\n    #[clap(long, env)]\n    tokenizer_config_path: Option<String>,\n    #[clap(long, env)]\n    revision: Option<String>,\n    #[clap(long, env, value_enum)]\n    trust_remote_code: bool,\n    #[clap(default_value = \"2\", long, env)]\n    validation_workers: usize,\n    #[clap(long, env)]\n    api_key: Option<String>,\n    #[clap(long, env)]\n    json_output: bool,\n    #[clap(long, env)]\n    otlp_endpoint: Option<String>,\n    #[clap(default_value = \"text-generation-inference.router\", long, env)]\n    otlp_service_name: String,\n    #[clap(long, env)]\n    cors_allow_origin: Option<Vec<String>>,\n    #[clap(long, env)]\n    ngrok: bool,\n    #[clap(long, env)]\n    ngrok_authtoken: Option<String>,\n    #[clap(long, env)]\n    ngrok_edge: Option<String>,\n    #[clap(long, env, default_value_t = false)]\n    disable_grammar_support: bool,\n    #[clap(default_value = \"4\", long, env)]\n    max_client_batch_size: usize,\n    #[clap(default_value = \"on\", long, env)]\n    usage_stats: usage_stats::UsageStatsLevel,\n    #[clap(default_value = \"2000000\", long, env)]\n    payload_limit: usize,\n    #[clap(default_value = \"1073741824\", long, env)]\n    max_image_fetch_size: usize,\n}\n\n#[derive(Debug, Subcommand)]\nenum Commands {\n    PrintSchema,\n}\n\n#[tokio::main]\nasync fn main() -> Result<(), RouterError> {\n    // Get args\n    let args = Args::parse();\n    // Pattern match configuration\n    let Args {\n        command,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        hostname,\n        port,\n        prometheus_port,\n        master_shard_uds_path,\n        tokenizer_name,\n        tokenizer_config_path,\n        revision,\n        trust_remote_code,\n        validation_workers,\n        api_key,\n        json_output,\n        otlp_endpoint,\n        otlp_service_name,\n        cors_allow_origin,\n        ngrok,\n        ngrok_authtoken,\n        ngrok_edge,\n        disable_grammar_support,\n        max_client_batch_size,\n        usage_stats,\n        payload_limit,\n        max_image_fetch_size,\n    } = args;\n\n    if let Some(Commands::PrintSchema) = command {\n        use utoipa::OpenApi;\n        let api_doc = text_generation_router::server::ApiDoc::openapi();\n        let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();\n        println!(\"{}\", api_doc);\n        std::process::exit(0);\n    };\n    text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);\n\n    // Validate args\n    if max_input_tokens >= max_total_tokens {\n        return Err(RouterError::ArgumentValidation(\n            \"`max_input_tokens` must be < `max_total_tokens`\".to_string(),\n        ));\n    }\n    if max_input_tokens as u32 > max_batch_prefill_tokens {\n        return Err(RouterError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}\")));\n    }\n\n    if validation_workers == 0 {\n        return Err(RouterError::ArgumentValidation(\n            \"`validation_workers` must be > 0\".to_string(),\n        ));\n    }\n\n    if let Some(ref max_batch_total_tokens) = max_batch_total_tokens {\n        if max_batch_prefill_tokens > *max_batch_total_tokens {\n            return Err(RouterError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}\")));\n        }\n        if max_total_tokens as u32 > *max_batch_total_tokens {\n            return Err(RouterError::ArgumentValidation(format!(\"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}\")));\n        }\n    }\n\n    if let Some(max_batch_size) = max_batch_size {\n        if max_batch_size == 0 {\n            return Err(RouterError::ArgumentValidation(\n                \"`max_batch_size` must be > 0\".to_string(),\n            ));\n        }\n    }\n\n    let (backend, _backend_info) = connect_backend(\n        max_input_tokens,\n        max_total_tokens,\n        master_shard_uds_path,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n    )\n    .await?;\n\n    // Run server\n    server::run(\n        backend,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        validation_workers,\n        api_key,\n        tokenizer_name,\n        tokenizer_config_path,\n        revision,\n        trust_remote_code,\n        hostname,\n        port,\n        cors_allow_origin,\n        ngrok,\n        ngrok_authtoken,\n        ngrok_edge,\n        disable_grammar_support,\n        max_client_batch_size,\n        usage_stats,\n        payload_limit,\n        max_image_fetch_size,\n        prometheus_port,\n    )\n    .await?;\n    Ok(())\n}\n\n#[derive(Debug, Error)]\nenum RouterError {\n    #[error(\"Argument validation error: {0}\")]\n    ArgumentValidation(String),\n    #[error(\"Backend failed: {0}\")]\n    Backend(#[from] V2Error),\n    #[error(\"WebServer error: {0}\")]\n    WebServer(#[from] server::WebServerError),\n    #[error(\"Tokio runtime failed to start: {0}\")]\n    Tokio(#[from] std::io::Error),\n}\n"
  },
  {
    "path": "backends/v2/src/queue.rs",
    "content": "use crate::client::{\n    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\nuse nohash_hasher::{BuildNoHashHasher, IntMap};\nuse std::cmp::min;\nuse std::collections::VecDeque;\nuse text_generation_router::infer::InferError;\nuse text_generation_router::infer::InferStreamResponse;\nuse text_generation_router::validation::{\n    ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters, ValidStoppingParameters,\n};\nuse tokio::sync::{mpsc, oneshot};\nuse tokio::time::Instant;\nuse tracing::{info_span, instrument, Span};\n\n/// Queue entry\n#[derive(Debug)]\npub(crate) struct Entry {\n    /// Request\n    pub request: ValidGenerateRequest,\n    /// Response sender to communicate between the Infer struct and the batching_task\n    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,\n    /// Span that will live as long as entry\n    pub span: Span,\n    /// Temporary span used as a guard when logging inference, wait times...\n    pub temp_span: Option<Span>,\n    /// Instant when this entry was queued\n    pub queue_time: Instant,\n    /// Instant when this entry was added to a batch\n    pub batch_time: Option<Instant>,\n}\n\n/// Request Queue\n#[derive(Debug, Clone)]\npub(crate) struct Queue {\n    /// Channel to communicate with the background queue task\n    queue_sender: mpsc::UnboundedSender<QueueCommand>,\n}\n\nimpl Queue {\n    pub(crate) fn new(\n        requires_padding: bool,\n        block_size: u32,\n        window_size: Option<u32>,\n        speculate: u32,\n    ) -> Self {\n        // Create channel\n        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();\n\n        // Launch background queue task\n        tokio::spawn(queue_task(\n            requires_padding,\n            block_size,\n            window_size,\n            speculate,\n            queue_receiver,\n        ));\n\n        Self { queue_sender }\n    }\n\n    #[instrument(skip_all)]\n    pub(crate) fn append(&self, entry: Entry) {\n        // Send append command to the background task managing the state\n        // Unwrap is safe here\n        self.queue_sender\n            .send(QueueCommand::Append(Box::new(entry), Span::current()))\n            .unwrap();\n    }\n\n    // Get the next batch\n    #[instrument(skip(self))]\n    pub(crate) async fn next_batch(\n        &self,\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n    ) -> Option<NextBatch> {\n        // Create response channel\n        let (response_sender, response_receiver) = oneshot::channel();\n        // Send next batch command to the background task managing the state\n        // Unwrap is safe here\n        self.queue_sender\n            .send(QueueCommand::NextBatch {\n                min_size,\n                max_size,\n                prefill_token_budget,\n                token_budget,\n                response_sender,\n                span: Span::current(),\n            })\n            .unwrap();\n        // Await on response channel\n        // Unwrap is safe here\n        response_receiver.await.unwrap()\n    }\n}\n\n// Background task responsible of the queue state\nasync fn queue_task(\n    requires_padding: bool,\n    block_size: u32,\n    window_size: Option<u32>,\n    speculate: u32,\n    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,\n) {\n    let mut state = State::new(requires_padding, block_size, window_size, speculate);\n\n    while let Some(cmd) = receiver.recv().await {\n        match cmd {\n            QueueCommand::Append(entry, span) => {\n                span.in_scope(|| state.append(*entry));\n                metrics::gauge!(\"tgi_queue_size\").increment(1.0);\n            }\n            QueueCommand::NextBatch {\n                min_size,\n                max_size,\n                prefill_token_budget,\n                token_budget,\n                response_sender,\n                span,\n            } => span.in_scope(|| {\n                let next_batch =\n                    state.next_batch(min_size, max_size, prefill_token_budget, token_budget);\n                response_sender.send(next_batch).unwrap();\n                metrics::gauge!(\"tgi_queue_size\").set(state.entries.len() as f64);\n            }),\n        }\n    }\n}\n\n/// Queue State\n#[derive(Debug)]\nstruct State {\n    /// Queue entries organized in a Vec\n    entries: VecDeque<(u64, Entry)>,\n\n    /// Id of the next entry\n    next_id: u64,\n\n    /// Id of the next batch\n    next_batch_id: u64,\n\n    /// Whether the model is using padding\n    requires_padding: bool,\n\n    /// Paged Attention block size\n    block_size: u32,\n\n    /// Sliding window\n    window_size: Option<u32>,\n\n    /// Speculation amount\n    speculate: u32,\n}\n\nimpl State {\n    fn new(\n        requires_padding: bool,\n        block_size: u32,\n        window_size: Option<u32>,\n        speculate: u32,\n    ) -> Self {\n        Self {\n            entries: VecDeque::with_capacity(128),\n            next_id: 0,\n            next_batch_id: 0,\n            requires_padding,\n            block_size,\n            window_size,\n            speculate,\n        }\n    }\n\n    /// Append an entry to the queue\n    fn append(&mut self, mut entry: Entry) {\n        // Create a span that will live as long as the entry is in the queue waiting to be batched\n        let queue_span = info_span!(parent: &entry.span, \"queued\");\n        entry.temp_span = Some(queue_span);\n\n        // Push entry in the queue\n        self.entries.push_back((self.next_id, entry));\n        self.next_id += 1;\n    }\n\n    // Get the next batch\n    fn next_batch(\n        &mut self,\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n    ) -> Option<NextBatch> {\n        if self.entries.is_empty() {\n            tracing::debug!(\"No queue\");\n            return None;\n        }\n\n        // Check if we have enough entries\n        if let Some(min_size) = min_size {\n            if self.entries.len() < min_size {\n                tracing::debug!(\"Not enough entries\");\n                return None;\n            }\n        }\n\n        if let Some(max_size) = max_size {\n            if max_size == 0 {\n                tracing::debug!(\"No capacity\");\n                return None;\n            }\n        }\n\n        // Pad prefill_token_budget to be a multiple of block size\n        let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;\n\n        // Create span for this batch to add context to inference calls\n        let next_batch_span = info_span!(parent: None, \"batch\", batch_size = tracing::field::Empty);\n        next_batch_span.follows_from(Span::current());\n\n        let mut batch_requests = Vec::with_capacity(self.entries.len());\n        let mut batch_entries =\n            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());\n\n        let mut max_input_length = 0;\n        let mut prefill_tokens: u32 = 0;\n        let mut decode_tokens: u32 = 0;\n\n        // Pop entries starting from the front of the queue\n        while let Some((id, mut entry)) = self.entries.pop_front() {\n            // Filter entries where the response receiver was dropped (== entries where the request\n            // was dropped by the client)\n            if entry.response_tx.is_closed() {\n                metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n                tracing::debug!(\"Dropping entry\");\n                continue;\n            }\n\n            if self.requires_padding {\n                // We pad to max input length in the Python shards\n                // We need to take these padding tokens into the equation\n                max_input_length = max_input_length.max(entry.request.input_length);\n                prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length\n            } else {\n                // pad to block size\n                prefill_tokens +=\n                    entry.request.input_length.div_ceil(self.block_size) * self.block_size;\n            }\n\n            if self.requires_padding {\n                decode_tokens += entry.request.stopping_parameters.max_new_tokens;\n            } else {\n                let max_new_tokens = match self.window_size {\n                    None => entry.request.stopping_parameters.max_new_tokens,\n                    Some(window_size) => min(\n                        window_size.saturating_sub(entry.request.input_length),\n                        entry.request.stopping_parameters.max_new_tokens,\n                    ),\n                };\n\n                // pad to block size\n                decode_tokens += max_new_tokens.div_ceil(self.block_size) * self.block_size;\n            }\n\n            if prefill_tokens > prefill_token_budget\n                || (prefill_tokens + decode_tokens + self.speculate) > token_budget\n            {\n                // Entry is over budget\n                // Add it back to the front\n                tracing::debug!(\"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}\", self.speculate);\n                self.entries.push_front((id, entry));\n                break;\n            }\n\n            tracing::debug!(\"Accepting entry\");\n            // Create a new span to link the batch back to this entry\n            let entry_batch_span = info_span!(parent: &entry.span, \"infer\");\n            // Add relationships\n            next_batch_span.follows_from(&entry_batch_span);\n            entry_batch_span.follows_from(&next_batch_span);\n            // Update entry\n            entry.temp_span = Some(entry_batch_span);\n\n            batch_requests.push(Request {\n                id,\n                prefill_logprobs: entry.request.decoder_input_details,\n                inputs: entry.request.inputs.chunks_to_string(),\n                truncate: entry.request.truncate,\n                parameters: Some(NextTokenChooserParameters::from(\n                    entry.request.parameters.clone(),\n                )),\n                stopping_parameters: Some(StoppingCriteriaParameters::from(\n                    entry.request.stopping_parameters.clone(),\n                )),\n                top_n_tokens: entry.request.top_n_tokens,\n            });\n            // Set batch_time\n            entry.batch_time = Some(Instant::now());\n            // Insert in batch_entries IntMap\n            batch_entries.insert(id, entry);\n\n            // Check if max_size\n            if Some(batch_requests.len()) == max_size {\n                break;\n            }\n        }\n\n        // Empty batch\n        if batch_requests.is_empty() {\n            tracing::debug!(\"Filtered out all entries\");\n            return None;\n        }\n\n        // Check if our batch is big enough\n        if let Some(min_size) = min_size {\n            // Batch is too small\n            if batch_requests.len() < min_size {\n                // Add back entries to the queue in the correct order\n                for r in batch_requests.into_iter().rev() {\n                    let id = r.id;\n                    let entry = batch_entries.remove(&id).unwrap();\n                    self.entries.push_front((id, entry));\n                }\n\n                return None;\n            }\n        }\n\n        // Final batch size\n        let size = batch_requests.len() as u32;\n        next_batch_span.record(\"batch_size\", size);\n\n        let batch = Batch {\n            id: self.next_batch_id,\n            requests: batch_requests,\n            size,\n            max_tokens: (prefill_tokens + decode_tokens),\n        };\n        // Increment batch id\n        self.next_batch_id += 1;\n\n        metrics::histogram!(\"tgi_batch_next_size\").record(batch.size as f64);\n\n        Some((batch_entries, batch, next_batch_span))\n    }\n}\n\ntype NextBatch = (IntMap<u64, Entry>, Batch, Span);\n\n#[derive(Debug)]\nenum QueueCommand {\n    Append(Box<Entry>, Span),\n    NextBatch {\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n        response_sender: oneshot::Sender<Option<NextBatch>>,\n        span: Span,\n    },\n}\n\nimpl From<ValidParameters> for NextTokenChooserParameters {\n    fn from(value: ValidParameters) -> Self {\n        let (grammar, grammar_type) = match value.grammar {\n            None => (String::new(), GrammarType::None),\n\n            Some(grammar) => match grammar {\n                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),\n                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),\n            },\n        };\n\n        Self {\n            temperature: value.temperature,\n            top_k: value.top_k,\n            top_p: value.top_p,\n            typical_p: value.typical_p,\n            do_sample: value.do_sample,\n            seed: value.seed,\n            repetition_penalty: value.repetition_penalty,\n            frequency_penalty: value.frequency_penalty,\n            watermark: value.watermark,\n            grammar,\n            grammar_type: grammar_type.into(),\n        }\n    }\n}\n\nimpl From<ValidStoppingParameters> for StoppingCriteriaParameters {\n    fn from(value: ValidStoppingParameters) -> Self {\n        Self {\n            max_new_tokens: value.max_new_tokens,\n            stop_sequences: value.stop_sequences,\n            ignore_eos_token: value.ignore_eos_token,\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use std::sync::Arc;\n    use tracing::info_span;\n\n    fn default_entry() -> (\n        Entry,\n        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,\n    ) {\n        let (response_tx, receiver_tx) = mpsc::unbounded_channel();\n\n        let entry = Entry {\n            request: ValidGenerateRequest {\n                inputs: vec![],\n                input_ids: Some(Arc::new(vec![])),\n                input_length: 0,\n                add_special_tokens: true,\n                truncate: 0,\n                decoder_input_details: false,\n                parameters: ValidParameters {\n                    temperature: 0.0,\n                    top_k: 0,\n                    top_p: 0.0,\n                    typical_p: 0.0,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 0.0,\n                    frequency_penalty: 0.0,\n                    watermark: false,\n                    grammar: None,\n                },\n                stopping_parameters: ValidStoppingParameters {\n                    ignore_eos_token: false,\n                    max_new_tokens: 1,\n                    max_total_new_tokens: 1024,\n                    stop_sequences: vec![],\n                },\n                top_n_tokens: 0,\n                adapter_id: None,\n            },\n            response_tx,\n            span: info_span!(\"entry\"),\n            temp_span: None,\n            queue_time: Instant::now(),\n            batch_time: None,\n        };\n        (entry, receiver_tx)\n    }\n\n    #[test]\n    fn test_append() {\n        let mut state = State::new(false, 1, None, 0);\n        let (entry, _guard) = default_entry();\n\n        assert_eq!(state.next_id, 0);\n        assert_eq!(state.entries.len(), 0);\n\n        state.append(entry);\n\n        assert_eq!(state.next_id, 1);\n        assert_eq!(state.entries.len(), 1);\n        let (id, _) = state.entries.remove(0).unwrap();\n        assert_eq!(id, 0);\n    }\n\n    #[test]\n    fn test_next_batch_empty() {\n        let mut state = State::new(false, 1, None, 0);\n\n        assert!(state.next_batch(None, None, 1, 1).is_none());\n        assert!(state.next_batch(Some(1), None, 1, 1).is_none());\n    }\n\n    #[test]\n    fn test_next_batch_min_size() {\n        let mut state = State::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, None, 2, 2).unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert!(entries.get(&1).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 0);\n        assert_eq!(state.next_batch_id, 1);\n\n        let (entry3, _guard3) = default_entry();\n        state.append(entry3);\n\n        assert!(state.next_batch(Some(2), None, 2, 2).is_none());\n\n        assert_eq!(state.next_id, 3);\n        assert_eq!(state.entries.len(), 1);\n        let (id, _) = state.entries.remove(0).unwrap();\n        assert_eq!(id, 2);\n    }\n\n    #[test]\n    fn test_next_batch_max_size() {\n        let mut state = State::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 1);\n        assert_eq!(state.next_batch_id, 1);\n    }\n\n    #[test]\n    fn test_next_batch_token_budget() {\n        let mut state = State::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, None, 1, 1).unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 1);\n        assert_eq!(state.next_batch_id, 1);\n\n        let (entry3, _guard3) = default_entry();\n        state.append(entry3);\n\n        let (entries, batch, _) = state.next_batch(None, None, 3, 3).unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&1));\n        assert!(entries.contains_key(&2));\n        assert_eq!(batch.id, 1);\n        assert_eq!(batch.size, 2);\n\n        assert_eq!(state.next_id, 3);\n        assert_eq!(state.entries.len(), 0);\n        assert_eq!(state.next_batch_id, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_append() {\n        let queue = Queue::new(false, 1, None, 0);\n        let (entry, _guard) = default_entry();\n        queue.append(entry);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_empty() {\n        let queue = Queue::new(false, 1, None, 0);\n\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_min_size() {\n        let queue = Queue::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert!(entries.get(&1).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n\n        let (entry3, _guard3) = default_entry();\n        queue.append(entry3);\n\n        // Not enough requests pending\n        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());\n        // Not enough token budget\n        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());\n        // Ok\n        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();\n        assert_eq!(entries2.len(), 1);\n        assert!(entries2.contains_key(&2));\n        assert!(entries2.get(&2).unwrap().batch_time.is_some());\n        assert_eq!(batch2.id, 1);\n        assert_eq!(batch2.size, 1);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_max_size() {\n        let queue = Queue::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_token_budget() {\n        let queue = Queue::new(false, 1, None, 0);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        let (entry3, _guard3) = default_entry();\n        queue.append(entry3);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&1));\n        assert!(entries.contains_key(&2));\n        assert_eq!(batch.id, 1);\n        assert_eq!(batch.size, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_token_speculate() {\n        let queue = Queue::new(false, 1, None, 2);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        // Budget of 1 is not enough\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n\n        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_dropped_receiver() {\n        let queue = Queue::new(false, 1, None, 0);\n        let (entry, _) = default_entry();\n        queue.append(entry);\n\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n    }\n}\n"
  },
  {
    "path": "backends/v3/Cargo.toml",
    "content": "[package]\nname = \"text-generation-router-v3\"\ndescription = \"Text Generation Webserver\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[lib]\npath = \"src/lib.rs\"\n\n[[bin]]\nname = \"text-generation-router\"\npath = \"src/main.rs\"\n\n[dependencies]\nasync-trait = \"0.1.74\"\nasync-stream = \"0.3.5\"\naxum = { version = \"0.7\", features = [\"json\"] }\naxum-tracing-opentelemetry = \"0.16\"\ntext-generation-router = { path = \"../../router\" }\nclap = { version = \"4.4.5\", features = [\"derive\", \"env\"] }\ngrpc-metadata = { path = \"../grpc-metadata\" }\nfutures = \"0.3.28\"\nhf-hub = { workspace = true }\njsonschema = { version = \"0.28.0\" }\nmetrics = { workspace = true }\nmetrics-exporter-prometheus = { workspace = true }\nnohash-hasher = \"0.2.0\"\nopentelemetry = { version = \"0.20.0\", features = [\"rt-tokio\"] }\nopentelemetry-otlp = \"0.13.0\"\nrand = \"0.8.5\"\nreqwest = { version = \"0.11.20\", features = [] }\nserde = \"1.0.188\"\nserde_json = \"1.0.107\"\nslotmap = \"1.0.7\"\nthiserror = \"1.0.48\"\ntokenizers = { workspace = true }\ntokio = { version = \"1.32.0\", features = [\n  \"rt\",\n  \"rt-multi-thread\",\n  \"parking_lot\",\n  \"signal\",\n  \"sync\",\n] }\ntokio-stream = \"0.1.14\"\ntower-http = { version = \"0.5.1\", features = [\"cors\"] }\ntracing = \"0.1.37\"\ntracing-opentelemetry = \"0.21.0\"\ntracing-subscriber = { version = \"0.3.17\", features = [\"json\", \"env-filter\"] }\nutoipa = { version = \"4.2.0\", features = [\"axum_extras\"] }\nutoipa-swagger-ui = { version = \"6.0.0\", features = [\"axum\"] }\ninit-tracing-opentelemetry = { version = \"0.14.1\", features = [\n  \"opentelemetry-otlp\",\n] }\nminijinja = { workspace = true }\nminijinja-contrib = { workspace = true }\nfutures-util = \"0.3.30\"\nregex = \"1.10.3\"\nonce_cell = \"1.19.0\"\nimage = \"0.25.1\"\nbase64 = { workspace = true }\nprost = \"^0.12\"\ntonic = \"^0.10\"\ntower = \"^0.4\"\n\n[build-dependencies]\ntonic-build = \"0.10.1\"\nprost-build = \"0.12.1\"\n\n[dev-dependencies]\ncriterion = \"0.3\"\nitertools = \"0.13\"\nrustc-hash = \"2\"\n\n[features]\ndefault = [\"ngrok\"]\nngrok = [\"text-generation-router/ngrok\"]\ngoogle = [\"text-generation-router/google\"]\nkserve = [\"text-generation-router/kserve\"]\n\n[[bench]]\nname = \"prefix_cache\"\nharness = false\n"
  },
  {
    "path": "backends/v3/benches/prefix_cache.rs",
    "content": "use std::sync::Arc;\n\nuse criterion::{black_box, criterion_group, criterion_main, Criterion};\nuse rand::Rng;\n\nuse text_generation_router_v3::block_allocator::Allocator;\nuse text_generation_router_v3::radix::RadixAllocator;\n\nfn prefix_cache_benchmark(c: &mut Criterion) {\n    // let prefixes: Vec<Vec<u32>> = (0..8192)\n    //     .chunks(256)\n    //     .into_iter()\n    //     .map(|c| c.collect())\n    //     .collect();\n\n    let mut cache = RadixAllocator::new(1, 262144, None);\n\n    c.bench_function(\"Radix allocator\", |b| {\n        b.iter_batched(\n            || {\n                //prefixes\n                //    .choose_multiple(&mut rand::thread_rng(), 5)\n                //    .fold(Vec::new(), |mut v, s| {\n                //        v.extend(s);\n                //        v\n                //    })\n\n                (0..7936)\n                    .map(|_| rand::thread_rng().gen_range(0..1024))\n                    .collect::<Vec<u32>>()\n            },\n            |prefill| {\n                let alloc = cache.allocate(\n                    prefill.len() as u32 + 13,\n                    Some(Arc::new(black_box(prefill))),\n                );\n                if let Some(alloc) = alloc {\n                    cache.free(alloc.blocks.clone(), alloc.allocation_id);\n                }\n            },\n            criterion::BatchSize::SmallInput,\n        );\n    });\n}\n\ncriterion_group!(benches, prefix_cache_benchmark);\ncriterion_main!(benches);\n"
  },
  {
    "path": "backends/v3/build.rs",
    "content": "use std::fs;\n\nfn main() -> Result<(), Box<dyn std::error::Error>> {\n    println!(\"cargo:rerun-if-changed=../../proto/\");\n\n    fs::create_dir_all(\"src/client/pb\").unwrap_or(());\n    let mut config = prost_build::Config::new();\n    config.protoc_arg(\"--experimental_allow_proto3_optional\");\n\n    tonic_build::configure()\n        .build_client(true)\n        .build_server(false)\n        .out_dir(\"src/client/pb\")\n        .include_file(\"mod.rs\")\n        .compile_with_config(config, &[\"../../proto/v3/generate.proto\"], &[\"../../proto\"])\n        .unwrap_or_else(|e| panic!(\"protobuf compilation failed: {e}\"));\n\n    Ok(())\n}\n"
  },
  {
    "path": "backends/v3/src/backend.rs",
    "content": "/// Batching and inference logic\nuse crate::client::{\n    Batch, CachedBatch, ClientError, Generation, Health, InfoResponse, ShardedClient,\n};\nuse crate::queue::{Entry, Queue};\nuse async_trait::async_trait;\nuse nohash_hasher::IntMap;\nuse std::sync::Arc;\nuse text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};\nuse text_generation_router::validation::ValidGenerateRequest;\nuse text_generation_router::{FinishReason, PrefillToken, Token};\nuse tokio::sync::mpsc::error::SendError;\nuse tokio::sync::{mpsc, Notify};\nuse tokio::time::Instant;\nuse tokio_stream::wrappers::UnboundedReceiverStream;\nuse tracing::{info_span, instrument, Instrument, Span};\n\npub struct BackendV3 {\n    /// Request queue\n    queue: Queue,\n    /// Notify batcher on queue appends\n    batching_task_notifier: Arc<Notify>,\n    /// Client clone, used for health checks to skip the queue\n    client: ShardedClient,\n}\n\nimpl BackendV3 {\n    #[allow(clippy::too_many_arguments)]\n    pub(crate) fn new(\n        client: ShardedClient,\n        waiting_served_ratio: f32,\n        max_batch_prefill_tokens: u32,\n        max_batch_total_tokens: u32,\n        max_waiting_tokens: usize,\n        max_batch_size: Option<usize>,\n        shard_info: InfoResponse,\n    ) -> Self {\n        if shard_info.support_chunking {\n            tracing::warn!(\"Model supports prefill chunking. `waiting_served_ratio` and `max_waiting_tokens` will be ignored.\");\n        }\n\n        let block_size = shard_info.block_size;\n\n        let queue = Queue::new(\n            shard_info.requires_padding,\n            block_size,\n            shard_info.use_prefix_caching,\n            shard_info.window_size,\n            shard_info.speculate,\n            max_batch_total_tokens,\n            shard_info.support_chunking,\n        );\n        let batching_task_notifier = Arc::new(Notify::new());\n\n        // Spawn batching background task that contains all the inference logic\n        tokio::spawn(batching_task(\n            client.clone(),\n            waiting_served_ratio,\n            max_batch_prefill_tokens,\n            max_batch_total_tokens,\n            max_waiting_tokens,\n            max_batch_size,\n            shard_info.support_chunking,\n            queue.clone(),\n            batching_task_notifier.clone(),\n        ));\n\n        Self {\n            queue,\n            batching_task_notifier,\n            client,\n        }\n    }\n}\n\n#[async_trait]\nimpl Backend for BackendV3 {\n    #[instrument(skip_all)]\n    fn schedule(\n        &self,\n        request: ValidGenerateRequest,\n    ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError> {\n        // MPSC channel to communicate with the background batching task\n        let (response_tx, response_rx) = mpsc::unbounded_channel();\n\n        // Append the request to the queue\n        self.queue.append(Entry {\n            request,\n            response_tx,\n            span: Span::current(),\n            temp_span: None,\n            queue_time: Instant::now(),\n            batch_time: None,\n            block_allocation: None,\n        });\n\n        // Notify the background task that we have a new entry in the queue that needs\n        // to be batched\n        self.batching_task_notifier.notify_one();\n\n        // Return stream\n        Ok(UnboundedReceiverStream::new(response_rx))\n    }\n\n    async fn health(&self, current_health: bool) -> bool {\n        if current_health {\n            // Generation is healthy, we only check that the shards can allocate on device\n            self.client.device_health().await\n        } else {\n            self.client.model_health().await\n        }\n        .is_ok()\n    }\n\n    fn start_health(&self) -> bool {\n        true\n    }\n\n    fn name(&self) -> &'static str {\n        \"tgi-v3\"\n    }\n}\n\n/// Batching logic\n/// Will be launched in a background Tokio task\n///\n/// Batches requests and sends them to the inference server\n#[allow(clippy::too_many_arguments)]\npub(crate) async fn batching_task(\n    mut client: ShardedClient,\n    waiting_served_ratio: f32,\n    max_batch_prefill_tokens: u32,\n    max_batch_total_tokens: u32,\n    max_waiting_tokens: usize,\n    max_batch_size: Option<usize>,\n    support_chunking: bool,\n    queue: Queue,\n    notifier: Arc<Notify>,\n) {\n    // Infinite loop\n    loop {\n        // Wait for a notification from the Infer struct\n        notifier.notified().await;\n\n        // Get the next batch from the queue\n        // This batch might be smaller than the maximum batch size if there are not enough requests\n        // waiting in the queue\n        while let Some((mut entries, batch, span)) = queue\n            .next_batch(\n                None,\n                max_batch_size,\n                max_batch_prefill_tokens,\n                max_batch_total_tokens,\n            )\n            .await\n        {\n            let mut cached_batch = prefill(&mut client, batch, None, &mut entries)\n                .instrument(span)\n                .await;\n            let mut waiting_tokens = 1;\n\n            // We loop until we do not receive any cached batch from the inference server (== until\n            // all requests have met their stopping criteria)\n            while let Some(batch) = cached_batch {\n                // Get current batch info\n                let batch_size = batch.size;\n                let batch_max_tokens = batch.max_tokens;\n                let current_tokens = batch.current_tokens;\n                let mut batches = vec![batch];\n                metrics::gauge!(\"tgi_batch_current_size\").set(batch_size as f64);\n                metrics::gauge!(\"tgi_batch_current_max_tokens\").set(batch_max_tokens as f64);\n\n                let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);\n\n                let (min_size, max_size, prefill_token_budget) = if support_chunking {\n                    // Since the next batch will be concatenated with the current batch,\n                    // the current batch tokens must be subtracted to the prefill budget\n                    let prefill_token_budget =\n                        max_batch_prefill_tokens.saturating_sub(current_tokens);\n                    // We can ignore min_size and max_size\n                    // Models than rely on max_size cannot support chunking\n                    // Regarding min_size, chunking allow us to consistently run at the compute\n                    // bound, making min_size useless.\n                    (None, None, prefill_token_budget)\n                } else {\n                    let min_size = if waiting_tokens >= max_waiting_tokens {\n                        // If we didn't onboard any new requests since >= max_waiting_tokens, we try\n                        // to add a new batch even though its size might be small\n                        None\n                    } else {\n                        // Minimum batch size\n                        // TODO: temporarily disable to avoid incorrect deallocation +\n                        //       reallocation when using prefix caching.\n                        Some((batch_size as f32 * waiting_served_ratio).floor() as usize)\n                    };\n\n                    let max_size =\n                        max_batch_size.map(|max_size| max_size.saturating_sub(batch_size as usize));\n\n                    (min_size, max_size, max_batch_prefill_tokens)\n                };\n\n                // Try to get a new batch\n                if let Some((mut new_entries, new_batch, span)) = queue\n                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)\n                    .await\n                {\n                    // Tracking metrics\n                    if min_size.is_some() {\n                        metrics::counter!(\"tgi_batch_concat\", \"reason\" => \"backpressure\")\n                            .increment(1);\n                    } else {\n                        let counter = if support_chunking {\n                            metrics::counter!(\"tgi_batch_concat\", \"reason\" => \"chunking\")\n                        } else {\n                            metrics::counter!(\"tgi_batch_concat\", \"reason\" => \"wait_exceeded\")\n                        };\n                        counter.increment(1);\n                    }\n\n                    let new_cached_batch = if support_chunking {\n                        // Get cached batch\n                        let cached_batch = batches.pop();\n                        // Extend entries with the new entries since the batch will be\n                        // concatenated during the prefill op server side\n                        entries.extend(new_entries);\n                        // Generate one token for both the cached batch and the new batch\n                        let new_cached_batch =\n                            prefill(&mut client, new_batch, cached_batch, &mut entries)\n                                .instrument(span)\n                                .await;\n                        if new_cached_batch.is_none() {\n                            // New cached batch is empty, no work left\n                            break;\n                        }\n                        new_cached_batch\n                    } else {\n                        // Request are waiting because we cannot concatenate the batches if the\n                        // model/server does not support chunking\n                        entries.iter_mut().for_each(|(_, entry)| {\n                            // Create a new span to add the info that this entry is waiting\n                            // because a new batch is being computed\n                            let entry_waiting_span = info_span!(parent: &entry.span, \"waiting\");\n                            // Add relationships\n                            span.follows_from(&entry_waiting_span);\n                            entry_waiting_span.follows_from(&span);\n                            // Update entry\n                            entry.temp_span = Some(entry_waiting_span);\n                        });\n\n                        // Generate one token for this new batch to have the attention past in cache\n                        let new_cached_batch =\n                            prefill(&mut client, new_batch, None, &mut new_entries)\n                                .instrument(span)\n                                .await;\n                        if new_cached_batch.is_some() {\n                            // Extend entries\n                            entries.extend(new_entries);\n                        }\n                        new_cached_batch\n                    };\n\n                    // Reset waiting counter\n                    waiting_tokens = 1;\n                    // Extend current batch with the new batch\n                    if let Some(new_cached_batch) = new_cached_batch {\n                        batches.push(new_cached_batch);\n                    }\n                }\n\n                // Create span for this batch to add context to inference calls\n                let next_batch_size = entries.len();\n                let next_batch_span =\n                    info_span!(parent: None, \"batch\", batch_size = next_batch_size);\n                entries.iter_mut().for_each(|(_, entry)| {\n                    // Create a new span to link the batch back to this entry\n                    let entry_batch_span = info_span!(parent: &entry.span, \"infer\");\n                    // Add relationships\n                    next_batch_span.follows_from(&entry_batch_span);\n                    entry_batch_span.follows_from(&next_batch_span);\n                    // Update entry\n                    entry.temp_span = Some(entry_batch_span);\n                });\n\n                cached_batch = decode(&mut client, batches, &mut entries)\n                    .instrument(next_batch_span)\n                    .await;\n                waiting_tokens += 1;\n            }\n            metrics::gauge!(\"tgi_batch_current_size\").set(0.0);\n            metrics::gauge!(\"tgi_batch_current_max_tokens\").set(0.0);\n        }\n    }\n}\n\n#[instrument(skip_all)]\nasync fn prefill(\n    client: &mut ShardedClient,\n    batch: Batch,\n    cached_batch: Option<CachedBatch>,\n    entries: &mut IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let start_time = Instant::now();\n    let batch_id = batch.id;\n    metrics::counter!(\"tgi_batch_inference_count\", \"method\" => \"prefill\").increment(1);\n\n    match client.prefill(batch, cached_batch).await {\n        Ok((generations, next_batch, timings)) => {\n            let start_filtering_time = Instant::now();\n            // Send generated tokens and filter stopped entries\n            filter_send_generations(generations, entries);\n\n            // Filter next batch and remove requests that were stopped\n            let next_batch = filter_batch(client, next_batch, entries).await;\n\n            if let Some(concat_duration) = timings.concat {\n                metrics::histogram!(\"tgi_batch_concat_duration\", \"method\" => \"decode\")\n                    .record(concat_duration.as_secs_f64());\n            }\n            metrics::histogram!(\"tgi_batch_forward_duration\", \"method\" => \"prefill\")\n                .record(timings.forward.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_decode_duration\", \"method\" => \"prefill\")\n                .record(timings.decode.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_filter_duration\", \"method\" => \"prefill\")\n                .record(start_filtering_time.elapsed().as_secs_f64());\n            metrics::histogram!(\"tgi_batch_inference_duration\", \"method\" => \"prefill\")\n                .record(start_time.elapsed().as_secs_f64());\n            metrics::counter!(\"tgi_batch_inference_success\", \"method\" => \"prefill\").increment(1);\n            next_batch\n        }\n        // If we have an error, we discard the whole batch\n        Err(err) => {\n            let _ = client.clear_cache(Some(batch_id)).await;\n            send_errors(err, entries);\n            metrics::counter!(\"tgi_batch_inference_failure\", \"method\" => \"prefill\").increment(1);\n            None\n        }\n    }\n}\n\n#[instrument(skip_all)]\nasync fn decode(\n    client: &mut ShardedClient,\n    batches: Vec<CachedBatch>,\n    entries: &mut IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let start_time = Instant::now();\n    let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();\n    metrics::counter!(\"tgi_batch_inference_count\", \"method\" => \"decode\").increment(1);\n\n    match client.decode(batches).await {\n        Ok((generations, next_batch, timings)) => {\n            let start_filtering_time = Instant::now();\n            // Send generated tokens and filter stopped entries\n            filter_send_generations(generations, entries);\n\n            // Filter next batch and remove requests that were stopped\n            let next_batch = filter_batch(client, next_batch, entries).await;\n\n            if let Some(concat_duration) = timings.concat {\n                metrics::histogram!(\"tgi_batch_concat_duration\", \"method\" => \"decode\")\n                    .record(concat_duration.as_secs_f64());\n            }\n            metrics::histogram!(\"tgi_batch_forward_duration\", \"method\" => \"decode\")\n                .record(timings.forward.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_decode_duration\", \"method\" => \"decode\")\n                .record(timings.decode.as_secs_f64());\n            metrics::histogram!(\"tgi_batch_filter_duration\", \"method\" => \"decode\")\n                .record(start_filtering_time.elapsed().as_secs_f64());\n            metrics::histogram!(\"tgi_batch_inference_duration\", \"method\" => \"decode\")\n                .record(start_time.elapsed().as_secs_f64());\n            metrics::counter!(\"tgi_batch_inference_success\", \"method\" => \"decode\").increment(1);\n            next_batch\n        }\n        // If we have an error, we discard the whole batch\n        Err(err) => {\n            for id in batch_ids {\n                let _ = client.clear_cache(Some(id)).await;\n            }\n            send_errors(err, entries);\n            metrics::counter!(\"tgi_batch_inference_failure\", \"method\" => \"decode\").increment(1);\n            None\n        }\n    }\n}\n\n/// Filter a `batch` and remove all requests not present in `entries`\n#[instrument(skip_all)]\nasync fn filter_batch(\n    client: &mut ShardedClient,\n    next_batch: Option<CachedBatch>,\n    entries: &IntMap<u64, Entry>,\n) -> Option<CachedBatch> {\n    let mut batch = next_batch?;\n\n    // No need to filter\n    if batch.size as usize == entries.len() {\n        return Some(batch);\n    }\n\n    let id = batch.id;\n\n    // Retain only requests that are still in entries\n    batch.request_ids.retain(|id| entries.contains_key(id));\n\n    if batch.request_ids.is_empty() {\n        // All requests have been filtered out\n        // Next batch is now empty\n        // Clear it from the Python shards cache\n        // We unwrap here as we need to panic since we cannot recover if this method fails\n        client.clear_cache(Some(id)).await.unwrap();\n        None\n    } else {\n        // Filter Python shard cache\n        // We unwrap here as we need to panic since we cannot recover if this method fails\n        client.filter_batch(id, batch.request_ids).await.unwrap()\n    }\n}\n\n/// Send one or multiple `InferStreamResponse` to Infer for all `entries`\n/// and filter entries\n#[instrument(skip_all)]\nfn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u64, Entry>) {\n    generations.into_iter().for_each(|generation| {\n        let id = generation.request_id;\n        // Get entry\n        // We can `expect` here as the request id should always be in the entries\n        let entry = entries\n            .get(&id)\n            .expect(\"ID not found in entries. This is a bug.\");\n\n        // Create and enter a span to link this function back to the entry\n        let _span = info_span!(parent: entry.temp_span.as_ref().expect(\"batch_span is None. This is a bug.\"), \"send_generation\", generation = ?generation).entered();\n        // Send generation responses back to the infer task\n        // If the receive an error from the Flume channel, it means that the client dropped the\n        // request and we need to stop generating hence why we unwrap_or(true)\n        let stopped = send_responses(generation, entry).inspect_err(|_err| {\n            tracing::error!(\"Entry response channel error.\");\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n        }).unwrap_or(true);\n        if stopped {\n            entries.remove(&id).expect(\"ID not found in entries. This is a bug.\");\n        }\n    });\n}\n\n/// Send responses through the `entry` response channel\nfn send_responses(\n    generation: Generation,\n    entry: &Entry,\n) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {\n    // Return directly if the channel is disconnected\n    if entry.response_tx.is_closed() {\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n        return Ok(true);\n    }\n\n    let mut stopped = false;\n\n    if let Some(prefill_tokens) = generation.prefill_tokens {\n        // Create Token objects\n        // We do that here instead of in the Python code as Rust for loops are faster\n        let prefill_tokens = prefill_tokens\n            .ids\n            .into_iter()\n            .zip(prefill_tokens.logprobs)\n            .zip(prefill_tokens.texts)\n            .map(|((id, logprob), text)| PrefillToken { id, text, logprob })\n            .collect();\n\n        // Send message\n        entry\n            .response_tx\n            .send(Ok(InferStreamResponse::Prefill(prefill_tokens)))?;\n    }\n\n    // Create last Token\n    let tokens_ = generation.tokens.expect(\"Non empty tokens in generation\");\n    let n = tokens_.ids.len();\n    metrics::histogram!(\"tgi_request_skipped_tokens\").record((n - 1) as f64);\n    let mut iterator = tokens_\n        .ids\n        .into_iter()\n        .zip(tokens_.logprobs)\n        .zip(tokens_.texts)\n        .zip(tokens_.is_special)\n        .enumerate()\n        .peekable();\n    while let Some((i, (((id, logprob), text), special))) = iterator.next() {\n        let token = Token {\n            id,\n            text,\n            logprob,\n            special,\n        };\n        let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {\n            top_tokens_\n                .ids\n                .iter()\n                .zip(top_tokens_.logprobs.iter())\n                .zip(top_tokens_.texts.iter())\n                .zip(top_tokens_.is_special.iter())\n                .map(|(((&id, &logprob), text), &special)| Token {\n                    id,\n                    text: text.to_string(),\n                    logprob,\n                    special,\n                })\n                .collect()\n        } else {\n            vec![]\n        };\n        match (&generation.generated_text, iterator.peek()) {\n            (Some(generated_text), None) => {\n                // Generation has ended\n                stopped = true;\n                // Send message\n                entry.response_tx.send(Ok(InferStreamResponse::End {\n                    token,\n                    top_tokens,\n                    generated_text: GeneratedText::from(generated_text.clone()),\n                    queued: entry.queue_time,\n                    start: entry.batch_time.unwrap(),\n                }))?;\n            }\n            _ => {\n                // Send message\n                entry\n                    .response_tx\n                    .send(Ok(InferStreamResponse::Intermediate { token, top_tokens }))?;\n            }\n        }\n    }\n\n    Ok(stopped)\n}\n\n/// Send errors to Infer for all `entries`\n#[instrument(skip_all)]\nfn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {\n    entries.drain().for_each(|(_, entry)| {\n        // Create and enter a span to link this function back to the entry\n        let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect(\"batch_span is None. This is a bug.\"), \"send_error\").entered();\n        let err = InferError::GenerationError(error.to_string());\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"generation\").increment(1);\n        tracing::error!(\"{err}\");\n\n        // unwrap_or is valid here as we don't care if the receiver is gone.\n        entry\n            .response_tx\n            .send(Err(err))\n            .unwrap_or(());\n    });\n}\n\nimpl From<crate::client::GeneratedText> for GeneratedText {\n    fn from(value: crate::client::GeneratedText) -> Self {\n        let v3_finish_reason = crate::client::FinishReason::try_from(value.finish_reason).unwrap();\n        let finish_reason = match v3_finish_reason {\n            crate::client::FinishReason::Length => FinishReason::Length,\n            crate::client::FinishReason::EosToken => FinishReason::EndOfSequenceToken,\n            crate::client::FinishReason::StopSequence => FinishReason::StopSequence,\n        };\n\n        Self {\n            text: value.text,\n            generated_tokens: value.generated_tokens,\n            finish_reason,\n            seed: value.seed,\n        }\n    }\n}\n"
  },
  {
    "path": "backends/v3/src/block_allocator.rs",
    "content": "use std::sync::Arc;\nuse tokio::sync::{mpsc, oneshot};\n\nuse crate::radix::RadixAllocator;\nuse text_generation_router::usage_stats::Env;\n#[derive(Debug, Clone)]\npub struct BlockAllocation {\n    pub allocation_id: u64,\n    pub blocks: Vec<u32>,\n    pub slots: Vec<u32>,\n\n    /// Prefix that was cached and for which the KV does not have to\n    /// be recomputed.\n    pub prefix_len: u32,\n\n    pub(crate) block_allocator: Option<BlockAllocator>,\n}\n\nimpl Drop for BlockAllocation {\n    fn drop(&mut self) {\n        if let Some(block_allocator) = self.block_allocator.as_mut() {\n            block_allocator.free(self.blocks.clone(), self.allocation_id)\n        }\n    }\n}\n\n#[derive(Debug, Clone)]\npub struct BlockAllocator {\n    /// Channel to communicate with the background task\n    block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,\n}\n\nimpl BlockAllocator {\n    pub(crate) fn new(\n        max_batch_total_tokens: u32,\n        block_size: u32,\n        prefix_caching: bool,\n        window_size: Option<u32>,\n    ) -> Self {\n        // Create channel\n        let (sender, receiver) = mpsc::unbounded_channel();\n\n        // Launch background queue task\n        tokio::spawn(block_allocator_task(\n            max_batch_total_tokens / block_size,\n            block_size,\n            prefix_caching,\n            window_size,\n            receiver,\n        ));\n\n        Self {\n            block_allocator: sender,\n        }\n    }\n\n    pub(crate) async fn allocate(\n        &self,\n        tokens: u32,\n        prefill_tokens: Option<Arc<Vec<u32>>>,\n    ) -> Option<BlockAllocation> {\n        let (response_sender, response_receiver) = oneshot::channel();\n        self.block_allocator\n            .send(BlockAllocatorCommand::Allocate {\n                tokens,\n                prefill_tokens,\n                response_sender,\n            })\n            .unwrap();\n\n        response_receiver.await.unwrap().map(|mut allocation| {\n            allocation.block_allocator = Some(self.clone());\n            allocation\n        })\n    }\n\n    pub(crate) fn free(&self, blocks: Vec<u32>, allocation_id: u64) {\n        self.block_allocator\n            .send(BlockAllocatorCommand::Free {\n                allocation_id,\n                blocks,\n            })\n            .unwrap();\n    }\n}\n\nasync fn block_allocator_task(\n    blocks: u32,\n    block_size: u32,\n    prefix_caching: bool,\n    window_size: Option<u32>,\n    mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,\n) {\n    let mut allocator: Box<dyn Allocator + Send> = if prefix_caching {\n        Box::new(RadixAllocator::new(block_size, blocks, window_size))\n    } else {\n        Box::new(SimpleAllocator::new(blocks, block_size, window_size))\n    };\n    while let Some(cmd) = receiver.recv().await {\n        match cmd {\n            BlockAllocatorCommand::Free {\n                blocks,\n                allocation_id,\n            } => allocator.free(blocks, allocation_id),\n            BlockAllocatorCommand::Allocate {\n                tokens,\n                prefill_tokens,\n                response_sender,\n            } => {\n                response_sender\n                    .send(allocator.allocate(tokens, prefill_tokens))\n                    .unwrap();\n            }\n        }\n    }\n}\n\n#[derive(Debug)]\nenum BlockAllocatorCommand {\n    Free {\n        blocks: Vec<u32>,\n        allocation_id: u64,\n    },\n    Allocate {\n        tokens: u32,\n        prefill_tokens: Option<Arc<Vec<u32>>>,\n        response_sender: oneshot::Sender<Option<BlockAllocation>>,\n    },\n}\n\npub trait Allocator {\n    fn allocate(\n        &mut self,\n        tokens: u32,\n        prefill_tokens: Option<Arc<Vec<u32>>>,\n    ) -> Option<BlockAllocation>;\n\n    fn free(&mut self, blocks: Vec<u32>, allocation_id: u64);\n}\npub struct SimpleAllocator {\n    free_blocks: Vec<u32>,\n    block_size: u32,\n    window_size: Option<u32>,\n    is_hpu_device: bool,\n}\n\nimpl SimpleAllocator {\n    fn new(blocks: u32, block_size: u32, window_size: Option<u32>) -> Self {\n        SimpleAllocator {\n            block_size,\n            // Block 0 is reserved for health checks\n            free_blocks: (1..blocks).collect(),\n            window_size,\n            is_hpu_device: Env::new().is_hpu_device(),\n        }\n    }\n}\n\nimpl Allocator for SimpleAllocator {\n    fn allocate(\n        &mut self,\n        tokens: u32,\n        _prefill_tokens: Option<Arc<Vec<u32>>>,\n    ) -> Option<BlockAllocation> {\n        let mut tokens = tokens;\n        if self.is_hpu_device {\n            // need 1 slot for ping-pong optimization\n            tokens += 1;\n        }\n        // Apply window size\n        let (required_blocks, repeats) = {\n            let (tokens, repeats) = match self.window_size {\n                None => (tokens, 1),\n                Some(window_size) => {\n                    let repeats = tokens.div_ceil(window_size);\n                    let tokens = core::cmp::min(tokens, window_size);\n                    (tokens, repeats as usize)\n                }\n            };\n            // Pad to a multiple of block size\n            let required_blocks = tokens.div_ceil(self.block_size);\n            (required_blocks, repeats)\n        };\n        let tokens = tokens as usize;\n        if required_blocks > self.free_blocks.len() as u32 {\n            None\n        } else {\n            if self.is_hpu_device {\n                self.free_blocks.sort_by(|a, b| b.cmp(a));\n            }\n            let mut blocks = self\n                .free_blocks\n                .split_off(self.free_blocks.len() - required_blocks as usize);\n            if self.is_hpu_device {\n                blocks.sort();\n            }\n            let mut slots =\n                Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);\n\n            'slots: for block_id in blocks.repeat(repeats).iter() {\n                for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {\n                    slots.push(s);\n                    if slots.len() == tokens {\n                        break 'slots;\n                    }\n                }\n            }\n            Some(BlockAllocation {\n                allocation_id: 0,\n                blocks,\n                slots,\n                prefix_len: 0,\n                block_allocator: None,\n            })\n        }\n    }\n\n    fn free(&mut self, blocks: Vec<u32>, _allocation_id: u64) {\n        self.free_blocks.extend(blocks)\n    }\n}\n"
  },
  {
    "path": "backends/v3/src/client/grpc_client.rs",
    "content": "/// Single shard Client\nuse crate::client::{pb, Chunk};\nuse crate::client::{ClientError, Result, WARMUP_IMAGE_BASE64};\nuse base64::engine::general_purpose::STANDARD;\nuse base64::Engine;\nuse grpc_metadata::InjectTelemetryContext;\nuse pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;\nuse pb::generate::v3::*;\nuse std::cmp::min;\nuse std::time::Duration;\nuse tonic::transport::{Channel, Uri};\nuse tracing::instrument;\n\n/// Text Generation Inference gRPC client\n#[derive(Debug, Clone)]\npub struct Client {\n    stub: TextGenerationServiceClient<Channel>,\n}\n\nimpl Client {\n    /// Returns a client connected to the given url\n    #[allow(dead_code)]\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let channel = Channel::builder(uri).connect().await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let channel = Channel::from_shared(\"http://[::]:50051\".to_string())\n            .unwrap()\n            .connect_with_connector(tower::service_fn(move |_: Uri| {\n                tokio::net::UnixStream::connect(path.clone())\n            }))\n            .await?;\n\n        Ok(Self {\n            stub: TextGenerationServiceClient::new(channel),\n        })\n    }\n\n    /// Returns a list of uris or unix sockets of all shards\n    #[instrument(skip(self))]\n    pub async fn service_discovery(&mut self) -> Result<Vec<String>> {\n        let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context();\n        let response = self.stub.service_discovery(request).await.map_err(|_| {\n            ClientError::Connection(\"Server does not support v3 interface\".to_string())\n        })?;\n        let urls = response\n            .into_inner()\n            .urls\n            .into_iter()\n            // Remove unix socket prefix\n            .map(|url| match url.strip_prefix(\"unix://\") {\n                None => url,\n                Some(stripped_url) => stripped_url.to_string(),\n            })\n            .collect();\n        Ok(urls)\n    }\n\n    /// Get model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<InfoResponse> {\n        let request = tonic::Request::new(InfoRequest {}).inject_context();\n        let response = self.stub.info(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Get model health\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let request = tonic::Request::new(HealthRequest {}).inject_context();\n        let response = self.stub.health(request).await?.into_inner();\n        Ok(response)\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context();\n        self.stub.clear_cache(request).await?;\n        Ok(())\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let request = tonic::Request::new(FilterBatchRequest {\n            batch_id,\n            request_ids,\n        })\n        .inject_context();\n        let filtered_batch = self.stub.filter_batch(request).await?.into_inner();\n        Ok(filtered_batch.batch)\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip_all)]\n    pub async fn warmup(\n        &mut self,\n        max_input_tokens: Option<u32>,\n        max_prefill_tokens: u32,\n        max_total_tokens: Option<u32>,\n        max_batch_size: Option<usize>,\n    ) -> Result<(Option<u32>, u32, u32)> {\n        let mut n_tokens = 0;\n        let mut requests = Vec::new();\n        // Create requests\n        while n_tokens < max_prefill_tokens {\n            let mut truncate = max_prefill_tokens - n_tokens;\n            if let Some(max_input_tokens) = max_input_tokens {\n                truncate = min(max_input_tokens, truncate);\n            }\n\n            let mut input_chunks = Vec::new();\n            input_chunks.push(Chunk::Text(\"_test \".to_string().repeat(truncate as usize)).into());\n            if n_tokens == 0 {\n                input_chunks.push(\n                    Chunk::Image(Image {\n                        // Safe unwrap, because we control the data.\n                        data: STANDARD.decode(WARMUP_IMAGE_BASE64).unwrap(),\n                        mimetype: \"image/jpeg;base64\".to_string(),\n                    })\n                    .into(),\n                );\n            }\n\n            // Send stringly-typed inputs for compatibility for backends that haven't\n            // been updated to support chunks.\n\n            let mut inputs = String::new();\n            inputs.push_str(&\"_test \".to_string().repeat(truncate as usize));\n            if n_tokens == 0 {\n                // 1 request is enough to test vision heads.\n                // Sending images on other queries messes up easily with truncation.\n                inputs.push_str(&format!(\n                    \"![](data:image/jpeg;base64,{WARMUP_IMAGE_BASE64})\",\n                ));\n            }\n\n            let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {\n                max_total_tokens - truncate\n            } else {\n                1\n            };\n\n            requests.push(Request {\n                id: 0,\n                inputs,\n                add_special_tokens: true,\n                input_chunks: Some(Input {\n                    chunks: input_chunks,\n                }),\n                // We truncate the input on the server side to be sure that it has the correct size\n                truncate,\n                // Blocks and slots will be set on the server side if we use paged attention\n                blocks: vec![],\n                slots: vec![],\n                cache_len: 0,\n                chunk_len: None,\n                // Set sampling parameters to also take these ops into account in the max memory\n                parameters: Some(NextTokenChooserParameters {\n                    temperature: 0.9,\n                    top_k: 10,\n                    top_p: 0.9,\n                    typical_p: 0.9,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 1.2,\n                    frequency_penalty: 0.1,\n                    watermark: true,\n                    grammar: String::new(),\n                    grammar_type: GrammarType::None as i32,\n                }),\n                stopping_parameters: Some(StoppingCriteriaParameters {\n                    max_new_tokens,\n                    stop_sequences: vec![],\n                    ignore_eos_token: true,\n                }),\n                prefill_logprobs: true,\n                top_n_tokens: 20,\n                adapter_id: None,\n            });\n            n_tokens += truncate;\n\n            // Check max_batch_size\n            if Some(requests.len()) == max_batch_size {\n                break;\n            }\n        }\n\n        let batch = Batch {\n            id: 0,\n            size: requests.len() as u32,\n            requests,\n            max_tokens: max_input_tokens.unwrap_or(0),\n            max_blocks: 0,\n        };\n\n        let request = tonic::Request::new(WarmupRequest {\n            batch: Some(batch),\n            max_input_tokens,\n            max_prefill_tokens,\n            max_total_tokens,\n        })\n        .inject_context();\n        let response = self.stub.warmup(request).await?.into_inner();\n        Ok((\n            response.max_supported_total_tokens,\n            response.max_input_tokens,\n            response.max_total_tokens,\n        ))\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = &batch.id, size = &batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n        cached_batch: Option<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let request = tonic::Request::new(PrefillRequest {\n            batch: Some(batch),\n            cached_batch,\n        })\n        .inject_context();\n        let response = self.stub.prefill(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            PrefillTimings::new(\n                response.concat_ns,\n                response.forward_ns,\n                response.decode_ns,\n                response.total_ns,\n            ),\n        ))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(|batch|{batch.size}).sum::<u32>()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let request = tonic::Request::new(DecodeRequest { batches }).inject_context();\n        let response = self.stub.decode(request).await?.into_inner();\n        Ok((\n            response.generations,\n            response.batch,\n            DecodeTimings::new(\n                response.concat_ns,\n                response.forward_ns,\n                response.decode_ns,\n                response.total_ns,\n            ),\n        ))\n    }\n}\n\npub struct PrefillTimings {\n    pub concat: Option<Duration>,\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl PrefillTimings {\n    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            concat: concat_ns.map(Duration::from_nanos),\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n\npub struct DecodeTimings {\n    pub concat: Option<Duration>,\n    pub forward: Duration,\n    pub decode: Duration,\n    pub total: Duration,\n}\n\nimpl DecodeTimings {\n    fn new(concat_ns: Option<u64>, forward_ns: u64, decode_ns: u64, total_ns: u64) -> Self {\n        Self {\n            concat: concat_ns.map(Duration::from_nanos),\n            forward: Duration::from_nanos(forward_ns),\n            decode: Duration::from_nanos(decode_ns),\n            total: Duration::from_nanos(total_ns),\n        }\n    }\n}\n"
  },
  {
    "path": "backends/v3/src/client/mod.rs",
    "content": "//! Text Generation gRPC client library\n\nuse async_trait::async_trait;\nuse thiserror::Error;\nuse tonic::transport;\nuse tonic::Status;\n\n#[allow(clippy::derive_partial_eq_without_eq)]\nmod pb;\n\nmod grpc_client;\nmod sharded_client;\n\npub use grpc_client::Client;\npub use pb::generate::v3::{\n    input_chunk::Chunk, Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,\n    HealthResponse, Image, InfoResponse, Input, InputChunk, NextTokenChooserParameters, Request,\n    StoppingCriteriaParameters,\n};\npub use sharded_client::ShardedClient;\n\n#[async_trait]\npub trait Health {\n    /// Check if a generate server is healthy by asking it to allocate a tensor on device\n    async fn device_health(&self) -> Result<()>;\n\n    /// Check if a generate server is healthy by doing a forward pass.\n    /// EXPENSIVE\n    async fn model_health(&self) -> Result<()>;\n}\n\n#[derive(Error, Debug, Clone)]\npub enum ClientError {\n    #[error(\"Could not connect to Text Generation server: {0}\")]\n    Connection(String),\n    #[error(\"Server error: {0}\")]\n    Generation(String),\n    #[error(\"Sharded results are empty\")]\n    EmptyResults,\n}\n\nimpl From<Status> for ClientError {\n    fn from(err: Status) -> Self {\n        let err = Self::Generation(err.message().to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\nimpl From<transport::Error> for ClientError {\n    fn from(err: transport::Error) -> Self {\n        let err = Self::Connection(err.to_string());\n        tracing::error!(\"{err}\");\n        err\n    }\n}\n\n// Small convenience re-wrapping of `Chunk`.\nimpl From<Chunk> for InputChunk {\n    fn from(chunk: Chunk) -> Self {\n        InputChunk { chunk: Some(chunk) }\n    }\n}\n\nstatic WARMUP_IMAGE_BASE64 :&str = \"iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAABg2lDQ1BJQ0MgcHJvZmlsZQAAKJF9kT1Iw0AcxV/TSotUROxQxCFDdbKLijjWKhShQqgVWnUwufQLmrQkKS6OgmvBwY/FqoOLs64OroIg+AHi7OCk6CIl/i8ptIjx4Lgf7+497t4BQqvKNDOQADTdMjKppJjLr4rBVwQQwhAERGVm1uckKQ3P8XUPH1/v4jzL+9yfY0AtmAzwicQJVjcs4g3imU2rznmfOMLKskp8Tjxh0AWJH7muuPzGueSwwDMjRjYzTxwhFks9rPQwKxsa8TRxTNV0yhdyLquctzhr1Qbr3JO/MFzQV5a5TnMUKSxiCRJEKGiggiosxGnVSTGRof2kh3/E8UvkUshVASPHAmrQIDt+8D/43a1ZnJp0k8JJoO/Ftj/GgOAu0G7a9vexbbdPAP8zcKV3/bUWMPtJerOrxY6AwW3g4rqrKXvA5Q4QfarLhuxIfppCsQi8n9E35YHhW6B/ze2ts4/TByBLXaVvgINDYLxE2ese7w719vbvmU5/PycecohsjayNAAAACXBIWXMAAC4jAAAuIwF4pT92AAAAB3RJTUUH6AQIEQMnlTSSjwAAABl0RVh0Q29tbWVudABDcmVhdGVkIHdpdGggR0lNUFeBDhcAAAASSURBVDjLY2AYBaNgFIyCoQsABMQAAeRw1DoAAAAASUVORK5CYII=\";\n\npub type Result<T> = std::result::Result<T, ClientError>;\n"
  },
  {
    "path": "backends/v3/src/client/sharded_client.rs",
    "content": "use crate::client::Health;\n/// Multi shard Client\nuse crate::client::{ClientError, Result};\n\nuse crate::client::grpc_client::{DecodeTimings, PrefillTimings};\nuse crate::client::{\n    Batch, CachedBatch, Client, Generation, GrammarType, HealthResponse,\n    NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\nuse crate::client::{Chunk, InfoResponse, Input};\nuse async_trait::async_trait;\nuse futures::future::join_all;\nuse tonic::transport::Uri;\nuse tracing::instrument;\n\n#[derive(Debug, Clone)]\n/// Text Generation Inference gRPC multi client\npub struct ShardedClient {\n    clients: Vec<Client>,\n}\n\nimpl ShardedClient {\n    fn new(clients: Vec<Client>) -> Self {\n        Self { clients }\n    }\n\n    /// Create a new ShardedClient from a master client. The master client will communicate with\n    /// the other shards and returns all uris/unix sockets with the `service_discovery` gRPC method.\n    async fn from_master_client(mut master_client: Client) -> Result<Self> {\n        // Get all uris/unix sockets from the master client\n        let uris = master_client.service_discovery().await?;\n        let futures = uris.into_iter().map(Client::connect_uds);\n        let clients: Result<Vec<Client>> = join_all(futures).await.into_iter().collect();\n        Ok(Self::new(clients?))\n    }\n\n    /// Returns a client connected to the given uri\n    #[allow(dead_code)]\n    pub async fn connect(uri: Uri) -> Result<Self> {\n        let master_client = Client::connect(uri).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Returns a client connected to the given unix socket\n    pub async fn connect_uds(path: String) -> Result<Self> {\n        let master_client = Client::connect_uds(path).await?;\n        Self::from_master_client(master_client).await\n    }\n\n    /// Get the model info\n    #[instrument(skip(self))]\n    pub async fn info(&mut self) -> Result<InfoResponse> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.info())\n            .collect();\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// GRPC health check\n    #[instrument(skip(self))]\n    pub async fn health(&mut self) -> Result<HealthResponse> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.health())\n            .collect();\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Clear the past generations cache\n    #[instrument(skip(self))]\n    pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| client.clear_cache(batch_id))\n            .collect();\n        join_all(futures).await.into_iter().collect()\n    }\n\n    /// Filter a cached batch\n    #[instrument(skip(self))]\n    pub async fn filter_batch(\n        &mut self,\n        batch_id: u64,\n        request_ids: Vec<u64>,\n    ) -> Result<Option<CachedBatch>> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.filter_batch(batch_id, request_ids.clone())))\n            .collect();\n        // all shards return the same message\n        join_all(futures).await.pop().unwrap()\n    }\n\n    /// Warmup on a max size batch\n    ///\n    /// Returns the maximum amount of tokens supported by the hardware\n    #[instrument(skip(self))]\n    pub async fn warmup(\n        &mut self,\n        max_input_length: Option<u32>,\n        max_prefill_tokens: u32,\n        max_total_tokens: Option<u32>,\n        max_batch_size: Option<usize>,\n    ) -> Result<(Option<u32>, u32, u32)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| {\n                Box::pin(client.warmup(\n                    max_input_length,\n                    max_prefill_tokens,\n                    max_total_tokens,\n                    max_batch_size,\n                ))\n            })\n            .collect();\n        let results = join_all(futures)\n            .await\n            .into_iter()\n            .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;\n\n        // Take the minimum value\n        // Different shards hold different parts of vocab, might yield\n        // different available block size.\n        let min = results\n            .iter()\n            .min()\n            .expect(\"Expect at least 1 warmup result\");\n        Ok(*min)\n    }\n\n    /// Generate one token for each request in the given batch\n    ///\n    /// Returns Generation for each request in batch\n    /// and the next cached batch\n    #[instrument(skip_all, fields(id = & batch.id, size = & batch.size))]\n    pub async fn prefill(\n        &mut self,\n        batch: Batch,\n        cached_batch: Option<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, PrefillTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n\n    /// Generate one token for each request in the given cached batches\n    ///\n    /// Returns Generation for each request in batches\n    /// and the next cached batch\n    #[instrument(skip_all, fields(size = batches.iter().map(| batch | {batch.size}).sum::< u32 > ()))]\n    pub async fn decode(\n        &mut self,\n        batches: Vec<CachedBatch>,\n    ) -> Result<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)> {\n        let futures: Vec<_> = self\n            .clients\n            .iter_mut()\n            .map(|client| Box::pin(client.decode(batches.clone())))\n            .collect();\n        #[allow(clippy::type_complexity)]\n        let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>, DecodeTimings)>> =\n            join_all(futures).await.into_iter().collect();\n        let mut results = results?;\n\n        let (mut generations, next_batch, mut timings) =\n            results.pop().ok_or(ClientError::EmptyResults)?;\n\n        // Merge generations from different model shards\n        for (mut shard_generations, _, shard_timings) in results.into_iter() {\n            generations.append(&mut shard_generations);\n            // Return the timings of the slowest shard\n            if shard_timings.total > timings.total {\n                timings = shard_timings;\n            }\n        }\n        Ok((generations, next_batch, timings))\n    }\n}\n\n#[async_trait]\nimpl Health for ShardedClient {\n    async fn device_health(&self) -> Result<()> {\n        self.clone().health().await?;\n        Ok(())\n    }\n\n    async fn model_health(&self) -> Result<()> {\n        // Dummy batch of 1 token and 1 generated token\n        let liveness_request = Request {\n            id: u64::MAX,\n            inputs: \"liveness\".to_string(),\n            input_chunks: Some(Input {\n                chunks: vec![Chunk::Text(\"liveness\".into()).into()],\n            }),\n            truncate: 1,\n            add_special_tokens: false,\n            prefill_logprobs: false,\n            parameters: Some(NextTokenChooserParameters {\n                temperature: 1.0,\n                top_k: 0,\n                top_p: 1.0,\n                typical_p: 1.0,\n                do_sample: false,\n                seed: 0,\n                repetition_penalty: 1.0,\n                frequency_penalty: 0.0,\n                watermark: false,\n                grammar: String::new(),\n                grammar_type: GrammarType::None as i32,\n            }),\n            stopping_parameters: Some(StoppingCriteriaParameters {\n                max_new_tokens: 1,\n                stop_sequences: vec![],\n                ignore_eos_token: false,\n            }),\n            top_n_tokens: 0,\n            // Block 0 is reserved for health checks\n            blocks: vec![0],\n            slots: vec![0],\n            cache_len: 0,\n            adapter_id: None,\n            chunk_len: None,\n        };\n        let batch = Batch {\n            id: u64::MAX,\n            requests: vec![liveness_request],\n            size: 1,\n            max_tokens: 2,\n            max_blocks: 1,\n        };\n        self.clone().prefill(batch, None).await?;\n        Ok(())\n    }\n}\n"
  },
  {
    "path": "backends/v3/src/lib.rs",
    "content": "mod backend;\npub mod block_allocator;\nmod client;\nmod queue;\npub mod radix;\n\nuse crate::client::{ClientError, ShardedClient};\npub(crate) use backend::BackendV3;\nuse serde::Serialize;\nuse thiserror::Error;\nuse utoipa::ToSchema;\n\n#[derive(Clone, Debug, Serialize, ToSchema)]\npub struct BackendInfo {\n    /// Mandatory\n    #[schema(example = \"cuda\")]\n    pub model_device_type: String,\n    #[schema(example = \"torch.float16\")]\n    pub model_dtype: String,\n\n    /// Backend parameters\n    #[schema(example = \"1\")]\n    pub speculate: usize,\n    #[schema(example = \"1.2\")]\n    pub waiting_served_ratio: f32,\n    #[schema(example = \"32000\")]\n    pub max_batch_total_tokens: u32,\n    #[schema(example = \"20\")]\n    pub max_waiting_tokens: usize,\n    #[schema(nullable = true, example = \"null\")]\n    pub max_batch_size: Option<usize>,\n    #[schema(example = \"false\")]\n    pub support_chunking: bool,\n    #[schema(example = \"false\")]\n    pub prefix_caching: bool,\n    #[schema(example = \"flashinfer\")]\n    pub attention_impl: String,\n    #[schema(example = \"1\")]\n    pub block_size: u32,\n\n    #[schema(example = \"30000\")]\n    pub max_input_tokens: usize,\n    #[schema(example = \"32000\")]\n    pub max_total_tokens: usize,\n}\n\n#[allow(clippy::too_many_arguments)]\npub async fn connect_backend(\n    max_input_tokens: Option<usize>,\n    max_total_tokens: Option<usize>,\n    master_shard_uds_path: String,\n    waiting_served_ratio: f32,\n    max_batch_prefill_tokens: u32,\n    max_batch_total_tokens: Option<u32>,\n    max_waiting_tokens: usize,\n    max_batch_size: Option<usize>,\n) -> Result<(BackendV3, BackendInfo), V3Error> {\n    // Helper function\n    let check_max_batch_total_tokens = |(\n        max_supported_batch_total_tokens,\n        shard_max_input_tokens,\n        shard_max_total_tokens,\n    ): (Option<u32>, u32, u32)|\n     -> Result<(u32, usize, usize), V3Error> {\n        if let Some(max_input_tokens) = max_input_tokens {\n            assert_eq!(max_input_tokens as u32, shard_max_input_tokens);\n        }\n        if let Some(max_total_tokens) = max_total_tokens {\n            assert_eq!(max_total_tokens as u32, shard_max_total_tokens);\n        }\n        match max_supported_batch_total_tokens {\n            // Older models do not support automatic max-batch-total-tokens\n            None => {\n                let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(\n                    16000\n                        .max(shard_max_total_tokens)\n                        .max(max_batch_prefill_tokens),\n                );\n                tracing::warn!(\"Model does not support automatic max batch total tokens\");\n                Ok((\n                    max_batch_total_tokens,\n                    shard_max_input_tokens as usize,\n                    shard_max_total_tokens as usize,\n                ))\n            }\n            // Flash attention models return their max supported total tokens\n            Some(max_supported_batch_total_tokens) => {\n                // Warn if user added his own max-batch-total-tokens as we will ignore it\n                if max_batch_total_tokens.is_some() {\n                    tracing::warn!(\n                        \"`--max-batch-total-tokens` is deprecated for Flash \\\n                        Attention models.\"\n                    );\n                    tracing::warn!(\n                        \"Inferred max batch total tokens: {max_supported_batch_total_tokens}\"\n                    );\n                }\n                if shard_max_total_tokens > max_supported_batch_total_tokens {\n                    return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));\n                }\n\n                Ok((\n                    max_supported_batch_total_tokens,\n                    shard_max_input_tokens as usize,\n                    shard_max_total_tokens as usize,\n                ))\n            }\n        }\n    };\n\n    let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)\n        .await\n        .map_err(V3Error::Connection)?;\n\n    // server is running on v3\n    // Clear the cache; useful if the webserver rebooted\n    sharded_client\n        .clear_cache(None)\n        .await\n        .map_err(V3Error::Cache)?;\n    // Get info from the shard\n    let shard_info = sharded_client.info().await.map_err(V3Error::Info)?;\n\n    // Warmup model\n    tracing::info!(\"Warming up model\");\n    let answer = sharded_client\n        .warmup(\n            max_input_tokens.map(|p| p as u32),\n            max_batch_prefill_tokens,\n            max_total_tokens.map(|p| p as u32),\n            max_batch_size,\n        )\n        .await\n        .map_err(V3Error::Warmup)?;\n    let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =\n        check_max_batch_total_tokens(answer)?;\n    tracing::info!(\"Setting max batch total tokens to {max_batch_total_tokens}\");\n    metrics::gauge!(\"tgi_batch_max_total_tokens\").set(max_batch_total_tokens);\n\n    let backend_info = BackendInfo {\n        waiting_served_ratio,\n        max_batch_total_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        model_device_type: shard_info.device_type.clone(),\n        model_dtype: shard_info.dtype.clone(),\n        speculate: shard_info.speculate as usize,\n        support_chunking: shard_info.support_chunking,\n        prefix_caching: shard_info.use_prefix_caching,\n        attention_impl: shard_info.attention_impl.clone(),\n        block_size: shard_info.block_size,\n    };\n\n    let backend = BackendV3::new(\n        sharded_client,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        shard_info,\n    );\n\n    tracing::info!(\"Using backend V3\");\n\n    Ok((backend, backend_info))\n}\n\n#[derive(Debug, Error)]\npub enum V3Error {\n    #[error(\"Unable to clear the Python model shards cache: {0}\")]\n    Cache(ClientError),\n    #[error(\"Unable to connect to the Python model shards: {0}\")]\n    Connection(ClientError),\n    #[error(\"Unable to get the Python model shards info: {0}\")]\n    Info(ClientError),\n    #[error(\"Unable to warmup the Python model shards: {0}\")]\n    Warmup(ClientError),\n    #[error(\"Not enough memory to handle `max_total_tokens={0}`\")]\n    NotEnoughMemory(usize),\n}\n"
  },
  {
    "path": "backends/v3/src/main.rs",
    "content": "use clap::{Parser, Subcommand};\nuse text_generation_router::{server, usage_stats};\nuse text_generation_router_v3::{connect_backend, V3Error};\nuse thiserror::Error;\n\n/// App Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    #[command(subcommand)]\n    command: Option<Commands>,\n\n    #[clap(default_value = \"128\", long, env)]\n    max_concurrent_requests: usize,\n    #[clap(default_value = \"2\", long, env)]\n    max_best_of: usize,\n    #[clap(default_value = \"4\", long, env)]\n    max_stop_sequences: usize,\n    #[clap(default_value = \"5\", long, env)]\n    max_top_n_tokens: u32,\n    #[clap(long, env)]\n    max_input_tokens: Option<usize>,\n    #[clap(long, env)]\n    max_total_tokens: Option<usize>,\n    #[clap(default_value = \"1.2\", long, env)]\n    waiting_served_ratio: f32,\n    #[clap(default_value = \"4096\", long, env)]\n    max_batch_prefill_tokens: u32,\n    #[clap(long, env)]\n    max_batch_total_tokens: Option<u32>,\n    #[clap(default_value = \"20\", long, env)]\n    max_waiting_tokens: usize,\n    #[clap(long, env)]\n    max_batch_size: Option<usize>,\n    #[clap(default_value = \"0.0.0.0\", long, env)]\n    hostname: String,\n    #[clap(default_value = \"3000\", long, short, env)]\n    port: u16,\n    #[clap(default_value = \"9000\", long, short, env)]\n    prometheus_port: u16,\n    #[clap(default_value = \"/tmp/text-generation-server-0\", long, env)]\n    master_shard_uds_path: String,\n    #[clap(default_value = \"bigscience/bloom\", long, env)]\n    tokenizer_name: String,\n    #[clap(long, env)]\n    tokenizer_config_path: Option<String>,\n    #[clap(long, env)]\n    revision: Option<String>,\n    #[clap(long, env, value_enum)]\n    trust_remote_code: bool,\n    #[clap(default_value = \"2\", long, env)]\n    validation_workers: usize,\n    #[clap(long, env)]\n    api_key: Option<String>,\n    #[clap(long, env)]\n    json_output: bool,\n    #[clap(long, env)]\n    otlp_endpoint: Option<String>,\n    #[clap(default_value = \"text-generation-inference.router\", long, env)]\n    otlp_service_name: String,\n    #[clap(long, env)]\n    cors_allow_origin: Option<Vec<String>>,\n    #[clap(long, env)]\n    ngrok: bool,\n    #[clap(long, env)]\n    ngrok_authtoken: Option<String>,\n    #[clap(long, env)]\n    ngrok_edge: Option<String>,\n    #[clap(long, env, default_value_t = false)]\n    disable_grammar_support: bool,\n    #[clap(default_value = \"4\", long, env)]\n    max_client_batch_size: usize,\n    #[clap(default_value = \"on\", long, env)]\n    usage_stats: usage_stats::UsageStatsLevel,\n    #[clap(default_value = \"2000000\", long, env)]\n    payload_limit: usize,\n    #[clap(default_value = \"1073741824\", long, env)]\n    max_image_fetch_size: usize,\n}\n\n#[derive(Debug, Subcommand)]\nenum Commands {\n    PrintSchema,\n}\n\n#[tokio::main]\nasync fn main() -> Result<(), RouterError> {\n    // Get args\n    let args = Args::parse();\n    // Pattern match configuration\n    let Args {\n        command,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n        hostname,\n        port,\n        prometheus_port,\n        master_shard_uds_path,\n        tokenizer_name,\n        tokenizer_config_path,\n        revision,\n        trust_remote_code,\n        validation_workers,\n        api_key,\n        json_output,\n        otlp_endpoint,\n        otlp_service_name,\n        cors_allow_origin,\n        ngrok,\n        ngrok_authtoken,\n        ngrok_edge,\n        disable_grammar_support,\n        max_client_batch_size,\n        usage_stats,\n        payload_limit,\n        max_image_fetch_size,\n    } = args;\n\n    if let Some(Commands::PrintSchema) = command {\n        use utoipa::OpenApi;\n        let api_doc = text_generation_router::server::ApiDoc::openapi();\n        let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();\n        println!(\"{}\", api_doc);\n        std::process::exit(0);\n    };\n    text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);\n\n    // Validate args\n    if validation_workers == 0 {\n        return Err(RouterError::ArgumentValidation(\n            \"`validation_workers` must be > 0\".to_string(),\n        ));\n    }\n    if let Some(max_batch_size) = max_batch_size {\n        if max_batch_size == 0 {\n            return Err(RouterError::ArgumentValidation(\n                \"`max_batch_size` must be > 0\".to_string(),\n            ));\n        }\n    }\n\n    let (backend, backend_info) = connect_backend(\n        max_input_tokens,\n        max_total_tokens,\n        master_shard_uds_path,\n        waiting_served_ratio,\n        max_batch_prefill_tokens,\n        max_batch_total_tokens,\n        max_waiting_tokens,\n        max_batch_size,\n    )\n    .await?;\n\n    // Validate remaining args now that the backend is known\n    let support_chunking = backend_info.support_chunking;\n    let max_batch_total_tokens = backend_info.max_batch_total_tokens;\n\n    if max_input_tokens.is_none() {\n        tracing::info!(\n            \"Maximum input tokens defaulted to {}\",\n            backend_info.max_input_tokens\n        );\n    }\n    if max_total_tokens.is_none() {\n        tracing::info!(\n            \"Maximum total tokens defaulted to {}\",\n            backend_info.max_total_tokens\n        );\n    }\n\n    let max_input_tokens = backend_info.max_input_tokens;\n    let max_total_tokens = backend_info.max_total_tokens;\n    if max_input_tokens >= max_total_tokens {\n        return Err(RouterError::ArgumentValidation(\n            \"`max_input_tokens` must be < `max_total_tokens`\".to_string(),\n        ));\n    }\n\n    if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {\n        return Err(RouterError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}\")));\n    }\n    if max_batch_prefill_tokens > max_batch_total_tokens {\n        return Err(RouterError::ArgumentValidation(format!(\"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}\")));\n    }\n    if max_total_tokens as u32 > max_batch_total_tokens {\n        return Err(RouterError::ArgumentValidation(format!(\"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}\")));\n    }\n\n    // Run server\n    server::run(\n        backend,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        validation_workers,\n        api_key,\n        tokenizer_name,\n        tokenizer_config_path,\n        revision,\n        trust_remote_code,\n        hostname,\n        port,\n        cors_allow_origin,\n        ngrok,\n        ngrok_authtoken,\n        ngrok_edge,\n        disable_grammar_support,\n        max_client_batch_size,\n        usage_stats,\n        payload_limit,\n        max_image_fetch_size,\n        prometheus_port,\n    )\n    .await?;\n    Ok(())\n}\n\n#[derive(Debug, Error)]\nenum RouterError {\n    #[error(\"Argument validation error: {0}\")]\n    ArgumentValidation(String),\n    #[error(\"Backend failed: {0}\")]\n    Backend(#[from] V3Error),\n    #[error(\"WebServer error: {0}\")]\n    WebServer(#[from] server::WebServerError),\n    #[error(\"Tokio runtime failed to start: {0}\")]\n    Tokio(#[from] std::io::Error),\n}\n"
  },
  {
    "path": "backends/v3/src/queue.rs",
    "content": "use crate::block_allocator::{BlockAllocation, BlockAllocator};\nuse crate::client;\nuse crate::client::{\n    Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,\n};\nuse nohash_hasher::{BuildNoHashHasher, IntMap};\nuse std::cmp::max;\nuse std::collections::VecDeque;\nuse text_generation_router::infer::InferError;\nuse text_generation_router::infer::InferStreamResponse;\nuse text_generation_router::usage_stats::Env;\nuse text_generation_router::validation::{\n    Chunk, ChunksToString, ValidGenerateRequest, ValidGrammar, ValidParameters,\n    ValidStoppingParameters,\n};\nuse tokio::sync::{mpsc, oneshot};\nuse tokio::time::Instant;\nuse tracing::{info_span, instrument, Instrument, Span};\n\n/// Queue entry\n#[derive(Debug)]\npub(crate) struct Entry {\n    /// Request\n    pub request: ValidGenerateRequest,\n    /// Response sender to communicate between the Infer struct and the batching_task\n    pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,\n    /// Span that will live as long as entry\n    pub span: Span,\n    /// Temporary span used as a guard when logging inference, wait times...\n    pub temp_span: Option<Span>,\n    /// Instant when this entry was queued\n    pub queue_time: Instant,\n    /// Instant when this entry was added to a batch\n    pub batch_time: Option<Instant>,\n    /// Block Allocation\n    pub block_allocation: Option<BlockAllocation>,\n}\n\n/// Request Queue\n#[derive(Debug, Clone)]\npub(crate) struct Queue {\n    /// Channel to communicate with the background queue task\n    queue_sender: mpsc::UnboundedSender<QueueCommand>,\n}\n\nimpl Queue {\n    pub(crate) fn new(\n        requires_padding: bool,\n        block_size: u32,\n        prefix_caching: bool,\n        window_size: Option<u32>,\n        speculate: u32,\n        max_batch_total_tokens: u32,\n        support_chunking: bool,\n    ) -> Self {\n        // Create channel\n        let (queue_sender, queue_receiver) = mpsc::unbounded_channel();\n\n        // Launch background queue task\n        tokio::spawn(queue_task(\n            requires_padding,\n            block_size,\n            prefix_caching,\n            window_size,\n            speculate,\n            max_batch_total_tokens,\n            support_chunking,\n            queue_receiver,\n        ));\n\n        Self { queue_sender }\n    }\n\n    /// Append an entry to the queue\n    #[instrument(skip_all)]\n    pub(crate) fn append(&self, entry: Entry) {\n        // Send append command to the background task managing the state\n        // Unwrap is safe here\n        self.queue_sender\n            .send(QueueCommand::Append(Box::new(entry), Span::current()))\n            .unwrap();\n    }\n\n    // Get the next batch\n    #[instrument(skip(self))]\n    pub(crate) async fn next_batch(\n        &self,\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n    ) -> Option<NextBatch> {\n        if prefill_token_budget == 0 || token_budget == 0 {\n            return None;\n        };\n\n        // Create response channel\n        let (response_sender, response_receiver) = oneshot::channel();\n        // Send next batch command to the background task managing the state\n        // Unwrap is safe here\n        self.queue_sender\n            .send(QueueCommand::NextBatch {\n                min_size,\n                max_size,\n                prefill_token_budget,\n                token_budget,\n                response_sender,\n                span: Span::current(),\n            })\n            .unwrap();\n        // Await on response channel\n        // Unwrap is safe here\n        response_receiver.await.unwrap()\n    }\n}\n\n// Background task responsible of the queue state\n#[allow(clippy::too_many_arguments)]\nasync fn queue_task(\n    requires_padding: bool,\n    block_size: u32,\n    prefix_caching: bool,\n    window_size: Option<u32>,\n    speculate: u32,\n    max_batch_total_tokens: u32,\n    support_chunking: bool,\n    mut receiver: mpsc::UnboundedReceiver<QueueCommand>,\n) {\n    let mut state = State::new(\n        requires_padding,\n        block_size,\n        prefix_caching,\n        window_size,\n        speculate,\n        max_batch_total_tokens,\n        support_chunking,\n    );\n\n    while let Some(cmd) = receiver.recv().await {\n        match cmd {\n            QueueCommand::Append(entry, span) => {\n                span.in_scope(|| state.append(*entry));\n                metrics::gauge!(\"tgi_queue_size\").increment(1.0);\n            }\n            QueueCommand::NextBatch {\n                min_size,\n                max_size,\n                prefill_token_budget,\n                token_budget,\n                response_sender,\n                span,\n            } => {\n                let next_batch = state\n                    .next_batch(min_size, max_size, prefill_token_budget, token_budget)\n                    .instrument(span)\n                    .await;\n                response_sender.send(next_batch).unwrap();\n                metrics::gauge!(\"tgi_queue_size\").set(state.entries.len() as f64);\n            }\n        }\n    }\n}\n\n/// Queue State\n#[derive(Debug)]\nstruct State {\n    /// Queue entries organized in a Vec\n    entries: VecDeque<(u64, Entry)>,\n\n    /// Id of the next entry\n    next_id: u64,\n\n    /// Id of the next batch\n    next_batch_id: u64,\n\n    /// Paged Attention block size\n    block_size: u32,\n\n    /// Speculation amount\n    speculate: u32,\n\n    /// Whether the model allow the prefill chunking\n    /// If it does, the last request in the batch will be split to exactly match the prefill\n    /// token budget\n    support_chunking: bool,\n\n    /// Paged Attention Block Allocation\n    block_allocator: Option<BlockAllocator>,\n\n    /// indicate if it's hpu device, the hpu device needs padding to generate first token.\n    is_hpu_device: bool,\n}\n\nimpl State {\n    fn new(\n        requires_padding: bool,\n        block_size: u32,\n        prefix_caching: bool,\n        window_size: Option<u32>,\n        speculate: u32,\n        max_batch_total_tokens: u32,\n        support_chunking: bool,\n    ) -> Self {\n        let block_allocator = (!requires_padding).then(|| {\n            BlockAllocator::new(\n                max_batch_total_tokens,\n                block_size,\n                prefix_caching,\n                window_size,\n            )\n        });\n\n        Self {\n            entries: VecDeque::with_capacity(128),\n            next_id: 0,\n            next_batch_id: 0,\n            block_size,\n            speculate,\n            support_chunking,\n            block_allocator,\n            is_hpu_device: Env::new().is_hpu_device(),\n        }\n    }\n\n    /// Append an entry to the queue\n    fn append(&mut self, mut entry: Entry) {\n        // Create a span that will live as long as the entry is in the queue waiting to be batched\n        let queue_span = info_span!(parent: &entry.span, \"queued\");\n        entry.temp_span = Some(queue_span);\n\n        // Push entry in the queue\n        self.entries.push_back((self.next_id, entry));\n        self.next_id += 1;\n    }\n\n    // Get the next batch\n    async fn next_batch(\n        &mut self,\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n    ) -> Option<NextBatch> {\n        if self.entries.is_empty() {\n            tracing::debug!(\"No queue\");\n            return None;\n        }\n\n        // Check if we have enough entries\n        if let Some(min_size) = min_size {\n            if self.entries.len() < min_size {\n                tracing::debug!(\"Not enough entries\");\n                return None;\n            }\n        }\n\n        if let Some(max_size) = max_size {\n            if max_size == 0 {\n                tracing::debug!(\"No capacity\");\n                return None;\n            }\n        }\n\n        // Pad prefill_token_budget to be a multiple of block size\n        let prefill_token_budget = prefill_token_budget.div_ceil(self.block_size) * self.block_size;\n\n        // Create span for this batch to add context to inference calls\n        let next_batch_span = info_span!(parent: None, \"batch\", batch_size = tracing::field::Empty);\n        next_batch_span.follows_from(Span::current());\n\n        let mut batch = Vec::with_capacity(self.entries.len());\n        let mut max_input_length = 0;\n        let mut prefill_tokens: u32 = 0;\n        let mut decode_tokens: u32 = 0;\n        let mut max_blocks = 0;\n\n        // Pop entries starting from the front of the queue\n        'entry_loop: while let Some((id, entry)) = self.entries.pop_front() {\n            // Filter entries where the response receiver was dropped (== entries where the request\n            // was dropped by the client)\n            if entry.response_tx.is_closed() {\n                metrics::counter!(\"tgi_request_failure\", \"err\" => \"dropped\").increment(1);\n                tracing::debug!(\"Dropping entry\");\n                continue;\n            }\n\n            let block_allocation = match &self.block_allocator {\n                None => {\n                    // We pad to max input length in the Python shards\n                    // We need to take these padding tokens into the equation\n                    max_input_length = max_input_length.max(entry.request.input_length);\n                    prefill_tokens = (batch.len() + 1) as u32 * max_input_length;\n\n                    decode_tokens += entry.request.stopping_parameters.max_new_tokens;\n                    let total_tokens = prefill_tokens + decode_tokens + self.speculate;\n\n                    if prefill_tokens > prefill_token_budget || total_tokens > token_budget {\n                        // Entry is over budget\n                        // Add it back to the front\n                        tracing::debug!(\"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}\", self.speculate);\n                        self.entries.push_front((id, entry));\n                        break 'entry_loop;\n                    }\n                    None\n                }\n                Some(block_allocator) => {\n                    // If users wants the prefill logprobs, we cannot reuse the cache.\n                    // So no input_ids for the radix tree.\n                    let input_ids = if entry.request.decoder_input_details {\n                        None\n                    } else {\n                        entry.request.input_ids.clone()\n                    };\n\n                    let tokens = entry.request.input_length\n                        + entry.request.stopping_parameters.max_new_tokens\n                        + self.speculate\n                        - 1;\n                    // tracing::debug!(\"Allocating {tokens} with {input_ids:?}\");\n\n                    let block_allocation = match block_allocator.allocate(tokens, input_ids).await {\n                        None => {\n                            // Entry is over budget\n                            // Add it back to the front\n                            tracing::debug!(\"Over budget: not enough free blocks\");\n                            self.entries.push_front((id, entry));\n                            break 'entry_loop;\n                        }\n                        Some(mut block_allocation) => {\n                            // tracing::debug!(\"Allocation: {block_allocation:?}\");\n                            max_blocks = max(max_blocks, block_allocation.blocks.len() as u32);\n\n                            if block_allocation.prefix_len == entry.request.input_length {\n                                // The whole request was found in the radix trie\n                                // However, for the transformer forward to work, we need to\n                                // have at least one token of postfix.\n                                block_allocation.prefix_len -= 1;\n                            }\n\n                            block_allocation\n                        }\n                    };\n\n                    let postfix_len = entry.request.input_length - block_allocation.prefix_len;\n\n                    if prefill_tokens + postfix_len > prefill_token_budget {\n                        // Entry is over budget\n                        if self.support_chunking {\n                            // We support chunking, just set postfix_len to exactly match prefill_token_budget\n                            let chunk_len = prefill_token_budget.saturating_sub(prefill_tokens);\n                            if chunk_len > 0 {\n                                // Push this entry inside the batch\n                                batch.push((id, entry, Some(block_allocation), Some(chunk_len)));\n                            } else {\n                                // We cannot prefill even one token for this entry\n                                // Add it back to the queue\n                                self.entries.push_front((id, entry));\n                            }\n                            tracing::debug!(\n                                \"Matched budget: prefill_tokens={} == {prefill_token_budget}\",\n                                prefill_tokens + postfix_len\n                            );\n                            break 'entry_loop;\n                        } else {\n                            // We don't support chunking, this entry needs to go back to the buffer\n                            // Add it back to the front\n                            tracing::debug!(\n                                \"Over budget: prefill_tokens={} > {prefill_token_budget}\",\n                                prefill_tokens + postfix_len\n                            );\n                            self.entries.push_front((id, entry));\n                            break 'entry_loop;\n                        }\n                    }\n\n                    if self.is_hpu_device {\n                        //HPU needs to pad for the prefill\n                        max_input_length = max_input_length.max(entry.request.input_length);\n                        let actual_prefill_tokens_for_hpu =\n                            (batch.len() + 1) as u32 * max_input_length;\n\n                        if actual_prefill_tokens_for_hpu > prefill_token_budget {\n                            // Entry is over budget\n                            // Add it back to the front\n                            tracing::debug!(\"Over budget: prefill_tokens={actual_prefill_tokens_for_hpu} > {prefill_token_budget}\");\n                            self.entries.push_front((id, entry));\n                            break 'entry_loop;\n                        }\n                    }\n\n                    prefill_tokens += postfix_len;\n\n                    Some(block_allocation)\n                }\n            };\n            batch.push((id, entry, block_allocation, None));\n            if Some(batch.len()) == max_size {\n                break;\n            }\n        }\n\n        // Empty batch\n        if batch.is_empty() {\n            tracing::debug!(\"Filterered out all entries\");\n            return None;\n        }\n\n        // XXX We haven't allocated yet, so we're allowed to ditch the results.\n        // Check if our batch is big enough\n        if let Some(min_size) = min_size {\n            // Batch is too small\n            if batch.len() < min_size {\n                // Add back entries to the queue in the correct order\n                for (id, entry, _, _) in batch.into_iter().rev() {\n                    self.entries.push_front((id, entry));\n                }\n                return None;\n            }\n        }\n\n        let mut batch_requests = Vec::with_capacity(self.entries.len());\n        let mut batch_entries =\n            IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());\n\n        for (id, mut entry, block_allocation, chunk_len) in batch {\n            // Create a new span to link the batch back to this entry\n            let entry_batch_span = info_span!(parent: &entry.span, \"infer\");\n            // Add relationships\n            next_batch_span.follows_from(&entry_batch_span);\n            entry_batch_span.follows_from(&next_batch_span);\n            // Update entry\n            entry.temp_span = Some(entry_batch_span);\n\n            let (blocks, slots, prefix_len) = match &block_allocation {\n                None => (Vec::new(), Vec::new(), 0),\n                Some(block_allocation) => (\n                    block_allocation.blocks.clone(),\n                    block_allocation.slots.clone(),\n                    block_allocation.prefix_len,\n                ),\n            };\n\n            entry.block_allocation = block_allocation;\n\n            batch_requests.push(Request {\n                id,\n                prefill_logprobs: entry.request.decoder_input_details,\n                input_chunks: Some(client::Input {\n                    chunks: entry\n                        .request\n                        .inputs\n                        .clone()\n                        .into_iter()\n                        .map(|c| client::InputChunk {\n                            chunk: Some(match c {\n                                Chunk::Text(text) => client::Chunk::Text(text),\n                                Chunk::Image(image) => client::Chunk::Image(client::Image {\n                                    data: image.data,\n                                    mimetype: image.mimetype,\n                                }),\n                            }),\n                        })\n                        .collect(),\n                }),\n                inputs: entry.request.inputs.chunks_to_string(),\n                truncate: entry.request.truncate,\n                add_special_tokens: entry.request.add_special_tokens,\n                parameters: Some(NextTokenChooserParameters::from(\n                    entry.request.parameters.clone(),\n                )),\n                stopping_parameters: Some(StoppingCriteriaParameters::from(\n                    entry.request.stopping_parameters.clone(),\n                )),\n                top_n_tokens: entry.request.top_n_tokens,\n                blocks,\n                slots,\n                cache_len: prefix_len,\n                adapter_id: entry.request.adapter_id.clone(),\n                chunk_len,\n            });\n            // Set batch_time\n            entry.batch_time = Some(Instant::now());\n            // Insert in batch_entries IntMap\n            batch_entries.insert(id, entry);\n        }\n\n        // Final batch size\n        let size = batch_requests.len() as u32;\n        next_batch_span.record(\"batch_size\", size);\n\n        let batch = Batch {\n            id: self.next_batch_id,\n            requests: batch_requests,\n            size,\n            max_tokens: (prefill_tokens + decode_tokens),\n            max_blocks,\n        };\n        // Increment batch id\n        self.next_batch_id += 1;\n\n        metrics::histogram!(\"tgi_batch_next_size\").record(batch.size as f64);\n\n        Some((batch_entries, batch, next_batch_span))\n    }\n}\n\ntype NextBatch = (IntMap<u64, Entry>, Batch, Span);\n\n#[derive(Debug)]\nenum QueueCommand {\n    Append(Box<Entry>, Span),\n    NextBatch {\n        min_size: Option<usize>,\n        max_size: Option<usize>,\n        prefill_token_budget: u32,\n        token_budget: u32,\n        response_sender: oneshot::Sender<Option<NextBatch>>,\n        span: Span,\n    },\n}\n\nimpl From<ValidParameters> for NextTokenChooserParameters {\n    fn from(value: ValidParameters) -> Self {\n        let (grammar, grammar_type) = match value.grammar {\n            None => (String::new(), GrammarType::None),\n\n            Some(grammar) => match grammar {\n                ValidGrammar::Json(grammar_string) => (grammar_string, GrammarType::Json),\n                ValidGrammar::Regex(grammar_string) => (grammar_string, GrammarType::Regex),\n            },\n        };\n\n        Self {\n            temperature: value.temperature,\n            top_k: value.top_k,\n            top_p: value.top_p,\n            typical_p: value.typical_p,\n            do_sample: value.do_sample,\n            seed: value.seed,\n            repetition_penalty: value.repetition_penalty,\n            frequency_penalty: value.frequency_penalty,\n            watermark: value.watermark,\n            grammar,\n            grammar_type: grammar_type.into(),\n        }\n    }\n}\n\nimpl From<ValidStoppingParameters> for StoppingCriteriaParameters {\n    fn from(value: ValidStoppingParameters) -> Self {\n        Self {\n            max_new_tokens: value.max_new_tokens,\n            stop_sequences: value.stop_sequences,\n            ignore_eos_token: value.ignore_eos_token,\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use std::sync::Arc;\n\n    use super::*;\n    use tracing::info_span;\n\n    fn default_entry() -> (\n        Entry,\n        mpsc::UnboundedReceiver<Result<InferStreamResponse, InferError>>,\n    ) {\n        let (response_tx, receiver_tx) = mpsc::unbounded_channel();\n\n        let entry = Entry {\n            request: ValidGenerateRequest {\n                inputs: vec![],\n                input_ids: Some(Arc::new(vec![])),\n                input_length: 1,\n                add_special_tokens: true,\n                truncate: 0,\n                decoder_input_details: false,\n                parameters: ValidParameters {\n                    temperature: 0.0,\n                    top_k: 0,\n                    top_p: 0.0,\n                    typical_p: 0.0,\n                    do_sample: false,\n                    seed: 0,\n                    repetition_penalty: 0.0,\n                    frequency_penalty: 0.0,\n                    watermark: false,\n                    grammar: None,\n                },\n                stopping_parameters: ValidStoppingParameters {\n                    ignore_eos_token: false,\n                    max_new_tokens: 1,\n                    max_total_new_tokens: 1024,\n                    stop_sequences: vec![],\n                },\n                top_n_tokens: 0,\n                adapter_id: None,\n            },\n            response_tx,\n            span: info_span!(\"entry\"),\n            temp_span: None,\n            queue_time: Instant::now(),\n            batch_time: None,\n            block_allocation: None,\n        };\n        (entry, receiver_tx)\n    }\n\n    #[tokio::test]\n    async fn test_append() {\n        let mut state = State::new(false, 1, false, None, 0, 16, false);\n        let (entry, _guard) = default_entry();\n\n        assert_eq!(state.next_id, 0);\n        assert_eq!(state.entries.len(), 0);\n\n        state.append(entry);\n\n        assert_eq!(state.next_id, 1);\n        assert_eq!(state.entries.len(), 1);\n        let (id, _) = state.entries.remove(0).unwrap();\n        assert_eq!(id, 0);\n    }\n\n    #[tokio::test]\n    async fn test_next_batch_empty() {\n        let mut state = State::new(false, 1, false, None, 0, 16, false);\n\n        assert!(state.next_batch(None, None, 1, 1).await.is_none());\n        assert!(state.next_batch(Some(1), None, 1, 1).await.is_none());\n    }\n\n    #[tokio::test]\n    async fn test_next_batch_min_size() {\n        let mut state = State::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, None, 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert!(entries.get(&1).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 0);\n        assert_eq!(state.next_batch_id, 1);\n\n        let (entry3, _guard3) = default_entry();\n        state.append(entry3);\n\n        assert!(state.next_batch(Some(2), None, 2, 2).await.is_none());\n\n        assert_eq!(state.next_id, 3);\n        assert_eq!(state.entries.len(), 1);\n        let (id, _) = state.entries.remove(0).unwrap();\n        assert_eq!(id, 2);\n    }\n\n    #[tokio::test]\n    async fn test_next_batch_max_size() {\n        let mut state = State::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, Some(1), 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 1);\n        assert_eq!(state.next_batch_id, 1);\n    }\n\n    #[tokio::test]\n    async fn test_next_batch_token_budget() {\n        let mut state = State::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        state.append(entry1);\n        state.append(entry2);\n\n        let (entries, batch, _) = state.next_batch(None, None, 1, 1).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        assert_eq!(state.next_id, 2);\n        assert_eq!(state.entries.len(), 1);\n        assert_eq!(state.next_batch_id, 1);\n\n        let (entry3, _guard3) = default_entry();\n        state.append(entry3);\n\n        let (entries, batch, _) = state.next_batch(None, None, 3, 3).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&1));\n        assert!(entries.contains_key(&2));\n        assert_eq!(batch.id, 1);\n        assert_eq!(batch.size, 2);\n\n        assert_eq!(state.next_id, 3);\n        assert_eq!(state.entries.len(), 0);\n        assert_eq!(state.next_batch_id, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_append() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n        let (entry, _guard) = default_entry();\n        queue.append(entry);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_empty() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n        assert!(queue.next_batch(Some(1), None, 1, 1).await.is_none());\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_min_size() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert!(entries.get(&1).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n\n        let (entry3, _guard3) = default_entry();\n        queue.append(entry3);\n\n        // Not enough requests pending\n        assert!(queue.next_batch(Some(2), None, 2, 2).await.is_none());\n        // Not enough token budget\n        assert!(queue.next_batch(Some(1), None, 0, 0).await.is_none());\n        // Ok\n        let (entries2, batch2, _) = queue.next_batch(Some(1), None, 2, 2).await.unwrap();\n        assert_eq!(entries2.len(), 1);\n        assert!(entries2.contains_key(&2));\n        assert!(entries2.get(&2).unwrap().batch_time.is_some());\n        assert_eq!(batch2.id, 1);\n        assert_eq!(batch2.size, 1);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_max_size() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, Some(1), 2, 2).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert!(entries.get(&0).unwrap().batch_time.is_some());\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_token_budget() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 1, 1).await.unwrap();\n        assert_eq!(entries.len(), 1);\n        assert!(entries.contains_key(&0));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 1);\n\n        let (entry3, _guard3) = default_entry();\n        queue.append(entry3);\n\n        let (entries, batch, _) = queue.next_batch(None, None, 3, 3).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&1));\n        assert!(entries.contains_key(&2));\n        assert_eq!(batch.id, 1);\n        assert_eq!(batch.size, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_token_speculate() {\n        let queue = Queue::new(true, 1, false, None, 2, 16, false);\n        let (entry1, _guard1) = default_entry();\n        let (entry2, _guard2) = default_entry();\n        queue.append(entry1);\n        queue.append(entry2);\n\n        // Budget of 1 is not enough\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n\n        let (entries, batch, _) = queue.next_batch(None, None, 6, 6).await.unwrap();\n        assert_eq!(entries.len(), 2);\n        assert!(entries.contains_key(&0));\n        assert!(entries.contains_key(&1));\n        assert_eq!(batch.id, 0);\n        assert_eq!(batch.size, 2);\n    }\n\n    #[tokio::test]\n    async fn test_queue_next_batch_dropped_receiver() {\n        let queue = Queue::new(false, 1, false, None, 0, 16, false);\n        let (entry, _) = default_entry();\n        queue.append(entry);\n\n        assert!(queue.next_batch(None, None, 1, 1).await.is_none());\n    }\n}\n"
  },
  {
    "path": "backends/v3/src/radix.rs",
    "content": "use crate::block_allocator::{Allocator, BlockAllocation};\nuse slotmap::{DefaultKey, SlotMap};\nuse std::hash::{Hash, Hasher};\nuse std::{\n    collections::{BTreeSet, HashMap},\n    sync::Arc,\n};\n\nfn hash(slice: &[u32]) -> u64 {\n    assert!(!slice.is_empty());\n    if slice.len() == 1 {\n        slice[0] as u64\n    } else {\n        let mut s = std::hash::DefaultHasher::new();\n        slice.hash(&mut s);\n        s.finish()\n    }\n}\n\npub struct RadixAllocator {\n    allocation_id: u64,\n\n    allocations: HashMap<u64, RadixAllocation>,\n\n    cache_blocks: RadixTrie,\n\n    /// Blocks that are immediately available for allocation.\n    free_blocks: Vec<u32>,\n\n    #[allow(dead_code)]\n    // This isn't used because the prefix need to match without the windowing\n    // mecanism. This at worst is overallocating, not necessarily being wrong.\n    window_size: Option<u32>,\n\n    block_size: u32,\n}\n\nimpl RadixAllocator {\n    pub fn new(block_size: u32, n_blocks: u32, window_size: Option<u32>) -> Self {\n        RadixAllocator {\n            allocation_id: 0,\n            allocations: HashMap::new(),\n            cache_blocks: RadixTrie::new(block_size as usize),\n\n            // Block 0 is reserved for health checks.\n            free_blocks: (1..n_blocks).collect(),\n            window_size,\n            block_size,\n        }\n    }\n\n    fn alloc_or_reclaim(&mut self, n_blocks_needed: usize) -> Option<Vec<u32>> {\n        if self.free_blocks.len() < n_blocks_needed {\n            // This is a bit annoying, we first extend the free list and then\n            // split it off again below. This is because we need to put it on\n            // the free list if we cannot allocate enough blocks. This is only\n            // temporary, the trie needs to be able to report whether it can\n            // allocate the requested amount. Just not implemented yet.\n            tracing::debug!(\n                \"Free blocks {}  need {n_blocks_needed}\",\n                self.free_blocks.len()\n            );\n            self.free_blocks.extend(\n                self.cache_blocks\n                    .evict(n_blocks_needed - self.free_blocks.len()),\n            );\n        }\n\n        if self.free_blocks.len() >= n_blocks_needed {\n            Some(\n                self.free_blocks\n                    .split_off(self.free_blocks.len() - n_blocks_needed),\n            )\n        } else {\n            None\n        }\n    }\n}\n\n// Allocator trait\nimpl Allocator for RadixAllocator {\n    fn allocate(\n        &mut self,\n        tokens: u32,\n        prefill_tokens: Option<Arc<Vec<u32>>>,\n    ) -> Option<BlockAllocation> {\n        let mut blocks = vec![];\n        let prefix_node = if let Some(prefill_tokens) = prefill_tokens.as_ref() {\n            let node_id = self\n                .cache_blocks\n                .find(prefill_tokens.as_slice(), &mut blocks);\n            node_id\n        } else {\n            self.cache_blocks.root_id()\n        };\n\n        // Even if this allocation fails below, we need to increase he\n        // refcount to ensure that the prefix that was found is not evicted.\n        self.cache_blocks\n            .incref(prefix_node)\n            .expect(\"Failed to increment refcount\");\n\n        let prefix_len = blocks.len() * self.block_size as usize;\n        let suffix_len = tokens - prefix_len as u32;\n\n        let suffix_blocks = suffix_len.div_ceil(self.block_size);\n\n        tracing::info!(\"Prefix {prefix_len} - Suffix {suffix_len}\");\n\n        match self.alloc_or_reclaim(suffix_blocks as usize) {\n            Some(suffix_blocks) => blocks.extend(suffix_blocks),\n            None => {\n                tracing::debug!(\"Cannot allocate {:?}\", self.cache_blocks);\n                tracing::debug!(\"Found {prefix_len} prefix tokens need {suffix_blocks} suffix blocks for {tokens} tokens\");\n                tracing::debug!(\"Block size {}\", self.block_size);\n                self.cache_blocks\n                    .decref(prefix_node)\n                    .expect(\"Failed to decrement refcount\");\n                return None;\n            }\n        }\n\n        // 1:1 mapping of blocks and slots.\n        let slots = if self.block_size == 1 {\n            blocks.clone()\n        } else {\n            let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize);\n            'slots: for block_id in &blocks {\n                for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) {\n                    slots.push(s);\n                    if slots.len() as u32 == tokens {\n                        break 'slots;\n                    }\n                }\n            }\n            slots\n        };\n\n        let allocation = RadixAllocation {\n            prefix_node,\n            cached_prefix_len: prefix_len,\n            prefill_tokens: prefill_tokens.clone(),\n        };\n\n        self.allocation_id += 1;\n        self.allocations.insert(self.allocation_id, allocation);\n\n        Some(BlockAllocation {\n            allocation_id: self.allocation_id,\n            block_allocator: None,\n            blocks,\n            slots,\n            prefix_len: prefix_len as u32,\n        })\n    }\n\n    fn free(&mut self, blocks: Vec<u32>, allocation_id: u64) {\n        let allocation = match self.allocations.remove(&allocation_id) {\n            Some(allocation) => allocation,\n            None => unreachable!(\"Tried to free an unknown allocation.\"),\n        };\n\n        self.cache_blocks\n            .decref(allocation.prefix_node)\n            .expect(\"Failed to decrement refcount\");\n\n        if let Some(prefill_tokens) = allocation.prefill_tokens {\n            let prefill_tokens = prefill_tokens.as_slice();\n\n            // If there are prefill tokens that did not come from the cache,\n            // add them to the cache.\n            if prefill_tokens.len() > allocation.cached_prefix_len {\n                let aligned =\n                    (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize;\n                if aligned > 0 {\n                    let prefix_len = self\n                        .cache_blocks\n                        .insert(\n                            &prefill_tokens[..aligned],\n                            &blocks[..aligned / self.block_size as usize],\n                        )\n                        // Unwrap, failing is a programming error.\n                        .expect(\"Failed to store prefill tokens\");\n                    // We can have a prefill with the following structure:\n                    //\n                    // |---| From the prefix cache.\n                    // A B C D E F G\n                    //|--------| Found in the trie during insertion.\n                    //\n                    // This means that while processing this request there was a\n                    // partially overlapping request that had A..=E in its\n                    // prefill. In this case we need to free the blocks D E.\n                    if prefix_len > allocation.cached_prefix_len {\n                        self.free_blocks.extend(\n                            &blocks[allocation.cached_prefix_len / self.block_size as usize\n                                ..prefix_len / self.block_size as usize],\n                        );\n                    }\n                }\n            }\n\n            // Free non-prefill blocks.\n            self.free_blocks\n                .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]);\n        } else {\n            self.free_blocks.extend(blocks);\n        }\n    }\n}\n\nstruct RadixAllocation {\n    prefix_node: NodeId,\n    cached_prefix_len: usize,\n    prefill_tokens: Option<Arc<Vec<u32>>>,\n}\n\n// Radix trie that is heavily inspired by radix attention from sglang.\n//\n// The trie is optimized for prefix caching:\n//\n// - A normal radix trie stores discrete values. In this radix trie,\n//   inserting *abc* with value *xyz* will also enable lookup for\n//   *a* (*x*) and *ab* (*xy*).\n// - As a result, every value is required to have the same length as\n//   the key.\n// - We store additional information in each node, such as last access\n//   time and a reference count.\n\n#[derive(Debug)]\npub enum TrieError {\n    InvalidNodeId,\n    RefCountUnderflow,\n}\n\npub type NodeId = DefaultKey;\n\n#[derive(Debug)]\npub struct RadixTrie {\n    /// Identifier of the root nod.\n    root: DefaultKey,\n\n    /// Leave node identifiers ordered by increasing recency.\n    leaves: BTreeSet<(u64, NodeId)>,\n\n    /// All trie nodes.\n    nodes: SlotMap<NodeId, TrieNode>,\n\n    /// Time as a monotonically increating counter to avoid the system\n    /// call that a real time lookup would require.\n    time: u64,\n\n    /// All blocks need to be aligned with this\n    block_size: usize,\n}\n\nimpl RadixTrie {\n    /// Construct a new radix trie.\n    pub fn new(block_size: usize) -> Self {\n        let root = TrieNode::new(vec![], vec![], 0, None);\n        let mut nodes = SlotMap::new();\n        let root = nodes.insert(root);\n        RadixTrie {\n            leaves: BTreeSet::new(),\n            nodes,\n            root,\n            time: 0,\n            block_size,\n        }\n    }\n\n    /// Find the prefix of the given tokens.\n    ///\n    /// The blocks corresponding to the part of the prefix that could be found\n    /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`.\n    /// Returns the identifier of the trie node that contains the longest\n    /// prefix. The node identifier can be used by callers to e.g. increase its\n    /// reference count.\n    ///\n    /// Using this method will update the access time of the traversed nodes.\n    pub fn find(&mut self, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {\n        self.time += 1;\n        self.find_(self.root, key, blocks)\n    }\n\n    /// Find worker.\n    fn find_(&mut self, node_id: NodeId, key: &[u32], blocks: &mut Vec<u32>) -> NodeId {\n        let node = &self.nodes[node_id];\n\n        if key.len() >= self.block_size {\n            let node_key = hash(&key[..self.block_size]);\n            if let Some(&child_id) = node.children.get(&node_key) {\n                self.update_access_time(child_id);\n                let child = self.nodes.get(child_id).expect(\"Invalid child identifier\");\n                let shared_prefix_len = shared_prefix(&child.key, key, self.block_size);\n                assert_eq!(shared_prefix_len % self.block_size, 0);\n                blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]);\n\n                // A node represents the prefix of its children. So, only\n                // recurse when there is a full prefix match.\n                let key = &key[shared_prefix_len..];\n                if !key.is_empty() && shared_prefix_len == child.key.len() {\n                    return self.find_(child_id, key, blocks);\n                } else {\n                    return child_id;\n                }\n            }\n        }\n\n        node_id\n    }\n\n    /// Decrease the reference count of a node.\n    pub fn decref(&mut self, node_id: NodeId) -> Result<(), TrieError> {\n        // We don't care about refcounting for root, since it will never\n        // be evicted.\n        if node_id == self.root {\n            return Ok(());\n        }\n\n        let node = self\n            .nodes\n            .get_mut(node_id)\n            .ok_or(TrieError::InvalidNodeId)?;\n        if node.ref_count == 0 {\n            return Err(TrieError::RefCountUnderflow);\n        }\n\n        node.ref_count -= 1;\n        if node.ref_count == 0 {\n            assert!(\n                node.children.is_empty(),\n                \"Nodes with children must have refcount > 0\"\n            );\n\n            self.leaves.insert((node.last_accessed, node_id));\n        }\n\n        Ok(())\n    }\n\n    /// Increase the reference count of a node.\n    pub fn incref(&mut self, node_id: NodeId) -> Result<(), TrieError> {\n        if node_id == self.root {\n            return Ok(());\n        }\n\n        let node = self\n            .nodes\n            .get_mut(node_id)\n            .ok_or(TrieError::InvalidNodeId)?;\n        if node.ref_count == 0 {\n            self.leaves.remove(&(node.last_accessed, node_id));\n        }\n        node.ref_count += 1;\n\n        Ok(())\n    }\n\n    /// Evict `n_blocks` from the trie.\n    ///\n    /// Returns the evicted blocks. When the length is less than `n_blocks`,\n    /// not enough blocks could be evicted.\n    pub fn evict(&mut self, n_blocks: usize) -> Vec<u32> {\n        // NOTE: we don't return Result here. If any of the unwrapping fails,\n        // it's a programming error in the trie implementation, not a user\n        // error caused by e.g. an invalid argument.\n\n        // TODO: add some bookkeeping in the future to check whether we can\n        // evict n_blocks and return `None` if we can't. We are now needlessly\n        // evicting prefixes from the cache in such a case.\n        let mut evicted = Vec::new();\n        tracing::debug!(\"Evicting in search of {n_blocks}\");\n\n        while let Some((last_access, node_id)) = self.leaves.pop_first() {\n            let blocks_needed = n_blocks.saturating_sub(evicted.len());\n            tracing::debug!(\"Evicting node {node_id:?} \");\n\n            let node = self.nodes.get(node_id).expect(\"Leave does not exist\");\n            assert_eq!(\n                node.ref_count, 0,\n                \"Leaf must have refcount of 0, got {}\",\n                node.ref_count\n            );\n\n            if blocks_needed >= node.blocks.len() {\n                // We need to evict the whole node if we need more blocks than it has.\n                let node = self.remove_node(node_id);\n                evicted.extend(node.blocks);\n\n                if evicted.len() >= n_blocks {\n                    break;\n                }\n            } else {\n                // The node has more blocks than needed, so we'll just remove\n                // the required number of blocks and leave the remaining blocks\n                // untouched.\n                let node = self.nodes.get_mut(node_id).expect(\"Leave does not exist\");\n\n                let truncate_blocks = node.blocks.len() - blocks_needed;\n                let truncate_tokens = truncate_blocks * self.block_size;\n                node.key.truncate(truncate_tokens);\n                evicted.extend(node.blocks.split_off(truncate_blocks));\n                self.leaves.insert((last_access, node_id));\n                break;\n            }\n        }\n\n        evicted\n    }\n\n    /// Insert a prefill along with its blocks.\n    ///\n    /// This method returns the length of the prefix that was already\n    /// in the trie. E.g. if the length is 10, this means that for\n    /// the first 10 elements of the tree **the blocks are not updated**.\n    pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result<usize, TrieError> {\n        self.time += 1;\n        let common = self.insert_(self.root, tokens, blocks)?;\n        Ok(common)\n    }\n\n    /// Insertion worker.\n    fn insert_(\n        &mut self,\n        node_id: NodeId,\n        tokens: &[u32],\n        blocks: &[u32],\n    ) -> Result<usize, TrieError> {\n        // TODO: in the future we may want to check that the blocks match for\n        // the part of the prefix that is already in the trie to detect\n        // mismatches.\n\n        assert_eq!(tokens.len(), blocks.len() * self.block_size);\n\n        let node_key = hash(&tokens[..self.block_size]);\n        if let Some(&child_id) = self.nodes[node_id].children.get(&node_key) {\n            self.update_access_time(child_id);\n            let child = self\n                .nodes\n                .get_mut(child_id)\n                // Unwrap here, since failure is a bug.\n                .expect(\"Child node does not exist\");\n            let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size);\n\n            // We are done, the prefix is already in the trie.\n            if shared_prefix_len == tokens.len() || shared_prefix_len == 0 {\n                return Ok(shared_prefix_len);\n            }\n\n            // The node's prefix is a prefix of the insertion prefix.\n            if shared_prefix_len == child.key.len() {\n                return Ok(shared_prefix_len\n                    + self.insert_(\n                        child_id,\n                        &tokens[shared_prefix_len..],\n                        &blocks[shared_prefix_len / self.block_size..],\n                    )?);\n            }\n\n            // The node's prefix and the insertion prefix only match partially,\n            // split the node to just contain the matching part. Then insert the\n            // remainder of the prefix into the node again\n            let child_id = self.split_node(child_id, shared_prefix_len);\n            let key = &tokens[shared_prefix_len..];\n            let blocks = &blocks[shared_prefix_len / self.block_size..];\n            Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?)\n        } else {\n            self.add_node(node_id, tokens, blocks);\n            Ok(0)\n        }\n    }\n\n    fn split_node(&mut self, node_id: NodeId, prefix_len: usize) -> NodeId {\n        // We have to make the current node a child to ensure that its\n        // properties and node id stay the same.\n\n        // This funcion unwraps, an  invalid node_id is a programming error.\n\n        let node = self\n            .nodes\n            .get_mut(node_id)\n            .expect(\"Node to-be split does not exist\");\n        let mut parent_key = node.key.split_off(prefix_len);\n        let prefix_blocks = prefix_len / self.block_size;\n        let mut parent_blocks = node.blocks.split_off(prefix_blocks);\n\n        // Move first part of the prefix to the parent. We swap to avoid\n        // an allocation + copy for both splits of the key/blocks.\n        std::mem::swap(&mut node.key, &mut parent_key);\n        std::mem::swap(&mut node.blocks, &mut parent_blocks);\n\n        let node_key = hash(&node.key[..self.block_size]);\n\n        let grandparent_id = node.parent.expect(\"Node does not have a parent\");\n        let parent_id = self.add_node(grandparent_id, parent_key, parent_blocks);\n        self.add_node_to_parent(parent_id, node_key, node_id);\n\n        // Reborrow to make the borrow checker happy.\n        let node = self\n            .nodes\n            .get_mut(node_id)\n            .expect(\"Node to-be split does not exist\");\n        node.parent = Some(parent_id);\n\n        parent_id\n    }\n\n    /// Create a node and add it to the parent.\n    fn add_node(\n        &mut self,\n        parent_id: NodeId,\n        key: impl Into<Vec<u32>>,\n        blocks: impl Into<Vec<u32>>,\n    ) -> NodeId {\n        let key = key.into();\n        let blocks = blocks.into();\n        let first = hash(&key[..self.block_size]);\n\n        let child = TrieNode::new(key, blocks, self.time, Some(parent_id));\n        let child_id = self.nodes.insert(child);\n\n        self.add_node_to_parent(parent_id, first, child_id);\n        self.leaves.insert((self.time, child_id));\n\n        child_id\n    }\n\n    /// Add a node to the parent.\n    fn add_node_to_parent(&mut self, parent_id: NodeId, hash: u64, child_id: NodeId) {\n        // Unwrap here, passing in an unknown id is a programming error.\n        let parent = self.nodes.get_mut(parent_id).expect(\"Unknown parent node\");\n        if parent.children.insert(hash, child_id).is_none() {\n            // Only increase reference count if child does not replace another child.\n            self.incref(parent_id)\n                .expect(\"Failed to increase parent refcount\");\n        }\n    }\n\n    /// Remove a node from the trie.\n    fn remove_node(&mut self, node_id: NodeId) -> TrieNode {\n        // Unwrap here, passing in an unknown id is a programming error.\n        let node = self.nodes.remove(node_id).expect(\"Unknown node\");\n        assert!(\n            node.children.is_empty(),\n            \"Tried to remove a node with {} children\",\n            node.children.len()\n        );\n        let parent_id = node.parent.expect(\"Attempted to remove root node\");\n        let parent = self.nodes.get_mut(parent_id).expect(\"Unknown parent node\");\n\n        let node_key = hash(&node.key[..self.block_size]);\n        parent.children.remove(&node_key);\n        self.decref(parent_id)\n            .expect(\"Failed to decrease parent refcount\");\n        node\n    }\n\n    fn update_access_time(&mut self, node_id: NodeId) {\n        // Unwrap here, passing in an unknown id is a programming error.\n        let node = self.nodes.get_mut(node_id).expect(\"Unknown node\");\n\n        // Update the ordered leaves set if the node is a leave.\n        if self.leaves.remove(&(node.last_accessed, node_id)) {\n            self.leaves.insert((self.time, node_id));\n        }\n\n        node.last_accessed = self.time;\n    }\n\n    #[allow(dead_code)]\n    #[doc(hidden)]\n    /// Print debugging output for the trie.\n    ///\n    /// In contrast to `Debug` nicely formatted.\n    pub fn print_debug(&self) {\n        self.print_debug_(self.root, 0);\n    }\n\n    fn print_debug_(&self, node_id: NodeId, indent: usize) {\n        let node = &self.nodes[node_id];\n        eprintln!(\n            \"{}{:?}, key: {:?}, blocks: {:?}, ref_count: {}, last_accessed: {}, parent: {:?}, children: {:?}\",\n            \" \".repeat(indent),\n            node_id,\n            node.key,\n            node.blocks,\n            node.ref_count,\n            node.last_accessed,\n            node.parent,\n            node.children\n        );\n        for child_id in self.nodes[node_id].children.values() {\n            self.print_debug_(*child_id, indent + 2);\n        }\n    }\n\n    pub(crate) fn root_id(&self) -> DefaultKey {\n        self.root\n    }\n}\n\n/// Trie node.\n#[derive(Debug)]\nstruct TrieNode {\n    blocks: Vec<u32>,\n    children: HashMap<u64, NodeId>,\n    key: Vec<u32>,\n    last_accessed: u64,\n    parent: Option<NodeId>,\n    ref_count: usize,\n}\n\nimpl TrieNode {\n    fn new(key: Vec<u32>, blocks: Vec<u32>, last_accessed: u64, parent: Option<NodeId>) -> Self {\n        TrieNode {\n            children: HashMap::new(),\n            key,\n            blocks,\n            last_accessed,\n            parent,\n            ref_count: 0,\n        }\n    }\n}\n\nfn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize {\n    let full = left.iter().zip(right).take_while(|(a, b)| a == b).count();\n    // NOTE: this is the case because the child node was chosen based on\n    //       matching the first character of the key/prefix.\n    assert!(full > 0, \"Prefixes must at least share 1 token\");\n    (full / block_size) * block_size\n}\n\n#[cfg(test)]\nmod tests {\n    use std::sync::Arc;\n\n    use rand::{\n        distributions::Uniform, prelude::Distribution, rngs::SmallRng, seq::SliceRandom,\n        SeedableRng,\n    };\n    use rustc_hash::FxHashSet;\n\n    use super::*;\n\n    #[test]\n    fn allocator_block_size() {\n        let mut cache = RadixAllocator::new(2, 12, None);\n        let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);\n        assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);\n        assert_eq!(allocation.prefix_len, 0);\n        cache.free(allocation.blocks.clone(), allocation.allocation_id);\n\n        let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);\n        assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]);\n        assert_eq!(allocation.prefix_len, 4);\n    }\n\n    #[test]\n    fn allocator_block_size_non_aligned() {\n        let mut cache = RadixAllocator::new(2, 12, None);\n        let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();\n        assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);\n        assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);\n        assert_eq!(allocation.prefix_len, 0);\n        cache.free(allocation.blocks.clone(), allocation.allocation_id);\n\n        let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap();\n        assert_eq!(allocation.blocks, vec![8, 9, 10, 11]);\n        assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]);\n        assert_eq!(allocation.prefix_len, 2);\n    }\n\n    #[test]\n    fn allocator_reuses_prefixes() {\n        let mut cache = RadixAllocator::new(1, 12, None);\n        let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);\n        assert_eq!(allocation.blocks, allocation.slots);\n        assert_eq!(allocation.prefix_len, 0);\n        cache.free(allocation.blocks.clone(), allocation.allocation_id);\n\n        let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]);\n        assert_eq!(allocation.prefix_len, 4);\n    }\n\n    #[test]\n    fn allocator_collects_older_prefixes_first() {\n        let mut cache = RadixAllocator::new(1, 7, None);\n        let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation1.blocks, vec![3, 4, 5, 6]);\n        assert_eq!(allocation1.prefix_len, 0);\n\n        let allocation2 = cache.allocate(2, Some(Arc::new(vec![4, 5]))).unwrap();\n        assert_eq!(allocation2.blocks, vec![1, 2]);\n        assert_eq!(allocation2.prefix_len, 0);\n\n        cache.free(allocation1.blocks.clone(), allocation1.allocation_id);\n        cache.free(allocation2.blocks.clone(), allocation2.allocation_id);\n\n        // We should get the blocks of the first allocation, since they are more recent.\n        let allocation3 = cache.allocate(4, Some(Arc::new(vec![6, 7, 8, 9]))).unwrap();\n        assert_eq!(allocation3.blocks, vec![3, 4, 5, 6]);\n        assert_eq!(allocation3.prefix_len, 0);\n    }\n\n    #[test]\n    fn allocator_frees_fully_overlapping_prefills() {\n        let mut cache = RadixAllocator::new(1, 10, None);\n        let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        let allocation2 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n\n        cache.free(allocation2.blocks.clone(), allocation2.allocation_id);\n        cache.free(allocation1.blocks.clone(), allocation1.allocation_id);\n\n        let allocation3 = cache.allocate(4, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap();\n        assert_eq!(allocation3.prefix_len, 4);\n\n        // 10 blocks, of which 1 reserved for health checks, 4 for the cached blocks.\n        assert_eq!(cache.free_blocks.len(), 5);\n    }\n\n    #[test]\n    fn allocator_frees_partially_overlapping_prefills() {\n        let mut cache = RadixAllocator::new(1, 20, None);\n        let allocation1 = cache.allocate(4, Some(Arc::new(vec![0, 1]))).unwrap();\n        assert_eq!(allocation1.blocks, vec![16, 17, 18, 19]);\n        assert_eq!(allocation1.prefix_len, 0);\n\n        cache.free(allocation1.blocks.clone(), allocation1.allocation_id);\n\n        let allocation2 = cache\n            .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))\n            .unwrap();\n        assert_eq!(allocation2.blocks, vec![16, 17, 12, 13, 14, 15, 18, 19]);\n        assert_eq!(allocation2.prefix_len, 2);\n\n        let allocation3 = cache\n            .allocate(8, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))\n            .unwrap();\n        assert_eq!(allocation3.blocks, vec![16, 17, 6, 7, 8, 9, 10, 11]);\n        assert_eq!(allocation3.prefix_len, 2);\n\n        cache.free(allocation3.blocks.clone(), allocation3.allocation_id);\n        cache.free(allocation2.blocks.clone(), allocation2.allocation_id);\n\n        // 20 blocks, of which 1 reserved for health checks, 6 for allocation3, 2 for allocation2.\n        assert_eq!(cache.free_blocks.len(), 11);\n\n        let allocation4 = cache\n            .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 4, 5])))\n            .unwrap();\n        assert_eq!(allocation4.blocks, vec![16, 17, 6, 7, 14, 15]);\n        assert_eq!(allocation4.prefix_len, 6);\n        assert_eq!(cache.free_blocks.len(), 11);\n\n        let allocation5 = cache\n            .allocate(6, Some(Arc::new(vec![0, 1, 2, 3, 6, 7])))\n            .unwrap();\n        assert_eq!(allocation5.blocks, vec![16, 17, 6, 7, 8, 9]);\n        assert_eq!(allocation5.prefix_len, 6);\n        assert_eq!(cache.free_blocks.len(), 11);\n    }\n\n    #[test]\n    fn trie_insertions_have_correct_prefix_len() {\n        let mut trie = RadixTrie::new(1);\n\n        assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0);\n\n        // Already exists.\n        assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 3);\n\n        // Completely new at root-level\n        assert_eq!(trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(), 0);\n\n        // Contains full prefix, but longer.\n        assert_eq!(trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(), 3);\n\n        // Shares partial prefix, we need a split.\n        assert_eq!(\n            trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])\n                .unwrap(),\n            4\n        );\n    }\n\n    #[test]\n    fn trie_insertions_block_size() {\n        let mut trie = RadixTrie::new(2);\n\n        assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0);\n\n        // Already exists.\n        // But needs to be block_size aligned\n        assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4);\n\n        // Completely new at root-level\n        assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0);\n\n        // Contains full prefix, but longer.\n        assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4);\n\n        // Shares partial prefix, we need a split.\n        assert_eq!(\n            trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3])\n                .unwrap(),\n            2\n        );\n    }\n\n    #[test]\n    fn trie_get_returns_correct_blocks() {\n        let mut trie = RadixTrie::new(1);\n        trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();\n        trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();\n        trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();\n        trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])\n            .unwrap();\n\n        let mut blocks = Vec::new();\n        trie.find(&[0], &mut blocks);\n        assert_eq!(blocks, vec![0]);\n\n        blocks.clear();\n        trie.find(&[0, 1, 2], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2]);\n\n        blocks.clear();\n        trie.find(&[1, 2, 3], &mut blocks);\n        assert_eq!(blocks, vec![1, 2, 3]);\n\n        blocks.clear();\n        trie.find(&[0, 1, 2, 3], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2, 3]);\n\n        blocks.clear();\n        trie.find(&[0, 1, 2, 3, 4], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2, 3, 4]);\n\n        blocks.clear();\n        trie.find(&[0, 1, 2, 3, 5], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2, 3, 5]);\n    }\n\n    #[test]\n    fn trie_evict_removes_correct_blocks() {\n        let mut trie = RadixTrie::new(1);\n        trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();\n        trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7])\n            .unwrap();\n        trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap();\n        trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap();\n\n        let mut blocks = Vec::new();\n\n        // Remove less than the leave blocks.\n        assert_eq!(trie.evict(1), vec![7]);\n        trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2, 3, 5, 6]);\n\n        // Refresh other leaf.\n        trie.find(&[0, 1, 2, 3, 4], &mut blocks);\n        trie.find(&[1, 2, 3], &mut blocks);\n\n        // Remove the leave blocks exactly.\n        assert_eq!(trie.evict(2), vec![5, 6]);\n        blocks.clear();\n        trie.find(&[0, 1, 2, 3, 5, 6, 7], &mut blocks);\n        assert_eq!(blocks, vec![0, 1, 2, 3]);\n\n        trie.find(&[1, 2, 3], &mut blocks);\n\n        // Remove more than the leave blocks.\n        assert_eq!(trie.evict(3), vec![4, 3, 2]);\n        blocks.clear();\n        trie.find(&[0, 1, 2, 3, 4], &mut blocks);\n        assert_eq!(blocks, vec![0, 1]);\n\n        // Clear out the whole trie.\n        assert_eq!(trie.evict(10), vec![1, 2, 3, 0, 1]);\n    }\n\n    #[test]\n    fn full_match_returns_correct_node() {\n        let mut trie = RadixTrie::new(1);\n        trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();\n        let node_id = trie.find(&[0, 1, 2], &mut vec![]);\n        // At this point, there are only two nodes: the root and the node\n        // with tokens 0, 1, 2. Looking up the exact prefix must return\n        // the non-root node.\n        assert_ne!(node_id, trie.root);\n    }\n\n    #[test]\n    fn partial_match_does_not_recurse() {\n        let mut trie = RadixTrie::new(1);\n        trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap();\n        trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5])\n            .unwrap();\n        let mut blocks = Vec::new();\n        let node_id = trie.find(&[0, 1, 3, 4, 5], &mut blocks);\n        assert_eq!(blocks, vec![0, 1]);\n        assert_eq!(node_id, trie.find(&[0, 1], &mut blocks))\n    }\n\n    struct AllocationWithInfo {\n        allocation: BlockAllocation,\n        // We are doing a lot of set operations and `FxBuildHasher` is\n        // muc faster for a set of integers.\n        blockset: FxHashSet<u32>,\n        non_prefix_blocks: FxHashSet<u32>,\n    }\n\n    #[test]\n    fn invariants_hold_on_many_operations_remove_all() {\n        invariants_hold_on_many_insertions(true);\n    }\n\n    #[test]\n    fn invariants_hold_on_many_operations_remove_subset() {\n        invariants_hold_on_many_insertions(false);\n    }\n\n    fn invariants_hold_on_many_insertions(remove_all: bool) {\n        // Small vocabulary sizes lead to violations more quickly due to\n        // prefix sharing, etc.\n        const VOCAB_SIZE: u32 = 2;\n        const DATA_LEN: usize = 1_000;\n\n        const MAX_PREFILL_LEN: usize = 8;\n        const MAX_DECODE_LEN: usize = 8;\n\n        let vocab_range = Uniform::new(0, VOCAB_SIZE);\n        let data_range = Uniform::new(0, DATA_LEN);\n        let prefill_len_range = Uniform::new(0, MAX_PREFILL_LEN);\n        let decode_len_range = Uniform::new(0, MAX_DECODE_LEN);\n\n        let mut rng = SmallRng::seed_from_u64(64);\n        let data = (0..DATA_LEN)\n            .map(|_| vocab_range.sample(&mut rng))\n            .collect::<Vec<_>>();\n        let mut allocator = RadixAllocator::new(1, 100, None);\n\n        let mut allocations = Vec::new();\n\n        for i in 0..100_000 {\n            // Allocate until all blocks are used.\n            'allocation: loop {\n                // Use offset 0 half of the times for prefix sharing.\n                let prefill_offset = data_range.sample(&mut rng);\n                let prefill_len = prefill_len_range.sample(&mut rng);\n                let decode_len = decode_len_range.sample(&mut rng);\n\n                let prefill =\n                    data[prefill_offset..data.len().min(prefill_offset + prefill_len)].to_vec();\n\n                let allocation = match allocator\n                    .allocate((prefill.len() + decode_len) as u32, Some(Arc::new(prefill)))\n                {\n                    Some(allocation) => allocation,\n                    None => break 'allocation,\n                };\n                let non_prefix_blocks = allocation.blocks[allocation.prefix_len as usize..]\n                    .iter()\n                    .copied()\n                    .collect::<FxHashSet<_>>();\n                let blockset = allocation.blocks.iter().copied().collect::<FxHashSet<_>>();\n\n                // No duplicate blocks in an allocation.\n                assert_eq!(\n                    allocation.blocks.len(),\n                    blockset.len(),\n                    \"Duplicate blocks in allocation\"\n                );\n\n                allocations.push(AllocationWithInfo {\n                    allocation,\n                    blockset,\n                    non_prefix_blocks,\n                });\n            }\n\n            // Check invariants. Skip first iteration, since there is no prefix sharing yet.\n            if i > 1 {\n                check_allocation_invariants(&allocations);\n            }\n\n            // Remove 20% of the allocations, randomly.\n            if remove_all {\n                allocations.into_iter().for_each(|allocation| {\n                    allocator.free(\n                        allocation.allocation.blocks.clone(),\n                        allocation.allocation.allocation_id,\n                    )\n                });\n                allocations = Vec::new();\n            } else {\n                allocations.shuffle(&mut rng);\n                let remove_index = (allocations.len() as f64 * 0.8) as usize;\n                for allocation in allocations.drain(remove_index..) {\n                    allocator.free(\n                        allocation.allocation.blocks.clone(),\n                        allocation.allocation.allocation_id,\n                    );\n                }\n            }\n        }\n    }\n\n    fn check_allocation_invariants(allocations: &[AllocationWithInfo]) {\n        for i in 0..allocations.len() {\n            let allocation = &allocations[i];\n\n            // 0 is used for health checks, must not be used.\n            assert!(\n                !allocation.blockset.contains(&0),\n                \"Block 0 must not be allocated\"\n            );\n\n            // No duplicate blocks in an allocation.\n            assert_eq!(\n                allocation.allocation.blocks.len(),\n                allocation.blockset.len(),\n                \"Duplicate blocks in allocation\"\n            );\n\n            for other_allocation in &allocations[i + 1..] {\n                assert!(\n                    other_allocation\n                        .non_prefix_blocks\n                        .is_disjoint(&allocation.non_prefix_blocks),\n                    \"Allocations share non-prefix blocks\"\n                )\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "benchmark/Cargo.toml",
    "content": "[package]\nname = \"text-generation-benchmark\"\ndescription = \"Text Generation Benchmarking tool\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[lib]\npath = \"src/lib.rs\"\n\n[[bin]]\nname = \"text-generation-benchmark\"\npath = \"src/main.rs\"\n\n[dependencies]\naverage = \"0.14\"\nclap = { version = \"4.4.5\", features = [\"derive\", \"env\"] }\nfloat-ord = \"0.3.2\"\nserde = {version = \"1.0.188\", features = [\"derive\"]}\nserde_json = \"1.0\"\ntabled = \"0.14.0\"\ntext-generation-client = { path = \"../backends/client\" }\nthiserror = \"1.0.48\"\ntokenizers = { workspace = true }\ntokio = { version = \"1.32.0\", features = [\"rt\", \"rt-multi-thread\", \"parking_lot\", \"signal\", \"sync\", \"macros\"] }\nratatui = \"0.28.1\"\ntracing = \"0.1.37\"\ntracing-subscriber = { version = \"0.3.17\", features = [\"json\", \"env-filter\"] }\nhf-hub = { workspace = true }\n"
  },
  {
    "path": "benchmark/README.md",
    "content": "<div align=\"center\">\n\n# Text Generation Inference benchmarking tool\n\n![benchmark](../assets/benchmark.png)\n\n</div>\n\nA lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)\nand powered by [Ratatui](https://github.com/ratatui/ratatui).\n\n## Install\n\n```shell\nmake install-benchmark\n```\n\n## Run\n\nFirst, start `text-generation-inference`:\n\n```shell\ntext-generation-launcher --model-id bigscience/bloom-560m\n```\n\nThen run the benchmarking tool:\n\n```shell\ntext-generation-benchmark --tokenizer-name bigscience/bloom-560m\n```\n"
  },
  {
    "path": "benchmark/src/app.rs",
    "content": "/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs\nuse crate::generation::{Decode, Message, Prefill};\nuse ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};\nuse ratatui::layout::{Alignment, Constraint, Direction, Layout};\nuse ratatui::style::{Color, Modifier, Style};\nuse ratatui::text::{Line, Span};\nuse ratatui::widgets::{\n    Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,\n};\nuse ratatui::{symbols, Frame};\nuse text_generation_client::ClientError;\nuse tokio::sync::mpsc;\n\n/// TUI powered App\npub(crate) struct App {\n    pub(crate) running: bool,\n    pub(crate) data: Data,\n    completed_runs: Vec<usize>,\n    completed_batch: usize,\n    current_batch: usize,\n    current_tab: usize,\n    touched_tab: bool,\n    zoom: bool,\n    is_error: bool,\n    tokenizer_name: String,\n    sequence_length: u32,\n    decode_length: u32,\n    n_run: usize,\n    receiver: mpsc::Receiver<Result<Message, ClientError>>,\n}\n\nimpl App {\n    pub(crate) fn new(\n        receiver: mpsc::Receiver<Result<Message, ClientError>>,\n        tokenizer_name: String,\n        sequence_length: u32,\n        decode_length: u32,\n        n_run: usize,\n        batch_size: Vec<u32>,\n    ) -> Self {\n        let current_tab = 0;\n\n        let completed_runs: Vec<usize> = (0..batch_size.len()).map(|_| 0).collect();\n        let completed_batch = 0;\n        let current_batch = 0;\n        let is_error = false;\n\n        let data = Data::new(n_run, batch_size);\n\n        Self {\n            running: true,\n            data,\n            completed_runs,\n            completed_batch,\n            current_batch,\n            current_tab,\n            touched_tab: false,\n            zoom: false,\n            is_error,\n            tokenizer_name,\n            sequence_length,\n            decode_length,\n            n_run,\n            receiver,\n        }\n    }\n\n    /// Handle crossterm key events\n    pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {\n        match key_event {\n            // Increase and wrap tab\n            KeyEvent {\n                code: KeyCode::Right,\n                ..\n            }\n            | KeyEvent {\n                code: KeyCode::Tab, ..\n            } => {\n                self.touched_tab = true;\n                self.current_tab = (self.current_tab + 1) % self.data.batch_size.len();\n            }\n            // Decrease and wrap tab\n            KeyEvent {\n                code: KeyCode::Left,\n                ..\n            } => {\n                self.touched_tab = true;\n                if self.current_tab > 0 {\n                    self.current_tab -= 1;\n                } else {\n                    self.current_tab = self.data.batch_size.len() - 1;\n                }\n            }\n            // Zoom on throughput/latency fig\n            KeyEvent {\n                code: KeyCode::Char('+'),\n                ..\n            } => {\n                self.zoom = true;\n            }\n            // Unzoom on throughput/latency fig\n            KeyEvent {\n                code: KeyCode::Char('-'),\n                ..\n            } => {\n                self.zoom = false;\n            }\n            // Quit\n            KeyEvent {\n                code: KeyCode::Char('q'),\n                ..\n            }\n            | KeyEvent {\n                code: KeyCode::Char('c'),\n                modifiers: KeyModifiers::CONTROL,\n                ..\n            } => {\n                self.running = false;\n            }\n            _ => (),\n        }\n    }\n\n    /// Get all pending messages from generation task\n    pub(crate) fn tick(&mut self) {\n        while let Ok(message) = self.receiver.try_recv() {\n            match message {\n                Ok(message) => match message {\n                    Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),\n                    Message::Decode(step) => self.data.push_decode(step, self.current_batch),\n                    Message::EndRun => {\n                        self.completed_runs[self.current_batch] += 1;\n                    }\n                    Message::EndBatch => {\n                        self.data.end_batch(self.current_batch);\n                        self.completed_batch += 1;\n\n                        if self.current_batch < self.data.batch_size.len() - 1 {\n                            // Only go to next tab if the user never touched the tab keys\n                            if !self.touched_tab {\n                                self.current_tab += 1;\n                            }\n\n                            self.current_batch += 1;\n                        }\n                    }\n                    Message::Warmup => {}\n                },\n                Err(_) => self.is_error = true,\n            }\n        }\n    }\n\n    /// Render frame\n    pub fn render(&mut self, f: &mut Frame) {\n        let batch_progress =\n            (self.completed_batch as f64 / self.data.batch_size.len() as f64).clamp(0.0, 1.0);\n        let run_progress =\n            (self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0);\n\n        // Vertical layout\n        let row5 = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints(\n                [\n                    Constraint::Length(1),\n                    Constraint::Length(3),\n                    Constraint::Length(3),\n                    Constraint::Length(13),\n                    Constraint::Min(10),\n                ]\n                .as_ref(),\n            )\n            .split(f.area());\n\n        // Top row horizontal layout\n        let top = Layout::default()\n            .direction(Direction::Horizontal)\n            .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())\n            .split(row5[2]);\n\n        // Mid row horizontal layout\n        let mid = Layout::default()\n            .direction(Direction::Horizontal)\n            .constraints(\n                [\n                    Constraint::Percentage(25),\n                    Constraint::Percentage(25),\n                    Constraint::Percentage(25),\n                    Constraint::Percentage(25),\n                ]\n                .as_ref(),\n            )\n            .split(row5[3]);\n\n        // Left mid row vertical layout\n        let prefill_text = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())\n            .split(mid[0]);\n\n        // Right mid row vertical layout\n        let decode_text = Layout::default()\n            .direction(Direction::Vertical)\n            .constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())\n            .split(mid[2]);\n        let decode_text_latency = Layout::default()\n            .direction(Direction::Horizontal)\n            .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())\n            .split(decode_text[0]);\n\n        // Bottom row horizontal layout\n        let bottom = Layout::default()\n            .direction(Direction::Horizontal)\n            .constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())\n            .split(row5[4]);\n\n        // Title\n        let title = Block::default()\n            .borders(Borders::NONE)\n            .title(format!(\n                \"Model: {} | Sequence Length: {} | Decode Length: {}\",\n                self.tokenizer_name, self.sequence_length, self.decode_length\n            ))\n            .style(\n                Style::default()\n                    .add_modifier(Modifier::BOLD)\n                    .fg(Color::White),\n            );\n        f.render_widget(title, row5[0]);\n\n        // Helper\n        let helper = Block::default()\n            .borders(Borders::NONE)\n            .title(\"<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom\")\n            .title_alignment(Alignment::Right)\n            .style(Style::default().fg(Color::White));\n        f.render_widget(helper, row5[0]);\n\n        // Batch tabs\n        let titles: Vec<Line> = self\n            .data\n            .batch_size\n            .iter()\n            .map(|b| {\n                Line::from(vec![Span::styled(\n                    format!(\"Batch: {b}\"),\n                    Style::default().fg(Color::White),\n                )])\n            })\n            .collect();\n        let tabs = Tabs::new(titles)\n            .block(Block::default().borders(Borders::ALL).title(\"Tabs\"))\n            .select(self.current_tab)\n            .style(Style::default().fg(Color::LightCyan))\n            .highlight_style(\n                Style::default()\n                    .add_modifier(Modifier::BOLD)\n                    .bg(Color::Black),\n            );\n        f.render_widget(tabs, row5[1]);\n\n        // Total progress bar\n        let color = if self.is_error {\n            Color::Red\n        } else {\n            Color::LightGreen\n        };\n        let batch_gauge = progress_gauge(\n            \"Total Progress\",\n            format!(\"{} / {}\", self.completed_batch, self.data.batch_size.len()),\n            batch_progress,\n            color,\n        );\n        f.render_widget(batch_gauge, top[0]);\n\n        // Batch progress Bar\n        let color = if self.is_error {\n            Color::Red\n        } else {\n            Color::LightBlue\n        };\n        let run_gauge = progress_gauge(\n            \"Batch Progress\",\n            format!(\n                \"{} / {}\",\n                self.completed_runs[self.current_batch], self.n_run\n            ),\n            run_progress,\n            color,\n        );\n        f.render_widget(run_gauge, top[1]);\n\n        // Prefill text infos\n        let prefill_latency_block = latency_paragraph(\n            &mut self.data.prefill_latencies[self.current_tab],\n            \"Prefill\",\n        );\n        let prefill_throughput_block =\n            throughput_paragraph(&self.data.prefill_throughputs[self.current_tab], \"Prefill\");\n\n        f.render_widget(prefill_latency_block, prefill_text[0]);\n        f.render_widget(prefill_throughput_block, prefill_text[1]);\n\n        // Prefill latency histogram\n        let histo_width = 7;\n        let bins = if mid[1].width < 2 {\n            0\n        } else {\n            (mid[1].width as usize - 2) / (histo_width + 1)\n        }\n        .max(2);\n\n        let histo_data =\n            latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins);\n        let histo_data_str: Vec<(&str, u64)> =\n            histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();\n        let prefill_histogram =\n            latency_histogram(&histo_data_str, \"Prefill\").bar_width(histo_width as u16);\n        f.render_widget(prefill_histogram, mid[1]);\n\n        // Decode text info\n        let decode_latency_block = latency_paragraph(\n            &mut self.data.decode_latencies[self.current_tab],\n            \"Decode Total\",\n        );\n        let decode_token_latency_block = latency_paragraph(\n            &mut self.data.decode_token_latencies[self.current_tab],\n            \"Decode Token\",\n        );\n        let decode_throughput_block =\n            throughput_paragraph(&self.data.decode_throughputs[self.current_tab], \"Decode\");\n        f.render_widget(decode_latency_block, decode_text_latency[0]);\n        f.render_widget(decode_token_latency_block, decode_text_latency[1]);\n        f.render_widget(decode_throughput_block, decode_text[1]);\n\n        // Decode latency histogram\n        let histo_data =\n            latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins);\n        let histo_data_str: Vec<(&str, u64)> =\n            histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();\n        let decode_histogram =\n            latency_histogram(&histo_data_str, \"Decode\").bar_width(histo_width as u16);\n        f.render_widget(decode_histogram, mid[3]);\n\n        // Prefill latency/throughput chart\n        let prefill_latency_throughput_chart = latency_throughput_chart(\n            &self.data.prefill_batch_latency_throughput,\n            &self.data.batch_size,\n            self.zoom,\n            \"Prefill\",\n        );\n        f.render_widget(prefill_latency_throughput_chart, bottom[0]);\n\n        // Decode latency/throughput chart\n        let decode_latency_throughput_chart = latency_throughput_chart(\n            &self.data.decode_batch_latency_throughput,\n            &self.data.batch_size,\n            self.zoom,\n            \"Decode\",\n        );\n        f.render_widget(decode_latency_throughput_chart, bottom[1]);\n    }\n}\n\n/// App internal data struct\npub(crate) struct Data {\n    pub(crate) batch_size: Vec<u32>,\n    pub(crate) prefill_latencies: Vec<Vec<f64>>,\n    pub(crate) prefill_throughputs: Vec<Vec<f64>>,\n    pub(crate) decode_latencies: Vec<Vec<f64>>,\n    pub(crate) decode_token_latencies: Vec<Vec<f64>>,\n    pub(crate) decode_throughputs: Vec<Vec<f64>>,\n    pub(crate) prefill_batch_latency_throughput: Vec<(f64, f64)>,\n    pub(crate) decode_batch_latency_throughput: Vec<(f64, f64)>,\n}\n\nimpl Data {\n    fn new(n_run: usize, batch_size: Vec<u32>) -> Self {\n        let prefill_latencies: Vec<Vec<f64>> = (0..batch_size.len())\n            .map(|_| Vec::with_capacity(n_run))\n            .collect();\n        let prefill_throughputs: Vec<Vec<f64>> = prefill_latencies.clone();\n\n        let decode_latencies: Vec<Vec<f64>> = prefill_latencies.clone();\n        let decode_token_latencies: Vec<Vec<f64>> = decode_latencies.clone();\n        let decode_throughputs: Vec<Vec<f64>> = prefill_throughputs.clone();\n\n        let prefill_batch_latency_throughput: Vec<(f64, f64)> =\n            Vec::with_capacity(batch_size.len());\n        let decode_batch_latency_throughput: Vec<(f64, f64)> =\n            prefill_batch_latency_throughput.clone();\n\n        Self {\n            batch_size,\n            prefill_latencies,\n            prefill_throughputs,\n            decode_latencies,\n            decode_token_latencies,\n            decode_throughputs,\n            prefill_batch_latency_throughput,\n            decode_batch_latency_throughput,\n        }\n    }\n\n    fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {\n        let latency = prefill.latency.as_micros() as f64 / 1000.0;\n        self.prefill_latencies[batch_idx].push(latency);\n        self.prefill_throughputs[batch_idx].push(prefill.throughput);\n    }\n\n    fn push_decode(&mut self, decode: Decode, batch_idx: usize) {\n        let latency = decode.latency.as_micros() as f64 / 1000.0;\n        let token_latency = decode.token_latency.as_micros() as f64 / 1000.0;\n        self.decode_latencies[batch_idx].push(latency);\n        self.decode_token_latencies[batch_idx].push(token_latency);\n        self.decode_throughputs[batch_idx].push(decode.throughput);\n    }\n\n    fn end_batch(&mut self, batch_idx: usize) {\n        self.prefill_batch_latency_throughput.push((\n            self.prefill_latencies[batch_idx].iter().sum::<f64>()\n                / self.prefill_latencies[batch_idx].len() as f64,\n            self.prefill_throughputs[batch_idx].iter().sum::<f64>()\n                / self.prefill_throughputs[batch_idx].len() as f64,\n        ));\n        self.decode_batch_latency_throughput.push((\n            self.decode_latencies[batch_idx].iter().sum::<f64>()\n                / self.decode_latencies[batch_idx].len() as f64,\n            self.decode_throughputs[batch_idx].iter().sum::<f64>()\n                / self.decode_throughputs[batch_idx].len() as f64,\n        ));\n    }\n}\n\n/// Progress bar\nfn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge<'_> {\n    Gauge::default()\n        .block(Block::default().title(title).borders(Borders::ALL))\n        .gauge_style(Style::default().fg(color))\n        .label(Span::raw(label))\n        .ratio(progress)\n}\n\n/// Throughput paragraph\nfn throughput_paragraph<'a>(throughput: &[f64], name: &'static str) -> Paragraph<'a> {\n    // Throughput average/high/low texts\n    let throughput_texts = statis_spans(throughput, \"tokens/secs\");\n\n    // Throughput block\n    Paragraph::new(throughput_texts).block(\n        Block::default()\n            .title(Span::raw(format!(\"{name} Throughput\")))\n            .borders(Borders::ALL),\n    )\n}\n\n/// Latency paragraph\nfn latency_paragraph<'a>(latency: &mut [f64], name: &'static str) -> Paragraph<'a> {\n    // Latency average/high/low texts\n    let mut latency_texts = statis_spans(latency, \"ms\");\n\n    // Sort latency for percentiles\n    float_ord::sort(latency);\n    let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);\n\n    // Latency p50/p90/p99 texts\n    let colors = [Color::LightGreen, Color::LightYellow, Color::LightRed];\n    for (i, (name, value)) in latency_percentiles.iter().enumerate() {\n        let span = Line::from(vec![Span::styled(\n            format!(\"{name}:     {value:.2} ms\"),\n            Style::default().fg(colors[i]),\n        )]);\n        latency_texts.push(span);\n    }\n\n    Paragraph::new(latency_texts).block(\n        Block::default()\n            .title(Span::raw(format!(\"{name} Latency\")))\n            .borders(Borders::ALL),\n    )\n}\n\n/// Average/High/Low spans\nfn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {\n    vec![\n        Line::from(vec![Span::styled(\n            format!(\n                \"Average: {:.2} {unit}\",\n                data.iter().sum::<f64>() / data.len() as f64\n            ),\n            Style::default().fg(Color::LightBlue),\n        )]),\n        Line::from(vec![Span::styled(\n            format!(\n                \"Lowest:  {:.2} {unit}\",\n                data.iter()\n                    .min_by(|a, b| a.total_cmp(b))\n                    .unwrap_or(&f64::NAN)\n            ),\n            Style::default().fg(Color::Reset),\n        )]),\n        Line::from(vec![Span::styled(\n            format!(\n                \"Highest: {:.2} {unit}\",\n                data.iter()\n                    .max_by(|a, b| a.total_cmp(b))\n                    .unwrap_or(&f64::NAN)\n            ),\n            Style::default().fg(Color::Reset),\n        )]),\n    ]\n}\n\n/// Latency histogram data\nfn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> {\n    let histo_data: Vec<(String, u64)> = {\n        let histo = crate::utils::histogram(latency, bins);\n        histo\n            .into_iter()\n            .map(|(label, v)| (format!(\"{label:.2}\"), v as u64))\n            .collect()\n    };\n\n    histo_data\n}\n\n/// Latency Histogram\nfn latency_histogram<'a>(\n    histo_data_str: &'a Vec<(&'a str, u64)>,\n    name: &'static str,\n) -> BarChart<'a> {\n    BarChart::default()\n        .block(\n            Block::default()\n                .title(format!(\"{name} latency histogram\"))\n                .style(Style::default().fg(Color::LightYellow).bg(Color::Reset))\n                .borders(Borders::ALL),\n        )\n        .data(histo_data_str.as_slice())\n}\n\n/// Latency/Throughput chart\nfn latency_throughput_chart<'a>(\n    latency_throughput: &'a [(f64, f64)],\n    batch_sizes: &'a [u32],\n    zoom: bool,\n    name: &'static str,\n) -> Chart<'a> {\n    let latency_iter = latency_throughput.iter().map(|(l, _)| l);\n    let throughput_iter = latency_throughput.iter().map(|(_, t)| t);\n\n    // Get extreme values\n    let min_latency: f64 = *latency_iter\n        .clone()\n        .min_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n    let max_latency: f64 = *latency_iter\n        .max_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n    let min_throughput: f64 = *throughput_iter\n        .clone()\n        .min_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n    let max_throughput: f64 = *throughput_iter\n        .max_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n\n    // Char min max values\n    let min_x = if zoom {\n        ((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0\n    } else {\n        0.0\n    };\n    let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;\n    let step_x = (max_x - min_x) / 4.0;\n\n    // Chart min max values\n    let min_y = if zoom {\n        ((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0\n    } else {\n        0.0\n    };\n    let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;\n    let step_y = (max_y - min_y) / 4.0;\n\n    // Labels\n    let mut x_labels = vec![Span::styled(\n        format!(\"{min_x:.2}\"),\n        Style::default()\n            .add_modifier(Modifier::BOLD)\n            .fg(Color::Gray)\n            .bg(Color::Reset),\n    )];\n    for i in 0..3 {\n        x_labels.push(Span::styled(\n            format!(\"{:.2}\", min_x + ((i + 1) as f64 * step_x)),\n            Style::default().fg(Color::Gray).bg(Color::Reset),\n        ));\n    }\n    x_labels.push(Span::styled(\n        format!(\"{max_x:.2}\"),\n        Style::default()\n            .add_modifier(Modifier::BOLD)\n            .fg(Color::Gray)\n            .bg(Color::Reset),\n    ));\n\n    // Labels\n    let mut y_labels = vec![Span::styled(\n        format!(\"{min_y:.2}\"),\n        Style::default()\n            .add_modifier(Modifier::BOLD)\n            .fg(Color::Gray)\n            .bg(Color::Reset),\n    )];\n    for i in 0..3 {\n        y_labels.push(Span::styled(\n            format!(\"{:.2}\", min_y + ((i + 1) as f64 * step_y)),\n            Style::default().fg(Color::Gray).bg(Color::Reset),\n        ));\n    }\n    y_labels.push(Span::styled(\n        format!(\"{max_y:.2}\"),\n        Style::default()\n            .add_modifier(Modifier::BOLD)\n            .fg(Color::Gray)\n            .bg(Color::Reset),\n    ));\n\n    // Chart dataset\n    let colors = color_vec();\n    let datasets: Vec<Dataset> = (0..latency_throughput.len())\n        .map(|i| {\n            let color_idx = i % colors.len();\n\n            Dataset::default()\n                .name(batch_sizes[i].to_string())\n                .marker(symbols::Marker::Block)\n                .style(Style::default().fg(colors[color_idx]))\n                .graph_type(GraphType::Scatter)\n                .data(&latency_throughput[i..(i + 1)])\n        })\n        .collect();\n\n    // Chart\n    Chart::new(datasets)\n        .style(Style::default().fg(Color::Cyan).bg(Color::Reset))\n        .block(\n            Block::default()\n                .title(Span::styled(\n                    format!(\"{name} throughput over latency\"),\n                    Style::default().fg(Color::Gray).bg(Color::Reset),\n                ))\n                .borders(Borders::ALL),\n        )\n        .x_axis(\n            Axis::default()\n                .title(\"ms\")\n                .style(Style::default().fg(Color::Gray).bg(Color::Reset))\n                .labels(x_labels)\n                .bounds([min_x, max_x]),\n        )\n        .y_axis(\n            Axis::default()\n                .title(\"tokens/secs\")\n                .style(Style::default().fg(Color::Gray).bg(Color::Reset))\n                .labels(y_labels)\n                .bounds([min_y, max_y]),\n        )\n}\n\n// Colors for latency/throughput chart\nfn color_vec() -> Vec<Color> {\n    vec![\n        Color::Red,\n        Color::Green,\n        Color::Yellow,\n        Color::Blue,\n        Color::Magenta,\n        Color::Cyan,\n        Color::Gray,\n        Color::DarkGray,\n        Color::LightRed,\n        Color::LightGreen,\n        Color::LightYellow,\n        Color::LightBlue,\n        Color::LightMagenta,\n        Color::LightCyan,\n    ]\n}\n"
  },
  {
    "path": "benchmark/src/event.rs",
    "content": "/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs\nuse ratatui::crossterm::event;\nuse std::time::{Duration, Instant};\nuse tokio::sync::{broadcast, mpsc};\n\n/// Events\n#[derive(Debug)]\npub(crate) enum Event {\n    /// Terminal tick.\n    Tick,\n    /// Key press.\n    Key(event::KeyEvent),\n    /// Terminal resize.\n    Resize,\n}\n\npub(crate) async fn terminal_event_task(\n    fps: u32,\n    event_sender: mpsc::Sender<Event>,\n    mut shutdown_receiver: broadcast::Receiver<()>,\n    _shutdown_guard_sender: mpsc::Sender<()>,\n) {\n    // End task if a message is received on shutdown_receiver\n    // _shutdown_guard_sender will be dropped once the task is finished\n    tokio::select! {\n        _ = event_loop(fps, event_sender)  => {\n        },\n        _ = shutdown_receiver.recv() => {}\n    }\n}\n\n/// Main event loop\nasync fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {\n    // Frame budget\n    let per_frame = Duration::from_secs(1) / fps;\n\n    // When was last frame executed\n    let mut last_frame = Instant::now();\n\n    loop {\n        // Sleep to avoid blocking the thread for too long\n        if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {\n            tokio::time::sleep(sleep).await;\n        }\n\n        // Get crossterm event and send a new one over the channel\n        if event::poll(Duration::from_secs(0)).expect(\"no events available\") {\n            match event::read().expect(\"unable to read event\") {\n                event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),\n                event::Event::Resize(_w, _h) => {\n                    event_sender.send(Event::Resize).await.unwrap_or(())\n                }\n                _ => (),\n            }\n        }\n\n        // Frame budget exceeded\n        if last_frame.elapsed() >= per_frame {\n            // Send tick\n            event_sender.send(Event::Tick).await.unwrap_or(());\n            // Rest last_frame time\n            last_frame = Instant::now();\n        }\n    }\n}\n"
  },
  {
    "path": "benchmark/src/generation.rs",
    "content": "use std::time::{Duration, Instant};\nuse text_generation_client::v3::{\n    Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,\n    StoppingCriteriaParameters,\n};\nuse text_generation_client::{Chunk, ClientError, Input};\nuse tokenizers::{Tokenizer, TruncationDirection};\nuse tokio::sync::{broadcast, mpsc};\n\nconst LOREM_IPSUM: &str = \"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\";\n\n#[derive(Debug, Clone)]\npub(crate) struct Prefill {\n    pub(crate) latency: Duration,\n    pub(crate) throughput: f64,\n}\n\n#[derive(Debug, Clone)]\npub(crate) struct Decode {\n    pub(crate) latency: Duration,\n    pub(crate) token_latency: Duration,\n    pub(crate) throughput: f64,\n}\n\n#[derive(Debug)]\npub(crate) enum Message {\n    Warmup,\n    Prefill(Prefill),\n    Decode(Decode),\n    EndRun,\n    EndBatch,\n}\n\n/// Benchmarking task\n#[allow(clippy::too_many_arguments)]\npub(crate) async fn generation_task(\n    tokenizer: Tokenizer,\n    batch_size: Vec<u32>,\n    sequence_length: u32,\n    decode_length: u32,\n    top_n_tokens: Option<u32>,\n    n_runs: usize,\n    warmups: usize,\n    parameters: NextTokenChooserParameters,\n    client: ShardedClient,\n    run_sender: mpsc::Sender<Result<Message, ClientError>>,\n    mut shutdown_receiver: broadcast::Receiver<()>,\n    _shutdown_guard_sender: mpsc::Sender<()>,\n) {\n    // End task if a message is received on shutdown_receiver\n    // _shutdown_guard_sender will be dropped once the task is finished\n    tokio::select! {\n        res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone())  => {\n            if let Err(err) = res {\n                run_sender.send(Err(err)).await.unwrap_or(());\n            }\n        },\n        _ = shutdown_receiver.recv() => {}\n    }\n}\n\n/// Benchmark prefill/decode\n#[allow(clippy::too_many_arguments)]\nasync fn generate_runs(\n    tokenizer: Tokenizer,\n    batch_size: Vec<u32>,\n    sequence_length: u32,\n    decode_length: u32,\n    top_n_tokens: Option<u32>,\n    n_runs: usize,\n    warmups: usize,\n    parameters: NextTokenChooserParameters,\n    mut client: ShardedClient,\n    run_sender: mpsc::Sender<Result<Message, ClientError>>,\n) -> Result<(), ClientError> {\n    // Create a dummy sequence\n    let sequence = create_sequence(sequence_length, tokenizer);\n\n    for b in batch_size {\n        // Warmups on batch size\n        for _ in 0..warmups {\n            let (_, decode_batch) = prefill(\n                sequence.clone(),\n                sequence_length,\n                b,\n                decode_length,\n                parameters.clone(),\n                top_n_tokens,\n                &mut client,\n            )\n            .await?;\n            let _ = decode(decode_batch, &mut client).await?;\n            // Send warmup message\n            run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());\n        }\n\n        for _ in 0..n_runs {\n            let (prefill, decode_batch) = prefill(\n                sequence.clone(),\n                sequence_length,\n                b,\n                decode_length,\n                parameters.clone(),\n                top_n_tokens,\n                &mut client,\n            )\n            .await?;\n            // Send prefill message\n            run_sender\n                .send(Ok(Message::Prefill(prefill)))\n                .await\n                .unwrap_or(());\n\n            let decode = decode(decode_batch, &mut client).await?;\n\n            // Send decode message\n            run_sender\n                .send(Ok(Message::Decode(decode)))\n                .await\n                .unwrap_or(());\n\n            // Send run ended message\n            run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());\n        }\n        // Batch ended\n        run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());\n    }\n    Ok(())\n}\n\n// Run a prefill step\nasync fn prefill(\n    sequence: String,\n    sequence_length: u32,\n    batch_size: u32,\n    decode_length: u32,\n    parameters: NextTokenChooserParameters,\n    top_n_tokens: Option<u32>,\n    client: &mut ShardedClient,\n) -> Result<(Prefill, CachedBatch), ClientError> {\n    // Create requests\n    let requests = (0..batch_size)\n        .map(|id| Request {\n            id: id.into(),\n            prefill_logprobs: false,\n            input_chunks: Some(Input {\n                chunks: vec![Chunk::Text(sequence.clone()).into()],\n            }),\n            inputs: sequence.clone(),\n            truncate: sequence_length,\n            add_special_tokens: true,\n            parameters: Some(parameters.clone()),\n            stopping_parameters: Some(StoppingCriteriaParameters {\n                max_new_tokens: decode_length,\n                stop_sequences: vec![],\n                ignore_eos_token: true, // Will not stop even if a eos token is generated\n            }),\n            top_n_tokens: top_n_tokens.unwrap_or(0),\n            blocks: vec![],\n            slots: vec![],\n            cache_len: 0,\n            chunk_len: None,\n            adapter_id: None,\n        })\n        .collect();\n\n    let batch = Batch {\n        id: 0,\n        requests,\n        size: batch_size,\n        max_tokens: batch_size * (sequence_length + decode_length),\n        max_blocks: 0,\n    };\n\n    // Run prefill\n    let start_time = Instant::now();\n    let (_, decode_batch, _) = client.prefill(batch.clone(), None).await?;\n\n    // Get latency\n    let latency = start_time.elapsed();\n\n    // Compute throughput from latency and batch size\n    let throughput = (batch_size * sequence_length) as f64 / latency.as_secs_f64();\n\n    // Decode batch cannot be empty\n    let decode_batch = decode_batch.expect(\"decode_batch is None. This is a bug.\");\n\n    let step = Prefill {\n        latency,\n        throughput,\n    };\n\n    Ok((step, decode_batch))\n}\n\n/// Run a full decode\nasync fn decode(batch: CachedBatch, client: &mut ShardedClient) -> Result<Decode, ClientError> {\n    let mut decode_length = 0;\n    let batch_size = batch.size;\n\n    let start_time = Instant::now();\n\n    // Full decode over decode length\n    let mut next_batch = Some(batch);\n    while let Some(batch) = next_batch {\n        let result = client.decode(vec![batch]).await?;\n        next_batch = result.1;\n        decode_length += 1;\n    }\n\n    // Get latency\n    let latency = start_time.elapsed();\n    let token_latency = latency / decode_length;\n\n    // Compute throughput from latency, batch size and decode length\n    let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();\n\n    let step = Decode {\n        latency,\n        token_latency,\n        throughput,\n    };\n    Ok(step)\n}\n\n/// Create a dummy sequence of the correct length\nfn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {\n    let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();\n    // Repeat lorem ipsum to cover sequence length\n    let string_sequence =\n        LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());\n    // Encode sequence\n    let mut encoding = tokenizer.encode(string_sequence, true).unwrap();\n    // Truncate to sequence_length\n    encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);\n    // Decode\n    tokenizer.decode(encoding.get_ids(), false).unwrap()\n}\n"
  },
  {
    "path": "benchmark/src/lib.rs",
    "content": "mod app;\nmod event;\nmod generation;\nmod table;\nmod utils;\n\nuse crate::app::App;\nuse crate::event::Event;\nuse ratatui::backend::CrosstermBackend;\nuse ratatui::crossterm::ExecutableCommand;\nuse ratatui::Terminal;\nuse std::io;\nuse text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};\nuse tokenizers::Tokenizer;\nuse tokio::sync::{broadcast, mpsc};\n\n/// Run benchmarking app\n#[allow(clippy::too_many_arguments)]\npub async fn run(\n    tokenizer_name: String,\n    tokenizer: Tokenizer,\n    batch_size: Vec<u32>,\n    sequence_length: u32,\n    decode_length: u32,\n    top_n_tokens: Option<u32>,\n    n_runs: usize,\n    warmups: usize,\n    temperature: Option<f32>,\n    top_k: Option<u32>,\n    top_p: Option<f32>,\n    typical_p: Option<f32>,\n    repetition_penalty: Option<f32>,\n    frequency_penalty: Option<f32>,\n    watermark: bool,\n    do_sample: bool,\n    client: ShardedClient,\n) -> Result<(), std::io::Error> {\n    let parameters = NextTokenChooserParameters {\n        temperature: temperature.unwrap_or(1.0),\n        top_k: top_k.unwrap_or(0),\n        top_p: top_p.unwrap_or(1.0),\n        typical_p: typical_p.unwrap_or(1.0),\n        do_sample,\n        seed: 0,\n        repetition_penalty: repetition_penalty.unwrap_or(1.0),\n        frequency_penalty: frequency_penalty.unwrap_or(0.0),\n        watermark,\n        grammar: String::new(),\n        grammar_type: GrammarType::None as i32,\n    };\n\n    // Initialize terminal properties\n    ratatui::crossterm::terminal::enable_raw_mode()?;\n    io::stdout().execute(ratatui::crossterm::terminal::EnterAlternateScreen)?;\n    io::stdout().execute(ratatui::crossterm::cursor::Hide)?;\n\n    // Initialize terminal\n    let mut terminal = {\n        let backend = CrosstermBackend::new(io::stdout());\n        Terminal::new(backend)?\n    };\n\n    // Create message channel between generation_task and app\n    let (run_sender, run_receiver) = mpsc::channel(8);\n    // Crossterm event channel\n    let (event_sender, mut event_receiver) = mpsc::channel(8);\n    // Shutdown channel to terminate tasks\n    let (shutdown_sender, _) = broadcast::channel(1);\n    // Channel to check if tasks terminated\n    let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);\n\n    // Create generation task\n    tokio::spawn(generation::generation_task(\n        tokenizer,\n        batch_size.clone(),\n        sequence_length,\n        decode_length,\n        top_n_tokens,\n        n_runs,\n        warmups,\n        parameters,\n        client,\n        run_sender,\n        shutdown_sender.subscribe(),\n        shutdown_guard_sender.clone(),\n    ));\n\n    // Create event task\n    tokio::spawn(event::terminal_event_task(\n        250,\n        event_sender,\n        shutdown_sender.subscribe(),\n        shutdown_guard_sender.clone(),\n    ));\n\n    // Drop our end of shutdown sender\n    drop(shutdown_guard_sender);\n\n    // Create App\n    let mut app = App::new(\n        run_receiver,\n        tokenizer_name.clone(),\n        sequence_length,\n        decode_length,\n        n_runs,\n        batch_size,\n    );\n\n    while app.running {\n        // Draw frame\n        terminal.draw(|frame| app.render(frame))?;\n\n        // Await a new event from event handling task\n        match event_receiver.recv().await {\n            None => break,\n            // Update app state\n            Some(event) => match event {\n                Event::Tick => app.tick(),\n                Event::Key(key_event) => app.handle_key_event(key_event),\n                _ => {}\n            },\n        }\n    }\n\n    // Ask tasks to shutdown\n    let _ = shutdown_sender.send(());\n    // Wait for tasks to shutdown\n    let _ = shutdown_guard_receiver.recv().await;\n\n    // Revert terminal to original view\n    io::stdout().execute(ratatui::crossterm::terminal::LeaveAlternateScreen)?;\n    ratatui::crossterm::terminal::disable_raw_mode()?;\n    io::stdout().execute(ratatui::crossterm::cursor::Show)?;\n\n    let parameters_table = table::parameters_table(\n        tokenizer_name,\n        sequence_length,\n        decode_length,\n        top_n_tokens,\n        n_runs,\n        warmups,\n        temperature,\n        top_k,\n        top_p,\n        typical_p,\n        repetition_penalty,\n        frequency_penalty,\n        watermark,\n        do_sample,\n    );\n    println!(\"\\n{parameters_table}\\n\");\n\n    let latency_table = table::latency_table(&app.data);\n    println!(\"\\n{latency_table}\\n\");\n\n    let throughput_table = table::throughput_table(&app.data);\n    println!(\"\\n{throughput_table}\\n\");\n\n    Ok(())\n}\n"
  },
  {
    "path": "benchmark/src/main.rs",
    "content": "/// Text Generation Inference benchmarking tool\n///\n/// Inspired by the great Oha app: https://github.com/hatoo/oha\n/// and: https://github.com/orhun/rust-tui-template\nuse clap::Parser;\nuse std::path::Path;\nuse text_generation_client::v3::ShardedClient;\nuse tokenizers::{FromPretrainedParameters, Tokenizer};\nuse tracing_subscriber::layer::SubscriberExt;\nuse tracing_subscriber::util::SubscriberInitExt;\nuse tracing_subscriber::EnvFilter;\n\n/// App Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    /// The name of the tokenizer (as in model_id on the huggingface hub, or local path).\n    #[clap(short, long, env)]\n    tokenizer_name: String,\n\n    /// The revision to use for the tokenizer if on the hub.\n    #[clap(default_value = \"main\", long, env)]\n    revision: String,\n\n    /// The various batch sizes to benchmark for, the idea is to get enough\n    /// batching to start seeing increased latency, this usually means you're\n    /// moving from memory bound (usual as BS=1) to compute bound, and this is\n    /// a sweet spot for the maximum batch size for the model under test\n    #[clap(short, long)]\n    batch_size: Option<Vec<u32>>,\n\n    /// This is the initial prompt sent to the text-generation-server length\n    /// in token. Longer prompt will slow down the benchmark. Usually the\n    /// latency grows somewhat linearly with this for the prefill step.\n    ///\n    /// Most importantly, the prefill step is usually not the one dominating\n    /// your runtime, so it's ok to keep it short.\n    #[clap(default_value = \"10\", short, long, env)]\n    sequence_length: u32,\n\n    /// This is how many tokens will be generated by the server and averaged out\n    /// to give the `decode` latency. This is the *critical* number you want to optimize for\n    /// LLM spend most of their time doing decoding.\n    ///\n    /// Decode latency is usually quite stable.\n    #[clap(default_value = \"8\", short, long, env)]\n    decode_length: u32,\n\n    ///How many runs should we average from\n    #[clap(default_value = \"10\", short, long, env)]\n    runs: usize,\n\n    /// Number of warmup cycles\n    #[clap(default_value = \"1\", short, long, env)]\n    warmups: usize,\n\n    /// The location of the grpc socket. This benchmark tool bypasses the router\n    /// completely and directly talks to the gRPC processes\n    #[clap(default_value = \"/tmp/text-generation-server-0\", short, long, env)]\n    master_shard_uds_path: String,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    temperature: Option<f32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    top_k: Option<u32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    top_p: Option<f32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    typical_p: Option<f32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    repetition_penalty: Option<f32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    frequency_penalty: Option<f32>,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    watermark: bool,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    do_sample: bool,\n\n    /// Generation parameter in case you want to specifically test/debug particular\n    /// decoding strategies, for full doc refer to the `text-generation-server`\n    #[clap(long, env)]\n    top_n_tokens: Option<u32>,\n}\n\nfn main() -> Result<(), Box<dyn std::error::Error>> {\n    init_logging();\n\n    // Get args\n    let args = Args::parse();\n    // Pattern match configuration\n    let Args {\n        tokenizer_name,\n        revision,\n        batch_size,\n        sequence_length,\n        decode_length,\n        runs,\n        warmups,\n        temperature,\n        top_k,\n        top_p,\n        typical_p,\n        repetition_penalty,\n        frequency_penalty,\n        watermark,\n        do_sample,\n        master_shard_uds_path,\n        top_n_tokens,\n    } = args;\n\n    let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);\n\n    // Tokenizer instance\n    // This will only be used to validate payloads\n    tracing::info!(\"Loading tokenizer\");\n    let local_path = Path::new(&tokenizer_name);\n    let tokenizer =\n        if local_path.exists() && local_path.is_dir() && local_path.join(\"tokenizer.json\").exists()\n        {\n            // Load local tokenizer\n            tracing::info!(\"Found local tokenizer\");\n            Tokenizer::from_file(local_path.join(\"tokenizer.json\")).unwrap()\n        } else {\n            tracing::info!(\"Downloading tokenizer\");\n\n            // Parse Huggingface hub token\n            let token = std::env::var(\"HF_TOKEN\")\n                .or_else(|_| std::env::var(\"HUGGING_FACE_HUB_TOKEN\"))\n                .ok();\n\n            // Download and instantiate tokenizer\n            // We need to download it outside of the Tokio runtime\n            let params = FromPretrainedParameters {\n                revision,\n                token,\n                ..Default::default()\n            };\n            Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()\n        };\n    tracing::info!(\"Tokenizer loaded\");\n\n    // Launch Tokio runtime\n    tokio::runtime::Builder::new_multi_thread()\n        .enable_all()\n        .build()\n        .unwrap()\n        .block_on(async {\n            // Instantiate sharded client from the master unix socket\n            tracing::info!(\"Connect to model server\");\n            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)\n                .await\n                .expect(\"Could not connect to server\");\n            // Clear the cache; useful if the webserver rebooted\n            sharded_client\n                .clear_cache(None)\n                .await\n                .expect(\"Unable to clear cache\");\n\n            tracing::info!(\"Connected\");\n\n            // Run app\n            text_generation_benchmark::run(\n                tokenizer_name,\n                tokenizer,\n                batch_size,\n                sequence_length,\n                decode_length,\n                top_n_tokens,\n                runs,\n                warmups,\n                temperature,\n                top_k,\n                top_p,\n                typical_p,\n                repetition_penalty,\n                frequency_penalty,\n                watermark,\n                do_sample,\n                sharded_client,\n            )\n            .await\n            .unwrap();\n        });\n    Ok(())\n}\n\n/// Init logging using LOG_LEVEL\nfn init_logging() {\n    // STDOUT/STDERR layer\n    let fmt_layer = tracing_subscriber::fmt::layer()\n        .with_file(true)\n        .with_line_number(true);\n\n    // Filter events with LOG_LEVEL\n    let env_filter =\n        EnvFilter::try_from_env(\"LOG_LEVEL\").unwrap_or_else(|_| EnvFilter::new(\"info\"));\n\n    tracing_subscriber::registry()\n        .with(env_filter)\n        .with(fmt_layer)\n        .init();\n}\n"
  },
  {
    "path": "benchmark/src/table.rs",
    "content": "use crate::app::Data;\nuse tabled::settings::Merge;\nuse tabled::{builder::Builder, settings::Style, Table};\n\n#[allow(clippy::too_many_arguments)]\npub(crate) fn parameters_table(\n    tokenizer_name: String,\n    sequence_length: u32,\n    decode_length: u32,\n    top_n_tokens: Option<u32>,\n    n_runs: usize,\n    warmups: usize,\n    temperature: Option<f32>,\n    top_k: Option<u32>,\n    top_p: Option<f32>,\n    typical_p: Option<f32>,\n    repetition_penalty: Option<f32>,\n    frequency_penalty: Option<f32>,\n    watermark: bool,\n    do_sample: bool,\n) -> Table {\n    let mut builder = Builder::default();\n\n    builder.set_header([\"Parameter\", \"Value\"]);\n\n    builder.push_record([\"Model\", &tokenizer_name]);\n    builder.push_record([\"Sequence Length\", &sequence_length.to_string()]);\n    builder.push_record([\"Decode Length\", &decode_length.to_string()]);\n    builder.push_record([\"Top N Tokens\", &format!(\"{top_n_tokens:?}\")]);\n    builder.push_record([\"N Runs\", &n_runs.to_string()]);\n    builder.push_record([\"Warmups\", &warmups.to_string()]);\n    builder.push_record([\"Temperature\", &format!(\"{temperature:?}\")]);\n    builder.push_record([\"Top K\", &format!(\"{top_k:?}\")]);\n    builder.push_record([\"Top P\", &format!(\"{top_p:?}\")]);\n    builder.push_record([\"Typical P\", &format!(\"{typical_p:?}\")]);\n    builder.push_record([\"Repetition Penalty\", &format!(\"{repetition_penalty:?}\")]);\n    builder.push_record([\"Frequency Penalty\", &format!(\"{frequency_penalty:?}\")]);\n    builder.push_record([\"Watermark\", &watermark.to_string()]);\n    builder.push_record([\"Do Sample\", &do_sample.to_string()]);\n\n    let mut table = builder.build();\n    table.with(Style::markdown());\n    table\n}\n\npub(crate) fn latency_table(data: &Data) -> Table {\n    let mut builder = Builder::default();\n\n    builder.set_header([\n        \"Step\",\n        \"Batch Size\",\n        \"Average\",\n        \"Lowest\",\n        \"Highest\",\n        \"p50\",\n        \"p90\",\n        \"p99\",\n    ]);\n\n    add_latencies(\n        &mut builder,\n        \"Prefill\",\n        &data.batch_size,\n        &data.prefill_latencies,\n    );\n    add_latencies(\n        &mut builder,\n        \"Decode (token)\",\n        &data.batch_size,\n        &data.decode_token_latencies,\n    );\n    add_latencies(\n        &mut builder,\n        \"Decode (total)\",\n        &data.batch_size,\n        &data.decode_latencies,\n    );\n\n    let mut table = builder.build();\n    table.with(Style::markdown()).with(Merge::vertical());\n    table\n}\n\npub(crate) fn throughput_table(data: &Data) -> Table {\n    let mut builder = Builder::default();\n\n    builder.set_header([\"Step\", \"Batch Size\", \"Average\", \"Lowest\", \"Highest\"]);\n\n    add_throuhgputs(\n        &mut builder,\n        \"Prefill\",\n        &data.batch_size,\n        &data.prefill_throughputs,\n    );\n    add_throuhgputs(\n        &mut builder,\n        \"Decode\",\n        &data.batch_size,\n        &data.decode_throughputs,\n    );\n\n    let mut table = builder.build();\n    table.with(Style::markdown()).with(Merge::vertical());\n    table\n}\n\nfn add_latencies(\n    builder: &mut Builder,\n    step: &'static str,\n    batch_size: &[u32],\n    batch_latencies: &[Vec<f64>],\n) {\n    for (i, b) in batch_size.iter().enumerate() {\n        let latencies = &batch_latencies[i];\n        let (avg, min, max) = avg_min_max(latencies);\n\n        let row = [\n            step,\n            &b.to_string(),\n            &format_value(avg, \"ms\"),\n            &format_value(min, \"ms\"),\n            &format_value(max, \"ms\"),\n            &format_value(px(latencies, 50), \"ms\"),\n            &format_value(px(latencies, 90), \"ms\"),\n            &format_value(px(latencies, 99), \"ms\"),\n        ];\n\n        builder.push_record(row);\n    }\n}\n\nfn add_throuhgputs(\n    builder: &mut Builder,\n    step: &'static str,\n    batch_size: &[u32],\n    batch_throughputs: &[Vec<f64>],\n) {\n    for (i, b) in batch_size.iter().enumerate() {\n        let throughputs = &batch_throughputs[i];\n        let (avg, min, max) = avg_min_max(throughputs);\n\n        let row = [\n            step,\n            &b.to_string(),\n            &format_value(avg, \"tokens/secs\"),\n            &format_value(min, \"tokens/secs\"),\n            &format_value(max, \"tokens/secs\"),\n        ];\n\n        builder.push_record(row);\n    }\n}\n\nfn avg_min_max(data: &[f64]) -> (f64, f64, f64) {\n    let average = data.iter().sum::<f64>() / data.len() as f64;\n    let min = data\n        .iter()\n        .min_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n    let max = data\n        .iter()\n        .max_by(|a, b| a.total_cmp(b))\n        .unwrap_or(&f64::NAN);\n    (average, *min, *max)\n}\n\nfn px(data: &[f64], p: u32) -> f64 {\n    let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;\n    *data.get(i).unwrap_or(&f64::NAN)\n}\n\nfn format_value(value: f64, unit: &'static str) -> String {\n    format!(\"{:.2} {unit}\", value)\n}\n"
  },
  {
    "path": "benchmark/src/utils.rs",
    "content": "/// MIT License\n//\n// Copyright (c) 2020 hatoo\n//\n// Permission is hereby granted, free of charge, to any person obtaining a copy\n// of this software and associated documentation files (the \"Software\"), to deal\n// in the Software without restriction, including without limitation the rights\n// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n// copies of the Software, and to permit persons to whom the Software is\n// furnished to do so, subject to the following conditions:\n//\n// The above copyright notice and this permission notice shall be included in all\n// copies or substantial portions of the Software.\nuse std::collections::BTreeMap;\n\npub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {\n    assert!(bins >= 2);\n    let mut bucket: Vec<usize> = vec![0; bins];\n    let min = values.iter().collect::<average::Min>().min();\n    let max = values.iter().collect::<average::Max>().max();\n    let step = (max - min) / (bins - 1) as f64;\n\n    for &v in values {\n        let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);\n        bucket[i] += 1;\n    }\n\n    bucket\n        .into_iter()\n        .enumerate()\n        .map(|(i, v)| (min + step * i as f64, v))\n        .collect()\n}\n\npub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {\n    pecents\n        .iter()\n        .map(|&p| {\n            let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;\n            (format!(\"p{p}\"), *values.get(i).unwrap_or(&f64::NAN))\n        })\n        .collect()\n}\n"
  },
  {
    "path": "clients/python/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\ntext_generation/__pycache__/\ntext_generation/pb/__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\ntransformers\nsafetensors\n"
  },
  {
    "path": "clients/python/Makefile",
    "content": "unit-tests:\n\tpython -m pytest --cov=text_generation tests\n\ninstall:\n\tpip install pip --upgrade\n\tpip install -e .\n"
  },
  {
    "path": "clients/python/README.md",
    "content": "# Legacy warning ⚠️\nThe inference clients from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference) are recommended over `text_generation`.\n\n# Text Generation\n\nThe Hugging Face Text Generation Python library provides a convenient way of interfacing with a\n`text-generation-inference` instance running on\n[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.\n\n## Get Started\n\n### Install\n\n```shell\npip install text-generation\n```\n\n### Inference API Usage\n\n```python\nfrom text_generation import InferenceAPIClient\n\nclient = InferenceAPIClient(\"bigscience/bloomz\")\ntext = client.generate(\"Why is the sky blue?\").generated_text\nprint(text)\n# ' Rayleigh scattering'\n\n# Token Streaming\ntext = \"\"\nfor response in client.generate_stream(\"Why is the sky blue?\"):\n    if not response.token.special:\n        text += response.token.text\n\nprint(text)\n# ' Rayleigh scattering'\n```\n\nor with the asynchronous client:\n\n```python\nfrom text_generation import InferenceAPIAsyncClient\n\nclient = InferenceAPIAsyncClient(\"bigscience/bloomz\")\nresponse = await client.generate(\"Why is the sky blue?\")\nprint(response.generated_text)\n# ' Rayleigh scattering'\n\n# Token Streaming\ntext = \"\"\nasync for response in client.generate_stream(\"Why is the sky blue?\"):\n    if not response.token.special:\n        text += response.token.text\n\nprint(text)\n# ' Rayleigh scattering'\n```\n\nCheck all currently deployed models on the Huggingface Inference API with `Text Generation` support:\n\n```python\nfrom text_generation.inference_api import deployed_models\n\nprint(deployed_models())\n```\n\n### Hugging Face Inference Endpoint usage\n\n```python\nfrom text_generation import Client\n\nendpoint_url = \"https://YOUR_ENDPOINT.endpoints.huggingface.cloud\"\n\nclient = Client(endpoint_url)\ntext = client.generate(\"Why is the sky blue?\").generated_text\nprint(text)\n# ' Rayleigh scattering'\n\n# Token Streaming\ntext = \"\"\nfor response in client.generate_stream(\"Why is the sky blue?\"):\n    if not response.token.special:\n        text += response.token.text\n\nprint(text)\n# ' Rayleigh scattering'\n```\n\nor with the asynchronous client:\n\n```python\nfrom text_generation import AsyncClient\n\nendpoint_url = \"https://YOUR_ENDPOINT.endpoints.huggingface.cloud\"\n\nclient = AsyncClient(endpoint_url)\nresponse = await client.generate(\"Why is the sky blue?\")\nprint(response.generated_text)\n# ' Rayleigh scattering'\n\n# Token Streaming\ntext = \"\"\nasync for response in client.generate_stream(\"Why is the sky blue?\"):\n    if not response.token.special:\n        text += response.token.text\n\nprint(text)\n# ' Rayleigh scattering'\n```\n\n### Types\n\n```python\n# enum for grammar type\nclass GrammarType(Enum):\n    Json = \"json\"\n    Regex = \"regex\"\n\n\n# Grammar type and value\nclass Grammar:\n    # Grammar type\n    type: GrammarType\n    # Grammar value\n    value: Union[str, dict]\n\nclass Parameters:\n    # Activate logits sampling\n    do_sample: bool\n    # Maximum number of generated tokens\n    max_new_tokens: int\n    # The parameter for repetition penalty. 1.0 means no penalty.\n    # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    repetition_penalty: Optional[float]\n    # The parameter for frequency penalty. 1.0 means no penalty\n    # Penalize new tokens based on their existing frequency in the text so far,\n    # decreasing the model's likelihood to repeat the same line verbatim.\n    frequency_penalty: Optional[float]\n    # Whether to prepend the prompt to the generated text\n    return_full_text: bool\n    # Stop generating tokens if a member of `stop_sequences` is generated\n    stop: List[str]\n    # Random sampling seed\n    seed: Optional[int]\n    # The value used to module the logits distribution.\n    temperature: Optional[float]\n    # The number of highest probability vocabulary tokens to keep for top-k-filtering.\n    top_k: Optional[int]\n    # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n    # higher are kept for generation.\n    top_p: Optional[float]\n    # truncate inputs tokens to the given size\n    truncate: Optional[int]\n    # Typical Decoding mass\n    # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n    typical_p: Optional[float]\n    # Generate best_of sequences and return the one if the highest token logprobs\n    best_of: Optional[int]\n    # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n    watermark: bool\n    # Get generation details\n    details: bool\n    # Get decoder input token logprobs and ids\n    decoder_input_details: bool\n    # Return the N most likely tokens at each step\n    top_n_tokens: Optional[int]\n    # grammar to use for generation\n    grammar: Optional[Grammar]\n\nclass Request:\n    # Prompt\n    inputs: str\n    # Generation parameters\n    parameters: Optional[Parameters]\n    # Whether to stream output tokens\n    stream: bool\n\n# Decoder input tokens\nclass InputToken:\n    # Token ID from the model tokenizer\n    id: int\n    # Token text\n    text: str\n    # Logprob\n    # Optional since the logprob of the first token cannot be computed\n    logprob: Optional[float]\n\n\n# Generated tokens\nclass Token:\n    # Token ID from the model tokenizer\n    id: int\n    # Token text\n    text: str\n    # Logprob\n    logprob: Optional[float]\n    # Is the token a special token\n    # Can be used to ignore tokens when concatenating\n    special: bool\n\n\n# Generation finish reason\nclass FinishReason(Enum):\n    # number of generated tokens == `max_new_tokens`\n    Length = \"length\"\n    # the model generated its end of sequence token\n    EndOfSequenceToken = \"eos_token\"\n    # the model generated a text included in `stop_sequences`\n    StopSequence = \"stop_sequence\"\n\n\n# Additional sequences when using the `best_of` parameter\nclass BestOfSequence:\n    # Generated text\n    generated_text: str\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int]\n    # Decoder input tokens, empty if decoder_input_details is False\n    prefill: List[InputToken]\n    # Generated tokens\n    tokens: List[Token]\n    # Most likely tokens\n    top_tokens: Optional[List[List[Token]]]\n\n\n# `generate` details\nclass Details:\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int]\n    # Decoder input tokens, empty if decoder_input_details is False\n    prefill: List[InputToken]\n    # Generated tokens\n    tokens: List[Token]\n    # Most likely tokens\n    top_tokens: Optional[List[List[Token]]]\n    # Additional sequences when using the `best_of` parameter\n    best_of_sequences: Optional[List[BestOfSequence]]\n\n\n# `generate` return value\nclass Response:\n    # Generated text\n    generated_text: str\n    # Generation details\n    details: Details\n\n\n# `generate_stream` details\nclass StreamDetails:\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int]\n\n\n# `generate_stream` return value\nclass StreamResponse:\n    # Generated token\n    token: Token\n    # Most likely tokens\n    top_tokens: Optional[List[Token]]\n    # Complete generated text\n    # Only available when the generation is finished\n    generated_text: Optional[str]\n    # Generation details\n    # Only available when the generation is finished\n    details: Optional[StreamDetails]\n\n# Inference API currently deployed model\nclass DeployedModel:\n    model_id: str\n    sha: str\n```\n"
  },
  {
    "path": "clients/python/pyproject.toml",
    "content": "[tool.poetry]\nname = \"text-generation\"\nversion = \"0.7.0\"\ndescription = \"Hugging Face Text Generation Python Client\"\nlicense = \"Apache-2.0\"\nauthors = [\"Olivier Dehaene <olivier@huggingface.co>\"]\nmaintainers = [\"Olivier Dehaene <olivier@huggingface.co>\"]\nreadme = \"README.md\"\nhomepage = \"https://github.com/huggingface/text-generation-inference\"\nrepository = \"https://github.com/huggingface/text-generation-inference\"\n\n\n[tool.poetry.dependencies]\npython = \"^3.9\"\npydantic = \"> 2, < 3\"\naiohttp = \"^3.11\"\nhuggingface-hub = \">= 0.12, < 1.0\"\n\n[tool.poetry.group.dev.dependencies]\npytest = \"^8\"\npytest-asyncio = \"^0.26\"\npytest-cov = \"^6.0.0\"\n\n[tool.pytest.ini_options]\nasyncio_mode = \"auto\"\n\n[build-system]\nrequires = [\"poetry-core>=1.0.0\"]\nbuild-backend = \"poetry.core.masonry.api\"\n\n[tool.isort]\nprofile = \"black\"\n"
  },
  {
    "path": "clients/python/tests/conftest.py",
    "content": "import pytest\n\nfrom text_generation import __version__\nfrom huggingface_hub.utils import build_hf_headers\n\n\n@pytest.fixture\ndef flan_t5_xxl():\n    return \"google/flan-t5-xxl\"\n\n\n@pytest.fixture\ndef llama_7b():\n    return \"meta-llama/Llama-2-7b-chat-hf\"\n\n\n@pytest.fixture\ndef fake_model():\n    return \"fake/model\"\n\n\n@pytest.fixture\ndef unsupported_model():\n    return \"google-bert/bert-base-uncased\"\n\n\n@pytest.fixture\ndef base_url():\n    return \"https://api-inference.huggingface.co/models\"\n\n\n@pytest.fixture\ndef bloom_url(base_url, bloom_model):\n    return f\"{base_url}/{bloom_model}\"\n\n\n@pytest.fixture\ndef flan_t5_xxl_url(base_url, flan_t5_xxl):\n    return f\"{base_url}/{flan_t5_xxl}\"\n\n\n@pytest.fixture\ndef llama_7b_url(base_url, llama_7b):\n    return f\"{base_url}/{llama_7b}\"\n\n\n@pytest.fixture\ndef fake_url(base_url, fake_model):\n    return f\"{base_url}/{fake_model}\"\n\n\n@pytest.fixture\ndef unsupported_url(base_url, unsupported_model):\n    return f\"{base_url}/{unsupported_model}\"\n\n\n@pytest.fixture(scope=\"session\")\ndef hf_headers():\n    return build_hf_headers(\n        library_name=\"text-generation-tests\", library_version=__version__\n    )\n"
  },
  {
    "path": "clients/python/tests/test_client.py",
    "content": "import pytest\n\nfrom text_generation import Client, AsyncClient\nfrom text_generation.errors import NotFoundError, ValidationError\nfrom text_generation.types import FinishReason\n\n\ndef test_generate(llama_7b_url, hf_headers):\n    client = Client(llama_7b_url, hf_headers)\n    response = client.generate(\"test\", max_new_tokens=1, decoder_input_details=True)\n\n    assert response.generated_text == \"_\"\n    assert response.details.finish_reason == FinishReason.Length\n    assert response.details.generated_tokens == 1\n    assert response.details.seed is None\n    assert len(response.details.prefill) == 0\n    # assert response.details.prefill[0] == InputToken(id=1, text=\"<s>\", logprob=None)\n    assert len(response.details.tokens) == 1\n    assert response.details.tokens[0].id == 29918\n    assert response.details.tokens[0].text == \"_\"\n    assert not response.details.tokens[0].special\n\n\ndef test_generate_best_of(llama_7b_url, hf_headers):\n    client = Client(llama_7b_url, hf_headers)\n    response = client.generate(\n        \"test\", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True\n    )\n\n    assert response.details.seed is not None\n    assert response.details.best_of_sequences is not None\n    assert len(response.details.best_of_sequences) == 1\n    assert response.details.best_of_sequences[0].seed is not None\n\n\ndef test_generate_not_found(fake_url, hf_headers):\n    client = Client(fake_url, hf_headers)\n    with pytest.raises(NotFoundError):\n        client.generate(\"test\")\n\n\ndef test_generate_validation_error(llama_7b_url, hf_headers):\n    client = Client(llama_7b_url, hf_headers)\n    with pytest.raises(ValidationError):\n        client.generate(\"test\", max_new_tokens=10_000)\n\n\ndef test_generate_stream(llama_7b_url, hf_headers):\n    client = Client(llama_7b_url, hf_headers)\n    responses = [\n        response for response in client.generate_stream(\"test\", max_new_tokens=1)\n    ]\n\n    assert len(responses) == 1\n    response = responses[0]\n\n    assert response.generated_text == \"_\"\n    assert response.details.finish_reason == FinishReason.Length\n    assert response.details.generated_tokens == 1\n    assert response.details.seed is None\n\n\ndef test_generate_stream_not_found(fake_url, hf_headers):\n    client = Client(fake_url, hf_headers)\n    with pytest.raises(NotFoundError):\n        list(client.generate_stream(\"test\"))\n\n\ndef test_generate_stream_validation_error(llama_7b_url, hf_headers):\n    client = Client(llama_7b_url, hf_headers)\n    with pytest.raises(ValidationError):\n        list(client.generate_stream(\"test\", max_new_tokens=10_000))\n\n\n@pytest.mark.asyncio\nasync def test_generate_async(llama_7b_url, hf_headers):\n    client = AsyncClient(llama_7b_url, hf_headers)\n    response = await client.generate(\n        \"test\", max_new_tokens=1, decoder_input_details=True\n    )\n\n    assert response.generated_text == \"_\"\n    assert response.details.finish_reason == FinishReason.Length\n    assert response.details.generated_tokens == 1\n    assert response.details.seed is None\n    assert len(response.details.prefill) == 0\n    # assert response.details.prefill[0] == InputToken(id=1, text=\"<s>\", logprob=None)\n    # assert response.details.prefill[1] == InputToken(\n    #     id=1243, text=\"test\", logprob=-10.96875\n    # )\n    assert len(response.details.tokens) == 1\n    assert response.details.tokens[0].id == 29918\n    assert response.details.tokens[0].text == \"_\"\n    assert not response.details.tokens[0].special\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_best_of(llama_7b_url, hf_headers):\n    client = AsyncClient(llama_7b_url, hf_headers)\n    response = await client.generate(\n        \"test\", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True\n    )\n\n    assert response.details.seed is not None\n    assert response.details.best_of_sequences is not None\n    assert len(response.details.best_of_sequences) == 1\n    assert response.details.best_of_sequences[0].seed is not None\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_not_found(fake_url, hf_headers):\n    client = AsyncClient(fake_url, hf_headers)\n    with pytest.raises(NotFoundError):\n        await client.generate(\"test\")\n\n\n@pytest.mark.asyncio\nasync def test_generate_async_validation_error(llama_7b_url, hf_headers):\n    client = AsyncClient(llama_7b_url, hf_headers)\n    with pytest.raises(ValidationError):\n        await client.generate(\"test\", max_new_tokens=10_000)\n\n\n@pytest.mark.asyncio\nasync def test_generate_stream_async(llama_7b_url, hf_headers):\n    client = AsyncClient(llama_7b_url, hf_headers)\n    responses = [\n        response async for response in client.generate_stream(\"test\", max_new_tokens=1)\n    ]\n\n    assert len(responses) == 1\n    response = responses[0]\n\n    assert response.generated_text == \"_\"\n    assert response.details.finish_reason == FinishReason.Length\n    assert response.details.generated_tokens == 1\n    assert response.details.seed is None\n\n\n@pytest.mark.asyncio\nasync def test_generate_stream_async_not_found(fake_url, hf_headers):\n    client = AsyncClient(fake_url, hf_headers)\n    with pytest.raises(NotFoundError):\n        async for _ in client.generate_stream(\"test\"):\n            pass\n\n\n@pytest.mark.asyncio\nasync def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):\n    client = AsyncClient(llama_7b_url, hf_headers)\n    with pytest.raises(ValidationError):\n        async for _ in client.generate_stream(\"test\", max_new_tokens=10_000):\n            pass\n"
  },
  {
    "path": "clients/python/tests/test_errors.py",
    "content": "from text_generation.errors import (\n    parse_error,\n    GenerationError,\n    IncompleteGenerationError,\n    OverloadedError,\n    ValidationError,\n    BadRequestError,\n    ShardNotReadyError,\n    ShardTimeoutError,\n    NotFoundError,\n    RateLimitExceededError,\n    UnknownError,\n)\n\n\ndef test_generation_error():\n    payload = {\"error_type\": \"generation\", \"error\": \"test\"}\n    assert isinstance(parse_error(400, payload), GenerationError)\n\n\ndef test_incomplete_generation_error():\n    payload = {\"error_type\": \"incomplete_generation\", \"error\": \"test\"}\n    assert isinstance(parse_error(400, payload), IncompleteGenerationError)\n\n\ndef test_overloaded_error():\n    payload = {\"error_type\": \"overloaded\", \"error\": \"test\"}\n    assert isinstance(parse_error(400, payload), OverloadedError)\n\n\ndef test_validation_error():\n    payload = {\"error_type\": \"validation\", \"error\": \"test\"}\n    assert isinstance(parse_error(400, payload), ValidationError)\n\n\ndef test_bad_request_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(400, payload), BadRequestError)\n\n\ndef test_shard_not_ready_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(403, payload), ShardNotReadyError)\n    assert isinstance(parse_error(424, payload), ShardNotReadyError)\n\n\ndef test_shard_timeout_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(504, payload), ShardTimeoutError)\n\n\ndef test_not_found_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(404, payload), NotFoundError)\n\n\ndef test_rate_limit_exceeded_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(429, payload), RateLimitExceededError)\n\n\ndef test_unknown_error():\n    payload = {\"error\": \"test\"}\n    assert isinstance(parse_error(500, payload), UnknownError)\n"
  },
  {
    "path": "clients/python/tests/test_inference_api.py",
    "content": "# import pytest\n#\n# from text_generation import (\n#     InferenceAPIClient,\n#     InferenceAPIAsyncClient,\n#     Client,\n#     AsyncClient,\n# )\n# from text_generation.errors import NotSupportedError, NotFoundError\n# from text_generation.inference_api import check_model_support, deployed_models\n#\n#\n# def test_check_model_support(flan_t5_xxl, unsupported_model, fake_model):\n#     assert check_model_support(flan_t5_xxl)\n#     assert not check_model_support(unsupported_model)\n#\n#     with pytest.raises(NotFoundError):\n#         check_model_support(fake_model)\n#\n#\n# def test_deployed_models():\n#     deployed_models()\n#\n#\n# def test_client(flan_t5_xxl):\n#     client = InferenceAPIClient(flan_t5_xxl)\n#     assert isinstance(client, Client)\n#\n#\n# def test_client_unsupported_model(unsupported_model):\n#     with pytest.raises(NotSupportedError):\n#         InferenceAPIClient(unsupported_model)\n#\n#\n# def test_async_client(flan_t5_xxl):\n#     client = InferenceAPIAsyncClient(flan_t5_xxl)\n#     assert isinstance(client, AsyncClient)\n#\n#\n# def test_async_client_unsupported_model(unsupported_model):\n#     with pytest.raises(NotSupportedError):\n#         InferenceAPIAsyncClient(unsupported_model)\n"
  },
  {
    "path": "clients/python/tests/test_types.py",
    "content": "import pytest\n\nfrom text_generation.types import Parameters, Request\nfrom text_generation.errors import ValidationError\n\n\ndef test_parameters_validation():\n    # Test best_of\n    Parameters(best_of=1)\n    with pytest.raises(ValidationError):\n        Parameters(best_of=0)\n    with pytest.raises(ValidationError):\n        Parameters(best_of=-1)\n    Parameters(best_of=2, do_sample=True)\n    with pytest.raises(ValidationError):\n        Parameters(best_of=2)\n    with pytest.raises(ValidationError):\n        Parameters(best_of=2, seed=1)\n\n    # Test repetition_penalty\n    Parameters(repetition_penalty=1)\n    with pytest.raises(ValidationError):\n        Parameters(repetition_penalty=0)\n    with pytest.raises(ValidationError):\n        Parameters(repetition_penalty=-1)\n\n    # Test seed\n    Parameters(seed=1)\n    with pytest.raises(ValidationError):\n        Parameters(seed=-1)\n\n    # Test temperature\n    Parameters(temperature=1)\n    with pytest.raises(ValidationError):\n        Parameters(temperature=0)\n    with pytest.raises(ValidationError):\n        Parameters(temperature=-1)\n\n    # Test top_k\n    Parameters(top_k=1)\n    with pytest.raises(ValidationError):\n        Parameters(top_k=0)\n    with pytest.raises(ValidationError):\n        Parameters(top_k=-1)\n\n    # Test top_p\n    Parameters(top_p=0.5)\n    with pytest.raises(ValidationError):\n        Parameters(top_p=0)\n    with pytest.raises(ValidationError):\n        Parameters(top_p=-1)\n    with pytest.raises(ValidationError):\n        Parameters(top_p=1)\n\n    # Test truncate\n    Parameters(truncate=1)\n    with pytest.raises(ValidationError):\n        Parameters(truncate=0)\n    with pytest.raises(ValidationError):\n        Parameters(truncate=-1)\n\n    # Test typical_p\n    Parameters(typical_p=0.5)\n    with pytest.raises(ValidationError):\n        Parameters(typical_p=0)\n    with pytest.raises(ValidationError):\n        Parameters(typical_p=-1)\n    with pytest.raises(ValidationError):\n        Parameters(typical_p=1)\n\n\ndef test_request_validation():\n    Request(inputs=\"test\")\n\n    with pytest.raises(ValidationError):\n        Request(inputs=\"\")\n\n    Request(inputs=\"test\", stream=True)\n    Request(inputs=\"test\", parameters=Parameters(best_of=2, do_sample=True))\n\n    with pytest.raises(ValidationError):\n        Request(\n            inputs=\"test\", parameters=Parameters(best_of=2, do_sample=True), stream=True\n        )\n"
  },
  {
    "path": "clients/python/text_generation/__init__.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__version__ = \"0.7.0\"\n\nDEPRECATION_WARNING = (\n    \"`text_generation` clients are deprecated and will be removed in the near future. \"\n    \"Please use the `InferenceClient` from the `huggingface_hub` package instead.\"\n)\n\nfrom text_generation.client import Client, AsyncClient  # noqa E402\nfrom text_generation.inference_api import (  # noqa E402\n    InferenceAPIClient,\n    InferenceAPIAsyncClient,\n)\n\n__all__ = [\n    \"Client\",\n    \"AsyncClient\",\n    \"InferenceAPIClient\",\n    \"InferenceAPIAsyncClient\",\n]\n"
  },
  {
    "path": "clients/python/text_generation/client.py",
    "content": "import json\nimport requests\nimport warnings\n\nfrom aiohttp import ClientSession, ClientTimeout\nfrom pydantic import ValidationError\nfrom typing import Dict, Optional, List, AsyncIterator, Iterator, Union\n\nfrom text_generation import DEPRECATION_WARNING\nfrom text_generation.types import (\n    StreamResponse,\n    Response,\n    Request,\n    Parameters,\n    Grammar,\n    CompletionRequest,\n    Completion,\n    CompletionComplete,\n    ChatRequest,\n    ChatCompletionChunk,\n    ChatComplete,\n    Message,\n    Tool,\n)\nfrom text_generation.errors import parse_error\n\n# emit deprecation warnings\nwarnings.simplefilter(\"always\", DeprecationWarning)\n\n\nclass Client:\n    \"\"\"Client to make calls to a text-generation-inference instance\n\n     Example:\n\n     ```python\n     >>> from text_generation import Client\n\n     >>> client = Client(\"https://api-inference.huggingface.co/models/bigscience/bloomz\")\n     >>> client.generate(\"Why is the sky blue?\").generated_text\n     ' Rayleigh scattering'\n\n     >>> result = \"\"\n     >>> for response in client.generate_stream(\"Why is the sky blue?\"):\n     >>>     if not response.token.special:\n     >>>         result += response.token.text\n     >>> result\n    ' Rayleigh scattering'\n     ```\n    \"\"\"\n\n    def __init__(\n        self,\n        base_url: str,\n        headers: Optional[Dict[str, str]] = None,\n        cookies: Optional[Dict[str, str]] = None,\n        timeout: int = 10,\n    ):\n        \"\"\"\n        Args:\n            base_url (`str`):\n                text-generation-inference instance base url\n            headers (`Optional[Dict[str, str]]`):\n                Additional headers\n            cookies (`Optional[Dict[str, str]]`):\n                Cookies to include in the requests\n            timeout (`int`):\n                Timeout in seconds\n        \"\"\"\n        warnings.warn(DEPRECATION_WARNING, DeprecationWarning)\n        self.base_url = base_url\n        self.headers = headers\n        self.cookies = cookies\n        self.timeout = timeout\n\n    def completion(\n        self,\n        prompt: str,\n        frequency_penalty: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        repetition_penalty: Optional[float] = None,\n        seed: Optional[int] = None,\n        stream: bool = False,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        stop: Optional[List[str]] = None,\n    ):\n        \"\"\"\n        Given a prompt, generate a response synchronously\n\n        Args:\n            prompt (`str`):\n                Prompt\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            max_tokens (`int`):\n                Maximum number of generated tokens\n            repetition_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            seed (`int`):\n                Random sampling seed\n            stream (`bool`):\n                Stream the response\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation\n            stop (`List[str]`):\n                Stop generating tokens if a member of `stop` is generated\n        \"\"\"\n        request = CompletionRequest(\n            model=\"tgi\",\n            prompt=prompt,\n            frequency_penalty=frequency_penalty,\n            max_tokens=max_tokens,\n            repetition_penalty=repetition_penalty,\n            seed=seed,\n            stream=stream,\n            temperature=temperature,\n            top_p=top_p,\n            stop=stop,\n        )\n        if not stream:\n            resp = requests.post(\n                f\"{self.base_url}/v1/completions\",\n                json=request.dict(),\n                headers=self.headers,\n                cookies=self.cookies,\n                timeout=self.timeout,\n            )\n            payload = resp.json()\n            if resp.status_code != 200:\n                raise parse_error(resp.status_code, payload)\n            return Completion(**payload)\n        else:\n            return self._completion_stream_response(request)\n\n    def _completion_stream_response(self, request):\n        resp = requests.post(\n            f\"{self.base_url}/v1/completions\",\n            json=request.dict(),\n            headers=self.headers,\n            cookies=self.cookies,\n            timeout=self.timeout,\n            stream=True,\n        )\n        # iterate and print stream\n        for byte_payload in resp.iter_lines():\n            if byte_payload == b\"\\n\":\n                continue\n            payload = byte_payload.decode(\"utf-8\")\n            if payload.startswith(\"data:\"):\n                json_payload = json.loads(payload.lstrip(\"data:\").rstrip(\"\\n\"))\n                try:\n                    response = CompletionComplete(**json_payload)\n                    yield response\n                except ValidationError:\n                    raise parse_error(resp.status, json_payload)\n\n    def chat(\n        self,\n        messages: List[Message],\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        logit_bias: Optional[List[float]] = None,\n        logprobs: Optional[bool] = None,\n        top_logprobs: Optional[int] = None,\n        max_tokens: Optional[int] = None,\n        n: Optional[int] = None,\n        presence_penalty: Optional[float] = None,\n        stream: bool = False,\n        seed: Optional[int] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Tool]] = None,\n        tool_prompt: Optional[str] = None,\n        tool_choice: Optional[str] = None,\n        stop: Optional[List[str]] = None,\n    ):\n        \"\"\"\n        Given a list of messages, generate a response asynchronously\n\n        Args:\n            messages (`List[Message]`):\n                List of messages\n            repetition_penalty (`float`):\n                The parameter for repetition penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            logit_bias (`List[float]`):\n                Adjust the likelihood of specified tokens\n            logprobs (`bool`):\n                Include log probabilities in the response\n            top_logprobs (`int`):\n                Include the `n` most likely tokens at each step\n            max_tokens (`int`):\n                Maximum number of generated tokens\n            n (`int`):\n                Generate `n` completions\n            presence_penalty (`float`):\n                The parameter for presence penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            stream (`bool`):\n                Stream the response\n            seed (`int`):\n                Random sampling seed\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation\n            tools (`List[Tool]`):\n                List of tools to use\n            tool_prompt (`str`):\n                A prompt to be appended before the tools\n            tool_choice (`str`):\n                The tool to use\n            stop (`List[str]`):\n                Stop generating tokens if a member of `stop` is generated\n\n        \"\"\"\n        request = ChatRequest(\n            model=\"tgi\",\n            messages=messages,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            logit_bias=logit_bias,\n            logprobs=logprobs,\n            top_logprobs=top_logprobs,\n            max_tokens=max_tokens,\n            n=n,\n            presence_penalty=presence_penalty,\n            stream=stream,\n            seed=seed,\n            temperature=temperature,\n            top_p=top_p,\n            tools=tools,\n            tool_prompt=tool_prompt,\n            tool_choice=tool_choice,\n            stop=stop,\n        )\n        if not stream:\n            resp = requests.post(\n                f\"{self.base_url}/v1/chat/completions\",\n                json=request.dict(),\n                headers=self.headers,\n                cookies=self.cookies,\n                timeout=self.timeout,\n            )\n            payload = resp.json()\n            if resp.status_code != 200:\n                raise parse_error(resp.status_code, payload)\n            return ChatComplete(**payload)\n        else:\n            return self._chat_stream_response(request)\n\n    def _chat_stream_response(self, request):\n        resp = requests.post(\n            f\"{self.base_url}/v1/chat/completions\",\n            json=request.dict(),\n            headers=self.headers,\n            cookies=self.cookies,\n            timeout=self.timeout,\n            stream=True,\n        )\n        # iterate and print stream\n        for byte_payload in resp.iter_lines():\n            if byte_payload == b\"\\n\":\n                continue\n            payload = byte_payload.decode(\"utf-8\")\n            if payload.startswith(\"data:\"):\n                json_payload = json.loads(payload.lstrip(\"data:\").rstrip(\"\\n\"))\n                try:\n                    response = ChatCompletionChunk(**json_payload)\n                    yield response\n                except ValidationError:\n                    raise parse_error(resp.status, json_payload)\n\n    def generate(\n        self,\n        prompt: str,\n        do_sample: bool = False,\n        max_new_tokens: int = 20,\n        best_of: Optional[int] = None,\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        return_full_text: bool = False,\n        seed: Optional[int] = None,\n        stop_sequences: Optional[List[str]] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        truncate: Optional[int] = None,\n        typical_p: Optional[float] = None,\n        watermark: bool = False,\n        decoder_input_details: bool = False,\n        top_n_tokens: Optional[int] = None,\n        grammar: Optional[Grammar] = None,\n    ) -> Response:\n        \"\"\"\n        Given a prompt, generate the following text\n\n        Args:\n            prompt (`str`):\n                Input text\n            do_sample (`bool`):\n                Activate logits sampling\n            max_new_tokens (`int`):\n                Maximum number of generated tokens\n            best_of (`int`):\n                Generate best_of sequences and return the one if the highest token logprobs\n            repetition_penalty (`float`):\n                The parameter for repetition penalty. 1.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 1.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            return_full_text (`bool`):\n                Whether to prepend the prompt to the generated text\n            seed (`int`):\n                Random sampling seed\n            stop_sequences (`List[str]`):\n                Stop generating tokens if a member of `stop_sequences` is generated\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_k (`int`):\n                The number of highest probability vocabulary tokens to keep for top-k-filtering.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation.\n            truncate (`int`):\n                Truncate inputs tokens to the given size\n            typical_p (`float`):\n                Typical Decoding mass\n                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n            watermark (`bool`):\n                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n            decoder_input_details (`bool`):\n                Return the decoder input token logprobs and ids\n            top_n_tokens (`int`):\n                Return the `n` most likely tokens at each step\n            grammar (`Grammar`):\n                Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation\n                of the text to match a regular expression or JSON schema.\n\n        Returns:\n            Response: generated response\n        \"\"\"\n        # Validate parameters\n        parameters = Parameters(\n            best_of=best_of,\n            details=True,\n            do_sample=do_sample,\n            max_new_tokens=max_new_tokens,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            return_full_text=return_full_text,\n            seed=seed,\n            stop=stop_sequences if stop_sequences is not None else [],\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            truncate=truncate,\n            typical_p=typical_p,\n            watermark=watermark,\n            decoder_input_details=decoder_input_details,\n            top_n_tokens=top_n_tokens,\n            grammar=grammar,\n        )\n        request = Request(inputs=prompt, stream=False, parameters=parameters)\n\n        resp = requests.post(\n            self.base_url,\n            json=request.dict(),\n            headers=self.headers,\n            cookies=self.cookies,\n            timeout=self.timeout,\n        )\n        payload = resp.json()\n        if resp.status_code != 200:\n            raise parse_error(resp.status_code, payload)\n        return Response(**payload[0])\n\n    def generate_stream(\n        self,\n        prompt: str,\n        do_sample: bool = False,\n        max_new_tokens: int = 20,\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        return_full_text: bool = False,\n        seed: Optional[int] = None,\n        stop_sequences: Optional[List[str]] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        truncate: Optional[int] = None,\n        typical_p: Optional[float] = None,\n        watermark: bool = False,\n        top_n_tokens: Optional[int] = None,\n        grammar: Optional[Grammar] = None,\n    ) -> Iterator[StreamResponse]:\n        \"\"\"\n        Given a prompt, generate the following stream of tokens\n\n        Args:\n            prompt (`str`):\n                Input text\n            do_sample (`bool`):\n                Activate logits sampling\n            max_new_tokens (`int`):\n                Maximum number of generated tokens\n            repetition_penalty (`float`):\n                The parameter for repetition penalty. 1.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 1.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            return_full_text (`bool`):\n                Whether to prepend the prompt to the generated text\n            seed (`int`):\n                Random sampling seed\n            stop_sequences (`List[str]`):\n                Stop generating tokens if a member of `stop_sequences` is generated\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_k (`int`):\n                The number of highest probability vocabulary tokens to keep for top-k-filtering.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation.\n            truncate (`int`):\n                Truncate inputs tokens to the given size\n            typical_p (`float`):\n                Typical Decoding mass\n                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n            watermark (`bool`):\n                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n            top_n_tokens (`int`):\n                Return the `n` most likely tokens at each step\n            grammar (`Grammar`):\n                Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation\n                of the text to match a regular expression or JSON schema.\n\n        Returns:\n            Iterator[StreamResponse]: stream of generated tokens\n        \"\"\"\n        # Validate parameters\n        parameters = Parameters(\n            best_of=None,\n            details=True,\n            decoder_input_details=False,\n            do_sample=do_sample,\n            max_new_tokens=max_new_tokens,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            return_full_text=return_full_text,\n            seed=seed,\n            stop=stop_sequences if stop_sequences is not None else [],\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            truncate=truncate,\n            typical_p=typical_p,\n            watermark=watermark,\n            top_n_tokens=top_n_tokens,\n            grammar=grammar,\n        )\n        request = Request(inputs=prompt, stream=True, parameters=parameters)\n\n        resp = requests.post(\n            self.base_url,\n            json=request.dict(),\n            headers=self.headers,\n            cookies=self.cookies,\n            timeout=self.timeout,\n            stream=True,\n        )\n\n        if resp.status_code != 200:\n            raise parse_error(resp.status_code, resp.json())\n\n        # Parse ServerSentEvents\n        for byte_payload in resp.iter_lines():\n            # Skip line\n            if byte_payload == b\"\\n\":\n                continue\n\n            payload = byte_payload.decode(\"utf-8\")\n\n            # Event data\n            if payload.startswith(\"data:\"):\n                # Decode payload\n                json_payload = json.loads(payload.lstrip(\"data:\").rstrip(\"/n\"))\n                # Parse payload\n                try:\n                    response = StreamResponse(**json_payload)\n                except ValidationError:\n                    # If we failed to parse the payload, then it is an error payload\n                    raise parse_error(resp.status_code, json_payload)\n                yield response\n\n\nclass AsyncClient:\n    \"\"\"Asynchronous Client to make calls to a text-generation-inference instance\n\n     Example:\n\n     ```python\n     >>> from text_generation import AsyncClient\n\n     >>> client = AsyncClient(\"https://api-inference.huggingface.co/models/bigscience/bloomz\")\n     >>> response = await client.generate(\"Why is the sky blue?\")\n     >>> response.generated_text\n     ' Rayleigh scattering'\n\n     >>> result = \"\"\n     >>> async for response in client.generate_stream(\"Why is the sky blue?\"):\n     >>>     if not response.token.special:\n     >>>         result += response.token.text\n     >>> result\n    ' Rayleigh scattering'\n     ```\n    \"\"\"\n\n    def __init__(\n        self,\n        base_url: str,\n        headers: Optional[Dict[str, str]] = None,\n        cookies: Optional[Dict[str, str]] = None,\n        timeout: int = 10,\n    ):\n        \"\"\"\n        Args:\n            base_url (`str`):\n                text-generation-inference instance base url\n            headers (`Optional[Dict[str, str]]`):\n                Additional headers\n            cookies (`Optional[Dict[str, str]]`):\n                Cookies to include in the requests\n            timeout (`int`):\n                Timeout in seconds\n        \"\"\"\n        warnings.warn(DEPRECATION_WARNING, DeprecationWarning)\n        self.base_url = base_url\n        self.headers = headers\n        self.cookies = cookies\n        self.timeout = ClientTimeout(timeout)\n\n    async def completion(\n        self,\n        prompt: str,\n        frequency_penalty: Optional[float] = None,\n        max_tokens: Optional[int] = None,\n        repetition_penalty: Optional[float] = None,\n        seed: Optional[int] = None,\n        stream: bool = False,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        stop: Optional[List[str]] = None,\n    ) -> Union[Completion, AsyncIterator[CompletionComplete]]:\n        \"\"\"\n        Given a prompt, generate a response asynchronously\n\n        Args:\n            prompt (`str`):\n                Prompt\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            max_tokens (`int`):\n                Maximum number of generated tokens\n            repetition_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            seed (`int`):\n                Random sampling seed\n            stream (`bool`):\n                Stream the response\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation\n            stop (`List[str]`):\n                Stop generating tokens if a member of `stop` is generated\n        \"\"\"\n        request = CompletionRequest(\n            model=\"tgi\",\n            prompt=prompt,\n            frequency_penalty=frequency_penalty,\n            max_tokens=max_tokens,\n            repetition_penalty=repetition_penalty,\n            seed=seed,\n            stream=stream,\n            temperature=temperature,\n            top_p=top_p,\n            stop=stop,\n        )\n        if not stream:\n            return await self._completion_single_response(request)\n        else:\n            return self._completion_stream_response(request)\n\n    async def _completion_single_response(self, request):\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(\n                f\"{self.base_url}/v1/completions\", json=request.dict()\n            ) as resp:\n                payload = await resp.json()\n                if resp.status != 200:\n                    raise parse_error(resp.status, payload)\n                return Completion(**payload)\n\n    async def _completion_stream_response(self, request):\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(\n                f\"{self.base_url}/v1/completions\", json=request.dict()\n            ) as resp:\n                async for byte_payload in resp.content:\n                    if byte_payload == b\"\\n\":\n                        continue\n                    payload = byte_payload.decode(\"utf-8\")\n                    if payload.startswith(\"data:\"):\n                        json_payload = json.loads(payload.lstrip(\"data:\").rstrip(\"\\n\"))\n                        try:\n                            response = CompletionComplete(**json_payload)\n                            yield response\n                        except ValidationError:\n                            raise parse_error(resp.status, json_payload)\n\n    async def chat(\n        self,\n        messages: List[Message],\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        logit_bias: Optional[List[float]] = None,\n        logprobs: Optional[bool] = None,\n        top_logprobs: Optional[int] = None,\n        max_tokens: Optional[int] = None,\n        n: Optional[int] = None,\n        presence_penalty: Optional[float] = None,\n        stream: bool = False,\n        seed: Optional[int] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Tool]] = None,\n        tool_prompt: Optional[str] = None,\n        tool_choice: Optional[str] = None,\n        stop: Optional[List[str]] = None,\n    ) -> Union[ChatComplete, AsyncIterator[ChatCompletionChunk]]:\n        \"\"\"\n        Given a list of messages, generate a response asynchronously\n\n        Args:\n            messages (`List[Message]`):\n                List of messages\n            repetition_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 0.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            logit_bias (`List[float]`):\n                Adjust the likelihood of specified tokens\n            logprobs (`bool`):\n                Include log probabilities in the response\n            top_logprobs (`int`):\n                Include the `n` most likely tokens at each step\n            max_tokens (`int`):\n                Maximum number of generated tokens\n            n (`int`):\n                Generate `n` completions\n            presence_penalty (`float`):\n                The parameter for presence penalty. 0.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            stream (`bool`):\n                Stream the response\n            seed (`int`):\n                Random sampling seed\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation\n            tools (`List[Tool]`):\n                List of tools to use\n            tool_prompt (`str`):\n                A prompt to be appended before the tools\n            tool_choice (`str`):\n                The tool to use\n            stop (`List[str]`):\n                Stop generating tokens if a member of `stop` is generated\n\n        \"\"\"\n        request = ChatRequest(\n            model=\"tgi\",\n            messages=messages,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            logit_bias=logit_bias,\n            logprobs=logprobs,\n            top_logprobs=top_logprobs,\n            max_tokens=max_tokens,\n            n=n,\n            presence_penalty=presence_penalty,\n            stream=stream,\n            seed=seed,\n            temperature=temperature,\n            top_p=top_p,\n            tools=tools,\n            tool_prompt=tool_prompt,\n            tool_choice=tool_choice,\n            stop=stop,\n        )\n        if not stream:\n            return await self._chat_single_response(request)\n        else:\n            return self._chat_stream_response(request)\n\n    async def _chat_single_response(self, request):\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(\n                f\"{self.base_url}/v1/chat/completions\", json=request.dict()\n            ) as resp:\n                payload = await resp.json()\n                if resp.status != 200:\n                    raise parse_error(resp.status, payload)\n                return ChatComplete(**payload)\n\n    async def _chat_stream_response(self, request):\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(\n                f\"{self.base_url}/v1/chat/completions\", json=request.dict()\n            ) as resp:\n                async for byte_payload in resp.content:\n                    if byte_payload == b\"\\n\":\n                        continue\n                    payload = byte_payload.decode(\"utf-8\")\n                    if payload.startswith(\"data:\"):\n                        payload_data = (\n                            payload.lstrip(\"data:\").rstrip(\"\\n\").removeprefix(\" \")\n                        )\n                        if payload_data == \"[DONE]\":\n                            break\n                        json_payload = json.loads(payload_data)\n                        try:\n                            response = ChatCompletionChunk(**json_payload)\n                            yield response\n                        except ValidationError:\n                            raise parse_error(resp.status, json_payload)\n\n    async def generate(\n        self,\n        prompt: str,\n        do_sample: bool = False,\n        max_new_tokens: int = 20,\n        best_of: Optional[int] = None,\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        return_full_text: bool = False,\n        seed: Optional[int] = None,\n        stop_sequences: Optional[List[str]] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        truncate: Optional[int] = None,\n        typical_p: Optional[float] = None,\n        watermark: bool = False,\n        decoder_input_details: bool = False,\n        top_n_tokens: Optional[int] = None,\n        grammar: Optional[Grammar] = None,\n    ) -> Response:\n        \"\"\"\n        Given a prompt, generate the following text asynchronously\n\n        Args:\n            prompt (`str`):\n                Input text\n            do_sample (`bool`):\n                Activate logits sampling\n            max_new_tokens (`int`):\n                Maximum number of generated tokens\n            best_of (`int`):\n                Generate best_of sequences and return the one if the highest token logprobs\n            repetition_penalty (`float`):\n                The parameter for repetition penalty. 1.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 1.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            return_full_text (`bool`):\n                Whether to prepend the prompt to the generated text\n            seed (`int`):\n                Random sampling seed\n            stop_sequences (`List[str]`):\n                Stop generating tokens if a member of `stop_sequences` is generated\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_k (`int`):\n                The number of highest probability vocabulary tokens to keep for top-k-filtering.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation.\n            truncate (`int`):\n                Truncate inputs tokens to the given size\n            typical_p (`float`):\n                Typical Decoding mass\n                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n            watermark (`bool`):\n                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n            decoder_input_details (`bool`):\n                Return the decoder input token logprobs and ids\n            top_n_tokens (`int`):\n                Return the `n` most likely tokens at each step\n            grammar (`Grammar`):\n                Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation\n                of the text to match a regular expression or JSON schema.\n\n        Returns:\n            Response: generated response\n        \"\"\"\n\n        # Validate parameters\n        parameters = Parameters(\n            best_of=best_of,\n            details=True,\n            decoder_input_details=decoder_input_details,\n            do_sample=do_sample,\n            max_new_tokens=max_new_tokens,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            return_full_text=return_full_text,\n            seed=seed,\n            stop=stop_sequences if stop_sequences is not None else [],\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            truncate=truncate,\n            typical_p=typical_p,\n            watermark=watermark,\n            top_n_tokens=top_n_tokens,\n            grammar=grammar,\n        )\n        request = Request(inputs=prompt, stream=False, parameters=parameters)\n\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(self.base_url, json=request.model_dump()) as resp:\n                payload = await resp.json()\n\n                if resp.status != 200:\n                    raise parse_error(resp.status, payload)\n                return Response(**payload[0])\n\n    async def generate_stream(\n        self,\n        prompt: str,\n        do_sample: bool = False,\n        max_new_tokens: int = 20,\n        repetition_penalty: Optional[float] = None,\n        frequency_penalty: Optional[float] = None,\n        return_full_text: bool = False,\n        seed: Optional[int] = None,\n        stop_sequences: Optional[List[str]] = None,\n        temperature: Optional[float] = None,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        truncate: Optional[int] = None,\n        typical_p: Optional[float] = None,\n        watermark: bool = False,\n        top_n_tokens: Optional[int] = None,\n        grammar: Optional[Grammar] = None,\n    ) -> AsyncIterator[StreamResponse]:\n        \"\"\"\n        Given a prompt, generate the following stream of tokens asynchronously\n\n        Args:\n            prompt (`str`):\n                Input text\n            do_sample (`bool`):\n                Activate logits sampling\n            max_new_tokens (`int`):\n                Maximum number of generated tokens\n            repetition_penalty (`float`):\n                The parameter for repetition penalty. 1.0 means no penalty. See [this\n                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n            frequency_penalty (`float`):\n                The parameter for frequency penalty. 1.0 means no penalty\n                Penalize new tokens based on their existing frequency in the text so far,\n                decreasing the model's likelihood to repeat the same line verbatim.\n            return_full_text (`bool`):\n                Whether to prepend the prompt to the generated text\n            seed (`int`):\n                Random sampling seed\n            stop_sequences (`List[str]`):\n                Stop generating tokens if a member of `stop_sequences` is generated\n            temperature (`float`):\n                The value used to module the logits distribution.\n            top_k (`int`):\n                The number of highest probability vocabulary tokens to keep for top-k-filtering.\n            top_p (`float`):\n                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n                higher are kept for generation.\n            truncate (`int`):\n                Truncate inputs tokens to the given size\n            typical_p (`float`):\n                Typical Decoding mass\n                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n            watermark (`bool`):\n                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n            top_n_tokens (`int`):\n                Return the `n` most likely tokens at each step\n            grammar (`Grammar`):\n                Whether to use a grammar for the generation and the grammar to use. Grammars will constrain the generation\n                of the text to match a regular expression or JSON schema.\n\n        Returns:\n            AsyncIterator[StreamResponse]: stream of generated tokens\n        \"\"\"\n        # Validate parameters\n        parameters = Parameters(\n            best_of=None,\n            details=True,\n            decoder_input_details=False,\n            do_sample=do_sample,\n            max_new_tokens=max_new_tokens,\n            repetition_penalty=repetition_penalty,\n            frequency_penalty=frequency_penalty,\n            return_full_text=return_full_text,\n            seed=seed,\n            stop=stop_sequences if stop_sequences is not None else [],\n            temperature=temperature,\n            top_k=top_k,\n            top_p=top_p,\n            truncate=truncate,\n            typical_p=typical_p,\n            watermark=watermark,\n            top_n_tokens=top_n_tokens,\n            grammar=grammar,\n        )\n        request = Request(inputs=prompt, stream=True, parameters=parameters)\n\n        async with ClientSession(\n            headers=self.headers, cookies=self.cookies, timeout=self.timeout\n        ) as session:\n            async with session.post(self.base_url, json=request.dict()) as resp:\n                if resp.status != 200:\n                    raise parse_error(resp.status, await resp.json())\n\n                # Parse ServerSentEvents\n                async for byte_payload in resp.content:\n                    # Skip line\n                    if byte_payload == b\"\\n\":\n                        continue\n\n                    payload = byte_payload.decode(\"utf-8\")\n\n                    # Event data\n                    if payload.startswith(\"data:\"):\n                        # Decode payload\n                        json_payload = json.loads(payload.lstrip(\"data:\").rstrip(\"/n\"))\n                        # Parse payload\n                        try:\n                            response = StreamResponse(**json_payload)\n                        except ValidationError:\n                            # If we failed to parse the payload, then it is an error payload\n                            raise parse_error(resp.status, json_payload)\n                        yield response\n"
  },
  {
    "path": "clients/python/text_generation/errors.py",
    "content": "from typing import Dict\n\n\n# Text Generation Inference Errors\nclass ValidationError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass GenerationError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass OverloadedError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass IncompleteGenerationError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\n# API Inference Errors\nclass BadRequestError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass ShardNotReadyError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass ShardTimeoutError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass NotFoundError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass RateLimitExceededError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\nclass NotSupportedError(Exception):\n    def __init__(self, model_id: str):\n        message = (\n            f\"Model `{model_id}` is not available for inference with this client. \\n\"\n            \"Use `huggingface_hub.inference_api.InferenceApi` instead.\"\n        )\n        super(NotSupportedError, self).__init__(message)\n\n\n# Unknown error\nclass UnknownError(Exception):\n    def __init__(self, message: str):\n        super().__init__(message)\n\n\ndef parse_error(status_code: int, payload: Dict[str, str]) -> Exception:\n    \"\"\"\n    Parse error given an HTTP status code and a json payload\n\n    Args:\n        status_code (`int`):\n            HTTP status code\n        payload (`Dict[str, str]`):\n            Json payload\n\n    Returns:\n        Exception: parsed exception\n\n    \"\"\"\n    # Try to parse a Text Generation Inference error\n    message = payload[\"error\"]\n    if \"error_type\" in payload:\n        error_type = payload[\"error_type\"]\n        if error_type == \"generation\":\n            return GenerationError(message)\n        if error_type == \"incomplete_generation\":\n            return IncompleteGenerationError(message)\n        if error_type == \"overloaded\":\n            return OverloadedError(message)\n        if error_type == \"validation\":\n            return ValidationError(message)\n\n    # Try to parse a APIInference error\n    if status_code == 400:\n        return BadRequestError(message)\n    if status_code == 403 or status_code == 424:\n        return ShardNotReadyError(message)\n    if status_code == 504:\n        return ShardTimeoutError(message)\n    if status_code == 404:\n        return NotFoundError(message)\n    if status_code == 429:\n        return RateLimitExceededError(message)\n\n    # Fallback to an unknown error\n    return UnknownError(message)\n"
  },
  {
    "path": "clients/python/text_generation/inference_api.py",
    "content": "import os\nimport requests\n\nfrom typing import Dict, Optional, List\nfrom huggingface_hub.utils import build_hf_headers\n\nfrom text_generation import Client, AsyncClient, __version__\nfrom text_generation.types import DeployedModel\nfrom text_generation.errors import NotSupportedError, parse_error\n\nINFERENCE_ENDPOINT = os.environ.get(\n    \"HF_INFERENCE_ENDPOINT\", \"https://api-inference.huggingface.co\"\n)\n\n\ndef deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:\n    \"\"\"\n    Get all currently deployed models with text-generation-inference-support\n\n    Returns:\n        List[DeployedModel]: list of all currently deployed models\n    \"\"\"\n    resp = requests.get(\n        \"https://api-inference.huggingface.co/framework/text-generation-inference\",\n        headers=headers,\n        timeout=5,\n    )\n\n    payload = resp.json()\n    if resp.status_code != 200:\n        raise parse_error(resp.status_code, payload)\n\n    models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload]\n    return models\n\n\ndef check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool:\n    \"\"\"\n    Check if a given model is supported by text-generation-inference\n\n    Returns:\n        bool: whether the model is supported by this client\n    \"\"\"\n    resp = requests.get(\n        f\"https://api-inference.huggingface.co/status/{repo_id}\",\n        headers=headers,\n        timeout=5,\n    )\n\n    payload = resp.json()\n    if resp.status_code != 200:\n        raise parse_error(resp.status_code, payload)\n\n    framework = payload[\"framework\"]\n    supported = framework == \"text-generation-inference\"\n    return supported\n\n\nclass InferenceAPIClient(Client):\n    \"\"\"Client to make calls to the HuggingFace Inference API.\n\n     Only supports a subset of the available text-generation or text2text-generation models that are served using\n     text-generation-inference\n\n     Example:\n\n     ```python\n     >>> from text_generation import InferenceAPIClient\n\n     >>> client = InferenceAPIClient(\"bigscience/bloomz\")\n     >>> client.generate(\"Why is the sky blue?\").generated_text\n     ' Rayleigh scattering'\n\n     >>> result = \"\"\n     >>> for response in client.generate_stream(\"Why is the sky blue?\"):\n     >>>     if not response.token.special:\n     >>>         result += response.token.text\n     >>> result\n    ' Rayleigh scattering'\n     ```\n    \"\"\"\n\n    def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):\n        \"\"\"\n        Init headers and API information\n\n        Args:\n            repo_id (`str`):\n                Id of repository (e.g. `bigscience/bloom`).\n            token (`str`, `optional`):\n                The API token to use as HTTP bearer authorization. This is not\n                the authentication token. You can find the token in\n                https://huggingface.co/settings/token. Alternatively, you can\n                find both your organizations and personal API tokens using\n                `HfApi().whoami(token)`.\n            timeout (`int`):\n                Timeout in seconds\n        \"\"\"\n\n        headers = build_hf_headers(\n            token=token, library_name=\"text-generation\", library_version=__version__\n        )\n\n        # Text Generation Inference client only supports a subset of the available hub models\n        if not check_model_support(repo_id, headers):\n            raise NotSupportedError(repo_id)\n\n        base_url = f\"{INFERENCE_ENDPOINT}/models/{repo_id}\"\n\n        super(InferenceAPIClient, self).__init__(\n            base_url, headers=headers, timeout=timeout\n        )\n\n\nclass InferenceAPIAsyncClient(AsyncClient):\n    \"\"\"Aynschronous Client to make calls to the HuggingFace Inference API.\n\n     Only supports a subset of the available text-generation or text2text-generation models that are served using\n     text-generation-inference\n\n     Example:\n\n     ```python\n     >>> from text_generation import InferenceAPIAsyncClient\n\n     >>> client = InferenceAPIAsyncClient(\"bigscience/bloomz\")\n     >>> response = await client.generate(\"Why is the sky blue?\")\n     >>> response.generated_text\n     ' Rayleigh scattering'\n\n     >>> result = \"\"\n     >>> async for response in client.generate_stream(\"Why is the sky blue?\"):\n     >>>     if not response.token.special:\n     >>>         result += response.token.text\n     >>> result\n    ' Rayleigh scattering'\n     ```\n    \"\"\"\n\n    def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10):\n        \"\"\"\n        Init headers and API information\n\n        Args:\n            repo_id (`str`):\n                Id of repository (e.g. `bigscience/bloom`).\n            token (`str`, `optional`):\n                The API token to use as HTTP bearer authorization. This is not\n                the authentication token. You can find the token in\n                https://huggingface.co/settings/token. Alternatively, you can\n                find both your organizations and personal API tokens using\n                `HfApi().whoami(token)`.\n            timeout (`int`):\n                Timeout in seconds\n        \"\"\"\n        headers = build_hf_headers(\n            token=token, library_name=\"text-generation\", library_version=__version__\n        )\n\n        # Text Generation Inference client only supports a subset of the available hub models\n        if not check_model_support(repo_id, headers):\n            raise NotSupportedError(repo_id)\n\n        base_url = f\"{INFERENCE_ENDPOINT}/models/{repo_id}\"\n\n        super(InferenceAPIAsyncClient, self).__init__(\n            base_url, headers=headers, timeout=timeout\n        )\n"
  },
  {
    "path": "clients/python/text_generation/types.py",
    "content": "from enum import Enum\nfrom pydantic import BaseModel, field_validator, ConfigDict\nfrom typing import Optional, List, Union, Any\n\nfrom text_generation.errors import ValidationError\n\n\n# enum for grammar type\nclass GrammarType(str, Enum):\n    Json = \"json\"\n    Regex = \"regex\"\n\n\n# Grammar type and value\nclass Grammar(BaseModel):\n    # Grammar type\n    type: GrammarType\n    # Grammar value\n    value: Union[str, dict]\n\n\nclass ToolCall(BaseModel):\n    # Id of the tool call\n    id: int\n    # Type of the tool call\n    type: str\n    # Function details of the tool call\n    function: dict\n\n\nclass Chunk(BaseModel):\n    type: str\n    text: Optional[str] = None\n    image_url: Any = None\n\n\nclass Message(BaseModel):\n    # Role of the message sender\n    role: str\n    # Content of the message\n    content: Optional[Union[str, List[Chunk]]] = None\n    # Optional name of the message sender\n    name: Optional[str] = None\n    # Tool calls associated with the chat completion\n    tool_calls: Optional[Any] = None\n\n\nclass Tool(BaseModel):\n    # Type of the tool\n    type: str\n    # Function details of the tool\n    function: dict\n\n\nclass Function(BaseModel):\n    name: Optional[str]\n    arguments: str\n\n\nclass ChoiceDeltaToolCall(BaseModel):\n    index: int\n    id: str\n    type: str\n    function: Function\n\n\nclass ChoiceDelta(BaseModel):\n    role: str\n    content: Optional[str] = None\n    tool_calls: Optional[List[ChoiceDeltaToolCall]] = None\n\n\nclass Choice(BaseModel):\n    index: int\n    delta: ChoiceDelta\n    logprobs: Optional[dict] = None\n    finish_reason: Optional[str] = None\n\n\nclass CompletionRequest(BaseModel):\n    # Model identifier\n    model: str\n    # Prompt\n    prompt: str\n    # The parameter for repetition penalty. 1.0 means no penalty.\n    # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    repetition_penalty: Optional[float] = None\n    # The parameter for frequency penalty. 1.0 means no penalty\n    # Penalize new tokens based on their existing frequency in the text so far,\n    # decreasing the model's likelihood to repeat the same line verbatim.\n    frequency_penalty: Optional[float] = None\n    # Maximum number of tokens to generate\n    max_tokens: Optional[int] = None\n    # Flag to indicate streaming response\n    stream: bool = False\n    # Random sampling seed\n    seed: Optional[int] = None\n    # Sampling temperature\n    temperature: Optional[float] = None\n    # Top-p value for nucleus sampling\n    top_p: Optional[float] = None\n    # Stop generating tokens if a member of `stop` is generated\n    stop: Optional[List[str]] = None\n\n\nclass CompletionComplete(BaseModel):\n    # Index of the chat completion\n    index: int\n    # Message associated with the chat completion\n    text: str\n    # Log probabilities for the chat completion\n    logprobs: Optional[Any]\n    # Reason for completion\n    finish_reason: str\n\n\nclass Completion(BaseModel):\n    # Completion details\n    id: str\n    object: str\n    created: int\n    model: str\n    system_fingerprint: str\n    choices: List[CompletionComplete]\n\n\nclass ChatRequest(BaseModel):\n    # Model identifier\n    model: str\n    # List of messages in the conversation\n    messages: List[Message]\n    # The parameter for repetition penalty. 1.0 means no penalty.\n    # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    repetition_penalty: Optional[float] = None\n    # The parameter for frequency penalty. 1.0 means no penalty\n    # Penalize new tokens based on their existing frequency in the text so far,\n    # decreasing the model's likelihood to repeat the same line verbatim.\n    frequency_penalty: Optional[float] = None\n    # Bias values for token selection\n    logit_bias: Optional[List[float]] = None\n    # Whether to return log probabilities\n    logprobs: Optional[bool] = None\n    # Number of most likely tokens to return at each position\n    top_logprobs: Optional[int] = None\n    # Maximum number of tokens to generate\n    max_tokens: Optional[int] = None\n    # Number of chat completion choices to generate\n    n: Optional[int] = None\n    # Penalty for presence of new tokens\n    presence_penalty: Optional[float] = None\n    # Flag to indicate streaming response\n    stream: bool = False\n    # Random sampling seed\n    seed: Optional[int] = None\n    # Sampling temperature\n    temperature: Optional[float] = None\n    # Top-p value for nucleus sampling\n    top_p: Optional[float] = None\n    # List of tools to be used\n    tools: Optional[List[Tool]] = None\n    # A prompt to be appended before the tools\n    tool_prompt: Optional[str] = None\n    # Choice of tool to be used\n    tool_choice: Optional[str] = None\n    # Stop generating tokens if a member of `stop` is generated\n    stop: Optional[List[str]] = None\n\n\nclass ChatCompletionComplete(BaseModel):\n    # Index of the chat completion\n    index: int\n    # Message associated with the chat completion\n    message: Message\n    # Log probabilities for the chat completion\n    logprobs: Optional[Any]\n    # Reason for completion\n    finish_reason: Optional[str]\n    # Usage details of the chat completion\n    usage: Optional[Any] = None\n\n\nclass ChatComplete(BaseModel):\n    # Chat completion details\n    id: str\n    object: str\n    created: int\n    model: str\n    system_fingerprint: str\n    choices: List[ChatCompletionComplete]\n    usage: Any\n\n\nclass ChatCompletionChunk(BaseModel):\n    id: str\n    object: str\n    created: int\n    model: str\n    system_fingerprint: str\n    choices: List[Choice]\n    usage: Optional[Any] = None\n\n\nclass Parameters(BaseModel):\n    # Activate logits sampling\n    do_sample: bool = False\n    # Maximum number of generated tokens\n    max_new_tokens: int = 20\n    # The parameter for repetition penalty. 1.0 means no penalty.\n    # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    repetition_penalty: Optional[float] = None\n    # The parameter for frequency penalty. 1.0 means no penalty\n    # Penalize new tokens based on their existing frequency in the text so far,\n    # decreasing the model's likelihood to repeat the same line verbatim.\n    frequency_penalty: Optional[float] = None\n    # Whether to prepend the prompt to the generated text\n    return_full_text: bool = False\n    # Stop generating tokens if a member of `stop_sequences` is generated\n    stop: List[str] = []\n    # Random sampling seed\n    seed: Optional[int] = None\n    # The value used to module the logits distribution.\n    temperature: Optional[float] = None\n    # The number of highest probability vocabulary tokens to keep for top-k-filtering.\n    top_k: Optional[int] = None\n    # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n    # higher are kept for generation.\n    top_p: Optional[float] = None\n    # truncate inputs tokens to the given size\n    truncate: Optional[int] = None\n    # Typical Decoding mass\n    # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information\n    typical_p: Optional[float] = None\n    # Generate best_of sequences and return the one if the highest token logprobs\n    best_of: Optional[int] = None\n    # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n    watermark: bool = False\n    # Get generation details\n    details: bool = False\n    # Get decoder input token logprobs and ids\n    decoder_input_details: bool = False\n    # Return the N most likely tokens at each step\n    top_n_tokens: Optional[int] = None\n    # grammar to use for generation\n    grammar: Optional[Grammar] = None\n\n    @field_validator(\"best_of\")\n    def valid_best_of(cls, field_value, values):\n        if field_value is not None:\n            if field_value <= 0:\n                raise ValidationError(\"`best_of` must be strictly positive\")\n            if field_value > 1 and values.data[\"seed\"] is not None:\n                raise ValidationError(\"`seed` must not be set when `best_of` is > 1\")\n            sampling = (\n                values.data[\"do_sample\"]\n                | (values.data[\"temperature\"] is not None)\n                | (values.data[\"top_k\"] is not None)\n                | (values.data[\"top_p\"] is not None)\n                | (values.data[\"typical_p\"] is not None)\n            )\n            if field_value > 1 and not sampling:\n                raise ValidationError(\"you must use sampling when `best_of` is > 1\")\n\n        return field_value\n\n    @field_validator(\"repetition_penalty\")\n    def valid_repetition_penalty(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`repetition_penalty` must be strictly positive\")\n        return v\n\n    @field_validator(\"frequency_penalty\")\n    def valid_frequency_penalty(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`frequency_penalty` must be strictly positive\")\n        return v\n\n    @field_validator(\"seed\")\n    def valid_seed(cls, v):\n        if v is not None and v < 0:\n            raise ValidationError(\"`seed` must be positive\")\n        return v\n\n    @field_validator(\"temperature\")\n    def valid_temp(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`temperature` must be strictly positive\")\n        return v\n\n    @field_validator(\"top_k\")\n    def valid_top_k(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`top_k` must be strictly positive\")\n        return v\n\n    @field_validator(\"top_p\")\n    def valid_top_p(cls, v):\n        if v is not None and (v <= 0 or v >= 1.0):\n            raise ValidationError(\"`top_p` must be > 0.0 and < 1.0\")\n        return v\n\n    @field_validator(\"truncate\")\n    def valid_truncate(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`truncate` must be strictly positive\")\n        return v\n\n    @field_validator(\"typical_p\")\n    def valid_typical_p(cls, v):\n        if v is not None and (v <= 0 or v >= 1.0):\n            raise ValidationError(\"`typical_p` must be > 0.0 and < 1.0\")\n        return v\n\n    @field_validator(\"top_n_tokens\")\n    def valid_top_n_tokens(cls, v):\n        if v is not None and v <= 0:\n            raise ValidationError(\"`top_n_tokens` must be strictly positive\")\n        return v\n\n    @field_validator(\"grammar\")\n    def valid_grammar(cls, v):\n        if v is not None:\n            if v.type == GrammarType.Regex and not v.value:\n                raise ValidationError(\"`value` cannot be empty for `regex` grammar\")\n            if v.type == GrammarType.Json and not v.value:\n                raise ValidationError(\"`value` cannot be empty for `json` grammar\")\n        return v\n\n\nclass Request(BaseModel):\n    # Prompt\n    inputs: str\n    # Generation parameters\n    parameters: Optional[Parameters] = None\n    # Whether to stream output tokens\n    stream: bool = False\n\n    @field_validator(\"inputs\")\n    def valid_input(cls, v):\n        if not v:\n            raise ValidationError(\"`inputs` cannot be empty\")\n        return v\n\n    @field_validator(\"stream\")\n    def valid_best_of_stream(cls, field_value, values):\n        parameters = values.data[\"parameters\"]\n        if (\n            parameters is not None\n            and parameters.best_of is not None\n            and parameters.best_of > 1\n            and field_value\n        ):\n            raise ValidationError(\n                \"`best_of` != 1 is not supported when `stream` == True\"\n            )\n        return field_value\n\n\n# Decoder input tokens\nclass InputToken(BaseModel):\n    # Token ID from the model tokenizer\n    id: int\n    # Token text\n    text: str\n    # Logprob\n    # Optional since the logprob of the first token cannot be computed\n    logprob: Optional[float] = None\n\n\n# Generated tokens\nclass Token(BaseModel):\n    # Token ID from the model tokenizer\n    id: int\n    # Token text\n    text: str\n    # Logprob\n    logprob: Optional[float] = None\n    # Is the token a special token\n    # Can be used to ignore tokens when concatenating\n    special: bool\n\n\n# Generation finish reason\nclass FinishReason(str, Enum):\n    # number of generated tokens == `max_new_tokens`\n    Length = \"length\"\n    # the model generated its end of sequence token\n    EndOfSequenceToken = \"eos_token\"\n    # the model generated a text included in `stop_sequences`\n    StopSequence = \"stop_sequence\"\n\n\n# Additional sequences when using the `best_of` parameter\nclass BestOfSequence(BaseModel):\n    # Generated text\n    generated_text: str\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int] = None\n    # Decoder input tokens, empty if decoder_input_details is False\n    prefill: List[InputToken]\n    # Generated tokens\n    tokens: List[Token]\n    # Most likely tokens\n    top_tokens: Optional[List[List[Token]]] = None\n\n\n# `generate` details\nclass Details(BaseModel):\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int] = None\n    # Decoder input tokens, empty if decoder_input_details is False\n    prefill: List[InputToken]\n    # Generated tokens\n    tokens: List[Token]\n    # Most likely tokens\n    top_tokens: Optional[List[List[Token]]] = None\n    # Additional sequences when using the `best_of` parameter\n    best_of_sequences: Optional[List[BestOfSequence]] = None\n\n\n# `generate` return value\nclass Response(BaseModel):\n    # Generated text\n    generated_text: str\n    # Generation details\n    details: Details\n\n\n# `generate_stream` details\nclass StreamDetails(BaseModel):\n    # Generation finish reason\n    finish_reason: FinishReason\n    # Number of generated tokens\n    generated_tokens: int\n    # Sampling seed if sampling was activated\n    seed: Optional[int] = None\n\n\n# `generate_stream` return value\nclass StreamResponse(BaseModel):\n    # Generated token\n    token: Token\n    # Most likely tokens\n    top_tokens: Optional[List[Token]] = None\n    # Complete generated text\n    # Only available when the generation is finished\n    generated_text: Optional[str] = None\n    # Generation details\n    # Only available when the generation is finished\n    details: Optional[StreamDetails] = None\n\n\n# Inference API currently deployed model\nclass DeployedModel(BaseModel):\n    # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members\n    # with model_ prefixes, since this disables guardrails for colliding fields:\n    # https://github.com/pydantic/pydantic/issues/9177\n    model_config = ConfigDict(protected_namespaces=())\n    model_id: str\n    sha: str\n"
  },
  {
    "path": "crate-hashes.json",
    "content": "{\n  \"git+https://github.com/dottxt-ai/outlines-core.git?rev=ba10c619fc9bf3c487e43f49bdecb95a24bb465c#outlines-core@0.1.0\": \"1j9dcd831b0bmmjk2n4aag3x47qnqmkpg4gqpvwwyic7744llbfm\"\n}"
  },
  {
    "path": "docs/README.md",
    "content": "Documentation available at: https://huggingface.co/docs/text-generation-inference\n\n## Release\n\nWhen making a release, please update the latest version in the documentation with:\n```\nexport OLD_VERSION=\"2\\.0\\.3\"\nexport NEW_VERSION=\"2\\.0\\.4\"\nfind . -name '*.md' -exec sed -i -e \"s/$OLD_VERSION/$NEW_VERSION/g\" {} \\;\n```\n"
  },
  {
    "path": "docs/index.html",
    "content": "<html>\n    <head>\n        <!-- Load the latest Swagger UI code and style from npm using unpkg.com -->\n        <script src=\"https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js\"></script>\n        <link rel=\"stylesheet\" type=\"text/css\" href=\"https://unpkg.com/swagger-ui-dist@3/swagger-ui.css\"/>\n        <title>Text Generation Inference API</title>\n    </head>\n    <body>\n        <div id=\"swagger-ui\"></div> <!-- Div to hold the UI component -->\n        <script>\n            window.onload = function () {\n                // Begin Swagger UI call region\n                const ui = SwaggerUIBundle({\n                    url: \"openapi.json\", //Location of Open API spec in the repo\n                    dom_id: '#swagger-ui',\n                    deepLinking: true,\n                    supportedSubmitMethods: [],\n                    presets: [\n                        SwaggerUIBundle.presets.apis,\n                        SwaggerUIBundle.SwaggerUIStandalonePreset\n                    ],\n                    plugins: [\n                        SwaggerUIBundle.plugins.DownloadUrl\n                    ],\n                })\n                window.ui = ui\n            }\n        </script>\n    </body>\n</html>\n"
  },
  {
    "path": "docs/openapi.json",
    "content": "{\n  \"openapi\": \"3.0.3\",\n  \"info\": {\n    \"title\": \"Text Generation Inference\",\n    \"description\": \"Text Generation Webserver\",\n    \"contact\": {\n      \"name\": \"Olivier Dehaene\"\n    },\n    \"license\": {\n      \"name\": \"Apache 2.0\",\n      \"url\": \"https://www.apache.org/licenses/LICENSE-2.0\"\n    },\n    \"version\": \"3.3.6-dev0\"\n  },\n  \"paths\": {\n    \"/\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate tokens if `stream == false` or a stream of token if `stream == true`\",\n        \"operationId\": \"compat_generate\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/CompatGenerateRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Text\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"type\": \"array\",\n                  \"items\": {\n                    \"$ref\": \"#/components/schemas/GenerateResponse\"\n                  }\n                }\n              },\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/StreamResponse\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/chat_tokenize\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Template and tokenize ChatRequest\",\n        \"operationId\": \"get_chat_tokenize\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/ChatRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Templated and tokenized ChatRequest\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ChatTokenizeResponse\"\n                }\n              }\n            }\n          },\n          \"404\": {\n            \"description\": \"Failed to tokenize ChatRequest\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/generate\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate tokens\",\n        \"operationId\": \"generate\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/GenerateRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Text\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/GenerateResponse\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/generate_stream\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate a stream of token using Server-Sent Events\",\n        \"operationId\": \"generate_stream\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/GenerateRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Text\",\n            \"content\": {\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/StreamResponse\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/health\": {\n      \"get\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Health check method\",\n        \"operationId\": \"health\",\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Everything is working fine\"\n          },\n          \"503\": {\n            \"description\": \"Text generation inference is down\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"unhealthy\",\n                  \"error_type\": \"healthcheck\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/info\": {\n      \"get\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Text Generation Inference endpoint info\",\n        \"operationId\": \"get_model_info\",\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Served model info\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/Info\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/invocations\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate tokens from Sagemaker request\",\n        \"operationId\": \"sagemaker_compatibility\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/SagemakerRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Chat Completion\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/SagemakerResponse\"\n                }\n              },\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/SagemakerStreamResponse\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/metrics\": {\n      \"get\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Prometheus metrics scrape endpoint\",\n        \"operationId\": \"metrics\",\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Prometheus Metrics\",\n            \"content\": {\n              \"text/plain\": {\n                \"schema\": {\n                  \"type\": \"string\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/tokenize\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Tokenize inputs\",\n        \"operationId\": \"tokenize\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/GenerateRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Tokenized ids\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/TokenizeResponse\"\n                }\n              }\n            }\n          },\n          \"404\": {\n            \"description\": \"No tokenizer found\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"No fast tokenizer available\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/v1/chat/completions\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate tokens\",\n        \"operationId\": \"chat_completions\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/ChatRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Chat Completion\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ChatCompletion\"\n                }\n              },\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ChatCompletionChunk\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/v1/completions\": {\n      \"post\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Generate tokens\",\n        \"operationId\": \"completions\",\n        \"requestBody\": {\n          \"content\": {\n            \"application/json\": {\n              \"schema\": {\n                \"$ref\": \"#/components/schemas/CompletionRequest\"\n              }\n            }\n          },\n          \"required\": true\n        },\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Generated Chat Completion\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/CompletionFinal\"\n                }\n              },\n              \"text/event-stream\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/Chunk\"\n                }\n              }\n            }\n          },\n          \"422\": {\n            \"description\": \"Input validation error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Input validation error\",\n                  \"error_type\": \"validation\"\n                }\n              }\n            }\n          },\n          \"424\": {\n            \"description\": \"Generation Error\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Request failed during generation\",\n                  \"error_type\": \"generation\"\n                }\n              }\n            }\n          },\n          \"429\": {\n            \"description\": \"Model is overloaded\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Model is overloaded\",\n                  \"error_type\": \"overloaded\"\n                }\n              }\n            }\n          },\n          \"500\": {\n            \"description\": \"Incomplete generation\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                },\n                \"example\": {\n                  \"error\": \"Incomplete generation\",\n                  \"error_type\": \"incomplete_generation\"\n                }\n              }\n            }\n          }\n        }\n      }\n    },\n    \"/v1/models\": {\n      \"get\": {\n        \"tags\": [\n          \"Text Generation Inference\"\n        ],\n        \"summary\": \"Get model info\",\n        \"operationId\": \"openai_get_model_info\",\n        \"responses\": {\n          \"200\": {\n            \"description\": \"Served model info\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ModelInfo\"\n                }\n              }\n            }\n          },\n          \"404\": {\n            \"description\": \"Model not found\",\n            \"content\": {\n              \"application/json\": {\n                \"schema\": {\n                  \"$ref\": \"#/components/schemas/ErrorResponse\"\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n  },\n  \"components\": {\n    \"schemas\": {\n      \"BestOfSequence\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"generated_text\",\n          \"finish_reason\",\n          \"generated_tokens\",\n          \"prefill\",\n          \"tokens\"\n        ],\n        \"properties\": {\n          \"finish_reason\": {\n            \"$ref\": \"#/components/schemas/FinishReason\"\n          },\n          \"generated_text\": {\n            \"type\": \"string\",\n            \"example\": \"test\"\n          },\n          \"generated_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 1,\n            \"minimum\": 0\n          },\n          \"prefill\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/PrefillToken\"\n            }\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 42,\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"tokens\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/Token\"\n            }\n          },\n          \"top_tokens\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"array\",\n              \"items\": {\n                \"$ref\": \"#/components/schemas/Token\"\n              }\n            }\n          }\n        }\n      },\n      \"ChatCompletion\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"created\",\n          \"model\",\n          \"system_fingerprint\",\n          \"choices\",\n          \"usage\"\n        ],\n        \"properties\": {\n          \"choices\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/ChatCompletionComplete\"\n            }\n          },\n          \"created\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": \"1706270835\",\n            \"minimum\": 0\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"model\": {\n            \"type\": \"string\",\n            \"example\": \"mistralai/Mistral-7B-Instruct-v0.2\"\n          },\n          \"system_fingerprint\": {\n            \"type\": \"string\"\n          },\n          \"usage\": {\n            \"$ref\": \"#/components/schemas/Usage\"\n          }\n        }\n      },\n      \"ChatCompletionChoice\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"index\",\n          \"delta\"\n        ],\n        \"properties\": {\n          \"delta\": {\n            \"$ref\": \"#/components/schemas/ChatCompletionDelta\"\n          },\n          \"finish_reason\": {\n            \"type\": \"string\",\n            \"nullable\": true\n          },\n          \"index\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"logprobs\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/ChatCompletionLogprobs\"\n              }\n            ],\n            \"nullable\": true\n          }\n        }\n      },\n      \"ChatCompletionChunk\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"created\",\n          \"model\",\n          \"system_fingerprint\",\n          \"choices\"\n        ],\n        \"properties\": {\n          \"choices\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/ChatCompletionChoice\"\n            }\n          },\n          \"created\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": \"1706270978\",\n            \"minimum\": 0\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"model\": {\n            \"type\": \"string\",\n            \"example\": \"mistralai/Mistral-7B-Instruct-v0.2\"\n          },\n          \"system_fingerprint\": {\n            \"type\": \"string\"\n          },\n          \"usage\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/Usage\"\n              }\n            ],\n            \"nullable\": true\n          }\n        }\n      },\n      \"ChatCompletionComplete\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"index\",\n          \"message\",\n          \"finish_reason\"\n        ],\n        \"properties\": {\n          \"finish_reason\": {\n            \"type\": \"string\"\n          },\n          \"index\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"logprobs\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/ChatCompletionLogprobs\"\n              }\n            ],\n            \"nullable\": true\n          },\n          \"message\": {\n            \"$ref\": \"#/components/schemas/OutputMessage\"\n          }\n        }\n      },\n      \"ChatCompletionDelta\": {\n        \"oneOf\": [\n          {\n            \"$ref\": \"#/components/schemas/TextMessage\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/ToolCallDelta\"\n          }\n        ]\n      },\n      \"ChatCompletionLogprob\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"token\",\n          \"logprob\",\n          \"top_logprobs\"\n        ],\n        \"properties\": {\n          \"logprob\": {\n            \"type\": \"number\",\n            \"format\": \"float\"\n          },\n          \"token\": {\n            \"type\": \"string\"\n          },\n          \"top_logprobs\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/ChatCompletionTopLogprob\"\n            }\n          }\n        }\n      },\n      \"ChatCompletionLogprobs\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"content\"\n        ],\n        \"properties\": {\n          \"content\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/ChatCompletionLogprob\"\n            }\n          }\n        }\n      },\n      \"ChatCompletionTopLogprob\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"token\",\n          \"logprob\"\n        ],\n        \"properties\": {\n          \"logprob\": {\n            \"type\": \"number\",\n            \"format\": \"float\"\n          },\n          \"token\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"ChatRequest\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"messages\"\n        ],\n        \"properties\": {\n          \"frequency_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\\ndecreasing the model's likelihood to repeat the same line verbatim.\",\n            \"example\": \"1.0\",\n            \"nullable\": true\n          },\n          \"logit_bias\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"number\",\n              \"format\": \"float\"\n            },\n            \"description\": \"UNUSED\\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\\nresult in a ban or exclusive selection of the relevant token.\",\n            \"nullable\": true\n          },\n          \"logprobs\": {\n            \"type\": \"boolean\",\n            \"description\": \"Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\\noutput token returned in the content of message.\",\n            \"example\": \"false\",\n            \"nullable\": true\n          },\n          \"max_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"The maximum number of tokens that can be generated in the chat completion.\",\n            \"default\": \"1024\",\n            \"example\": \"32\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"messages\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/Message\"\n            },\n            \"description\": \"A list of messages comprising the conversation so far.\",\n            \"example\": \"[{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What is Deep Learning?\\\"}]\"\n          },\n          \"model\": {\n            \"type\": \"string\",\n            \"description\": \"[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.\",\n            \"example\": \"mistralai/Mistral-7B-Instruct-v0.2\",\n            \"nullable\": true\n          },\n          \"n\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"UNUSED\\nHow many chat completion choices to generate for each input message. Note that you will be charged based on the\\nnumber of generated tokens across all of the choices. Keep n as 1 to minimize costs.\",\n            \"example\": \"2\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"presence_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\\nincreasing the model's likelihood to talk about new topics\",\n            \"example\": 0.1,\n            \"nullable\": true\n          },\n          \"response_format\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/GrammarType\"\n              }\n            ],\n            \"default\": \"null\",\n            \"nullable\": true\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 42,\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"stop\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"string\"\n            },\n            \"description\": \"Up to 4 sequences where the API will stop generating further tokens.\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"stream\": {\n            \"type\": \"boolean\"\n          },\n          \"stream_options\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/StreamOptions\"\n              }\n            ],\n            \"nullable\": true\n          },\n          \"temperature\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\\nlower values like 0.2 will make it more focused and deterministic.\\n\\nWe generally recommend altering this or `top_p` but not both.\",\n            \"example\": 1.0,\n            \"nullable\": true\n          },\n          \"tool_choice\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/ToolChoice\"\n              }\n            ],\n            \"default\": \"auto\",\n            \"nullable\": true\n          },\n          \"tool_prompt\": {\n            \"type\": \"string\",\n            \"description\": \"A prompt to be appended before the tools\",\n            \"example\": \"Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\",\n            \"nullable\": true\n          },\n          \"tools\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/Tool\"\n            },\n            \"description\": \"A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of\\nfunctions the model may generate JSON inputs for.\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"top_logprobs\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\\nan associated log probability. logprobs must be set to true if this parameter is used.\",\n            \"example\": \"5\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"top_p\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\",\n            \"example\": 0.95,\n            \"nullable\": true\n          }\n        }\n      },\n      \"ChatTokenizeResponse\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"tokenize_response\",\n          \"templated_text\"\n        ],\n        \"properties\": {\n          \"templated_text\": {\n            \"type\": \"string\"\n          },\n          \"tokenize_response\": {\n            \"$ref\": \"#/components/schemas/TokenizeResponse\"\n          }\n        }\n      },\n      \"Chunk\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"created\",\n          \"choices\",\n          \"model\",\n          \"system_fingerprint\"\n        ],\n        \"properties\": {\n          \"choices\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/CompletionComplete\"\n            }\n          },\n          \"created\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"minimum\": 0\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"model\": {\n            \"type\": \"string\"\n          },\n          \"system_fingerprint\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"CompatGenerateRequest\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"inputs\"\n        ],\n        \"properties\": {\n          \"inputs\": {\n            \"type\": \"string\",\n            \"example\": \"My name is Olivier and I\"\n          },\n          \"parameters\": {\n            \"$ref\": \"#/components/schemas/GenerateParameters\"\n          },\n          \"stream\": {\n            \"type\": \"boolean\",\n            \"default\": \"false\"\n          }\n        }\n      },\n      \"Completion\": {\n        \"oneOf\": [\n          {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/Chunk\"\n              },\n              {\n                \"type\": \"object\",\n                \"required\": [\n                  \"object\"\n                ],\n                \"properties\": {\n                  \"object\": {\n                    \"type\": \"string\",\n                    \"enum\": [\n                      \"text_completion\"\n                    ]\n                  }\n                }\n              }\n            ]\n          },\n          {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/CompletionFinal\"\n              },\n              {\n                \"type\": \"object\",\n                \"required\": [\n                  \"object\"\n                ],\n                \"properties\": {\n                  \"object\": {\n                    \"type\": \"string\",\n                    \"enum\": [\n                      \"text_completion\"\n                    ]\n                  }\n                }\n              }\n            ]\n          }\n        ],\n        \"discriminator\": {\n          \"propertyName\": \"object\"\n        }\n      },\n      \"CompletionComplete\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"index\",\n          \"text\",\n          \"finish_reason\"\n        ],\n        \"properties\": {\n          \"finish_reason\": {\n            \"type\": \"string\"\n          },\n          \"index\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"logprobs\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"number\",\n              \"format\": \"float\"\n            },\n            \"nullable\": true\n          },\n          \"text\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"CompletionFinal\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"created\",\n          \"model\",\n          \"system_fingerprint\",\n          \"choices\",\n          \"usage\"\n        ],\n        \"properties\": {\n          \"choices\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/CompletionComplete\"\n            }\n          },\n          \"created\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": \"1706270835\",\n            \"minimum\": 0\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"model\": {\n            \"type\": \"string\",\n            \"example\": \"mistralai/Mistral-7B-Instruct-v0.2\"\n          },\n          \"system_fingerprint\": {\n            \"type\": \"string\"\n          },\n          \"usage\": {\n            \"$ref\": \"#/components/schemas/Usage\"\n          }\n        }\n      },\n      \"CompletionRequest\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"prompt\"\n        ],\n        \"properties\": {\n          \"frequency_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\\ndecreasing the model's likelihood to repeat the same line verbatim.\",\n            \"example\": \"1.0\",\n            \"nullable\": true\n          },\n          \"max_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"The maximum number of tokens that can be generated in the chat completion.\",\n            \"default\": \"1024\",\n            \"example\": \"32\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"model\": {\n            \"type\": \"string\",\n            \"description\": \"UNUSED\\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.\",\n            \"example\": \"mistralai/Mistral-7B-Instruct-v0.2\",\n            \"nullable\": true\n          },\n          \"prompt\": {\n            \"$ref\": \"#/components/schemas/Prompt\"\n          },\n          \"repetition_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"nullable\": true\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 42,\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"stop\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"string\"\n            },\n            \"description\": \"Up to 4 sequences where the API will stop generating further tokens.\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"stream\": {\n            \"type\": \"boolean\"\n          },\n          \"suffix\": {\n            \"type\": \"string\",\n            \"description\": \"The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\\nplease see the completion_template field in the model's tokenizer_config.json file for completion template.\",\n            \"nullable\": true\n          },\n          \"temperature\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\\nlower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.\",\n            \"example\": 1.0,\n            \"nullable\": true\n          },\n          \"top_p\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\",\n            \"example\": 0.95,\n            \"nullable\": true\n          }\n        }\n      },\n      \"DeltaToolCall\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"index\",\n          \"id\",\n          \"type\",\n          \"function\"\n        ],\n        \"properties\": {\n          \"function\": {\n            \"$ref\": \"#/components/schemas/Function\"\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"index\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"type\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"Details\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"finish_reason\",\n          \"generated_tokens\",\n          \"prefill\",\n          \"tokens\"\n        ],\n        \"properties\": {\n          \"best_of_sequences\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/BestOfSequence\"\n            },\n            \"nullable\": true\n          },\n          \"finish_reason\": {\n            \"$ref\": \"#/components/schemas/FinishReason\"\n          },\n          \"generated_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 1,\n            \"minimum\": 0\n          },\n          \"prefill\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/PrefillToken\"\n            }\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 42,\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"tokens\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/Token\"\n            }\n          },\n          \"top_tokens\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"array\",\n              \"items\": {\n                \"$ref\": \"#/components/schemas/Token\"\n              }\n            }\n          }\n        }\n      },\n      \"ErrorResponse\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"error\",\n          \"error_type\"\n        ],\n        \"properties\": {\n          \"error\": {\n            \"type\": \"string\"\n          },\n          \"error_type\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"FinishReason\": {\n        \"type\": \"string\",\n        \"enum\": [\n          \"length\",\n          \"eos_token\",\n          \"stop_sequence\"\n        ],\n        \"example\": \"Length\"\n      },\n      \"Function\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"arguments\"\n        ],\n        \"properties\": {\n          \"arguments\": {\n            \"type\": \"string\"\n          },\n          \"name\": {\n            \"type\": \"string\",\n            \"nullable\": true\n          }\n        }\n      },\n      \"FunctionDefinition\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"name\",\n          \"arguments\"\n        ],\n        \"properties\": {\n          \"arguments\": {},\n          \"description\": {\n            \"type\": \"string\",\n            \"nullable\": true\n          },\n          \"name\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"FunctionName\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"name\"\n        ],\n        \"properties\": {\n          \"name\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"GenerateParameters\": {\n        \"type\": \"object\",\n        \"properties\": {\n          \"adapter_id\": {\n            \"type\": \"string\",\n            \"description\": \"Lora adapter id\",\n            \"default\": \"null\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"best_of\": {\n            \"type\": \"integer\",\n            \"description\": \"Generate best_of sequences and return the one if the highest token logprobs.\",\n            \"default\": \"null\",\n            \"example\": 1,\n            \"nullable\": true,\n            \"minimum\": 0,\n            \"exclusiveMinimum\": 0\n          },\n          \"decoder_input_details\": {\n            \"type\": \"boolean\",\n            \"description\": \"Whether to return decoder input token logprobs and ids.\",\n            \"default\": \"false\"\n          },\n          \"details\": {\n            \"type\": \"boolean\",\n            \"description\": \"Whether to return generation details.\",\n            \"default\": \"true\"\n          },\n          \"do_sample\": {\n            \"type\": \"boolean\",\n            \"description\": \"Activate logits sampling.\",\n            \"default\": \"false\",\n            \"example\": true\n          },\n          \"frequency_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"The parameter for frequency penalty. 1.0 means no penalty\\nPenalize new tokens based on their existing frequency in the text so far,\\ndecreasing the model's likelihood to repeat the same line verbatim.\",\n            \"default\": \"null\",\n            \"example\": 0.1,\n            \"nullable\": true,\n            \"exclusiveMinimum\": -2\n          },\n          \"grammar\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/GrammarType\"\n              }\n            ],\n            \"default\": \"null\",\n            \"nullable\": true\n          },\n          \"max_new_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"Maximum number of tokens to generate.\",\n            \"default\": \"1024\",\n            \"example\": \"20\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"repetition_penalty\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"The parameter for repetition penalty. 1.0 means no penalty.\\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\",\n            \"default\": \"null\",\n            \"example\": 1.03,\n            \"nullable\": true,\n            \"exclusiveMinimum\": 0\n          },\n          \"return_full_text\": {\n            \"type\": \"boolean\",\n            \"description\": \"Whether to prepend the prompt to the generated text\",\n            \"default\": \"null\",\n            \"example\": false,\n            \"nullable\": true\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"description\": \"Random sampling seed.\",\n            \"default\": \"null\",\n            \"example\": \"null\",\n            \"nullable\": true,\n            \"minimum\": 0,\n            \"exclusiveMinimum\": 0\n          },\n          \"stop\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"type\": \"string\"\n            },\n            \"description\": \"Stop generating tokens if a member of `stop` is generated.\",\n            \"example\": [\n              \"photographer\"\n            ],\n            \"maxItems\": 4\n          },\n          \"temperature\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"The value used to module the logits distribution.\",\n            \"default\": \"null\",\n            \"example\": 0.5,\n            \"nullable\": true,\n            \"exclusiveMinimum\": 0\n          },\n          \"top_k\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"The number of highest probability vocabulary tokens to keep for top-k-filtering.\",\n            \"default\": \"null\",\n            \"example\": 10,\n            \"nullable\": true,\n            \"exclusiveMinimum\": 0\n          },\n          \"top_n_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"description\": \"The number of highest probability vocabulary tokens to keep for top-n-filtering.\",\n            \"default\": \"null\",\n            \"example\": 5,\n            \"nullable\": true,\n            \"minimum\": 0,\n            \"exclusiveMinimum\": 0\n          },\n          \"top_p\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"Top-p value for nucleus sampling.\",\n            \"default\": \"null\",\n            \"example\": 0.95,\n            \"nullable\": true,\n            \"maximum\": 1,\n            \"exclusiveMinimum\": 0\n          },\n          \"truncate\": {\n            \"type\": \"integer\",\n            \"description\": \"Truncate inputs tokens to the given size.\",\n            \"default\": \"null\",\n            \"example\": \"null\",\n            \"nullable\": true,\n            \"minimum\": 0\n          },\n          \"typical_p\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"description\": \"Typical Decoding mass\\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.\",\n            \"default\": \"null\",\n            \"example\": 0.95,\n            \"nullable\": true,\n            \"maximum\": 1,\n            \"exclusiveMinimum\": 0\n          },\n          \"watermark\": {\n            \"type\": \"boolean\",\n            \"description\": \"Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).\",\n            \"default\": \"false\",\n            \"example\": true\n          }\n        }\n      },\n      \"GenerateRequest\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"inputs\"\n        ],\n        \"properties\": {\n          \"inputs\": {\n            \"type\": \"string\",\n            \"example\": \"My name is Olivier and I\"\n          },\n          \"parameters\": {\n            \"$ref\": \"#/components/schemas/GenerateParameters\"\n          }\n        }\n      },\n      \"GenerateResponse\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"generated_text\"\n        ],\n        \"properties\": {\n          \"details\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/Details\"\n              }\n            ],\n            \"nullable\": true\n          },\n          \"generated_text\": {\n            \"type\": \"string\",\n            \"example\": \"test\"\n          }\n        }\n      },\n      \"GrammarType\": {\n        \"oneOf\": [\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"type\",\n              \"value\"\n            ],\n            \"properties\": {\n              \"type\": {\n                \"type\": \"string\",\n                \"enum\": [\n                  \"json\"\n                ]\n              },\n              \"value\": {\n                \"description\": \"A string that represents a [JSON Schema](https://json-schema.org/).\\n\\nJSON Schema is a declarative language that allows to annotate JSON documents\\nwith types and descriptions.\"\n              }\n            }\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"type\",\n              \"value\"\n            ],\n            \"properties\": {\n              \"type\": {\n                \"type\": \"string\",\n                \"enum\": [\n                  \"regex\"\n                ]\n              },\n              \"value\": {\n                \"type\": \"string\"\n              }\n            }\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"type\",\n              \"value\"\n            ],\n            \"properties\": {\n              \"type\": {\n                \"type\": \"string\",\n                \"enum\": [\n                  \"json_schema\"\n                ]\n              },\n              \"value\": {\n                \"$ref\": \"#/components/schemas/JsonSchemaConfig\"\n              }\n            }\n          }\n        ],\n        \"discriminator\": {\n          \"propertyName\": \"type\"\n        }\n      },\n      \"Info\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"model_id\",\n          \"max_concurrent_requests\",\n          \"max_best_of\",\n          \"max_stop_sequences\",\n          \"max_input_tokens\",\n          \"max_total_tokens\",\n          \"validation_workers\",\n          \"max_client_batch_size\",\n          \"router\",\n          \"version\"\n        ],\n        \"properties\": {\n          \"docker_label\": {\n            \"type\": \"string\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"max_best_of\": {\n            \"type\": \"integer\",\n            \"example\": \"2\",\n            \"minimum\": 0\n          },\n          \"max_client_batch_size\": {\n            \"type\": \"integer\",\n            \"example\": \"32\",\n            \"minimum\": 0\n          },\n          \"max_concurrent_requests\": {\n            \"type\": \"integer\",\n            \"description\": \"Router Parameters\",\n            \"example\": \"128\",\n            \"minimum\": 0\n          },\n          \"max_input_tokens\": {\n            \"type\": \"integer\",\n            \"example\": \"1024\",\n            \"minimum\": 0\n          },\n          \"max_stop_sequences\": {\n            \"type\": \"integer\",\n            \"example\": \"4\",\n            \"minimum\": 0\n          },\n          \"max_total_tokens\": {\n            \"type\": \"integer\",\n            \"example\": \"2048\",\n            \"minimum\": 0\n          },\n          \"model_id\": {\n            \"type\": \"string\",\n            \"description\": \"Model info\",\n            \"example\": \"bigscience/blomm-560m\"\n          },\n          \"model_pipeline_tag\": {\n            \"type\": \"string\",\n            \"example\": \"text-generation\",\n            \"nullable\": true\n          },\n          \"model_sha\": {\n            \"type\": \"string\",\n            \"example\": \"e985a63cdc139290c5f700ff1929f0b5942cced2\",\n            \"nullable\": true\n          },\n          \"router\": {\n            \"type\": \"string\",\n            \"description\": \"Router Info\",\n            \"example\": \"text-generation-router\"\n          },\n          \"sha\": {\n            \"type\": \"string\",\n            \"example\": \"null\",\n            \"nullable\": true\n          },\n          \"validation_workers\": {\n            \"type\": \"integer\",\n            \"example\": \"2\",\n            \"minimum\": 0\n          },\n          \"version\": {\n            \"type\": \"string\",\n            \"example\": \"0.5.0\"\n          }\n        }\n      },\n      \"JsonSchemaConfig\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"schema\"\n        ],\n        \"properties\": {\n          \"name\": {\n            \"type\": \"string\",\n            \"description\": \"Optional name identifier for the schema\",\n            \"nullable\": true\n          },\n          \"schema\": {\n            \"description\": \"The actual JSON schema definition\"\n          }\n        }\n      },\n      \"Message\": {\n        \"allOf\": [\n          {\n            \"$ref\": \"#/components/schemas/MessageBody\"\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"role\"\n            ],\n            \"properties\": {\n              \"name\": {\n                \"type\": \"string\",\n                \"example\": \"\\\"David\\\"\",\n                \"nullable\": true\n              },\n              \"role\": {\n                \"type\": \"string\",\n                \"example\": \"user\"\n              }\n            }\n          }\n        ]\n      },\n      \"MessageBody\": {\n        \"oneOf\": [\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"content\"\n            ],\n            \"properties\": {\n              \"content\": {\n                \"$ref\": \"#/components/schemas/MessageContent\"\n              }\n            }\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"tool_calls\"\n            ],\n            \"properties\": {\n              \"tool_calls\": {\n                \"type\": \"array\",\n                \"items\": {\n                  \"$ref\": \"#/components/schemas/ToolCall\"\n                }\n              }\n            }\n          }\n        ]\n      },\n      \"MessageChunk\": {\n        \"oneOf\": [\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"text\",\n              \"type\"\n            ],\n            \"properties\": {\n              \"text\": {\n                \"type\": \"string\"\n              },\n              \"type\": {\n                \"type\": \"string\",\n                \"enum\": [\n                  \"text\"\n                ]\n              }\n            }\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"image_url\",\n              \"type\"\n            ],\n            \"properties\": {\n              \"image_url\": {\n                \"$ref\": \"#/components/schemas/Url\"\n              },\n              \"type\": {\n                \"type\": \"string\",\n                \"enum\": [\n                  \"image_url\"\n                ]\n              }\n            }\n          }\n        ],\n        \"discriminator\": {\n          \"propertyName\": \"type\"\n        }\n      },\n      \"MessageContent\": {\n        \"oneOf\": [\n          {\n            \"type\": \"string\"\n          },\n          {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/MessageChunk\"\n            }\n          }\n        ]\n      },\n      \"ModelInfo\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"object\",\n          \"created\",\n          \"owned_by\"\n        ],\n        \"properties\": {\n          \"created\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 1686935002,\n            \"minimum\": 0\n          },\n          \"id\": {\n            \"type\": \"string\",\n            \"example\": \"gpt2\"\n          },\n          \"object\": {\n            \"type\": \"string\",\n            \"example\": \"model\"\n          },\n          \"owned_by\": {\n            \"type\": \"string\",\n            \"example\": \"openai\"\n          }\n        }\n      },\n      \"OutputMessage\": {\n        \"oneOf\": [\n          {\n            \"$ref\": \"#/components/schemas/TextMessage\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/ToolCallMessage\"\n          }\n        ]\n      },\n      \"PrefillToken\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"text\",\n          \"logprob\"\n        ],\n        \"properties\": {\n          \"id\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 0,\n            \"minimum\": 0\n          },\n          \"logprob\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"example\": -0.34,\n            \"nullable\": true\n          },\n          \"text\": {\n            \"type\": \"string\",\n            \"example\": \"test\"\n          }\n        }\n      },\n      \"Prompt\": {\n        \"type\": \"array\",\n        \"items\": {\n          \"type\": \"string\"\n        }\n      },\n      \"SagemakerRequest\": {\n        \"oneOf\": [\n          {\n            \"$ref\": \"#/components/schemas/CompatGenerateRequest\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/ChatRequest\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/CompletionRequest\"\n          }\n        ]\n      },\n      \"SagemakerResponse\": {\n        \"oneOf\": [\n          {\n            \"$ref\": \"#/components/schemas/GenerateResponse\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/ChatCompletion\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/CompletionFinal\"\n          }\n        ]\n      },\n      \"SagemakerStreamResponse\": {\n        \"oneOf\": [\n          {\n            \"$ref\": \"#/components/schemas/StreamResponse\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/ChatCompletionChunk\"\n          },\n          {\n            \"$ref\": \"#/components/schemas/Chunk\"\n          }\n        ]\n      },\n      \"SimpleToken\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"text\",\n          \"start\",\n          \"stop\"\n        ],\n        \"properties\": {\n          \"id\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 0,\n            \"minimum\": 0\n          },\n          \"start\": {\n            \"type\": \"integer\",\n            \"example\": 0,\n            \"minimum\": 0\n          },\n          \"stop\": {\n            \"type\": \"integer\",\n            \"example\": 2,\n            \"minimum\": 0\n          },\n          \"text\": {\n            \"type\": \"string\",\n            \"example\": \"test\"\n          }\n        }\n      },\n      \"StreamDetails\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"finish_reason\",\n          \"generated_tokens\",\n          \"input_length\"\n        ],\n        \"properties\": {\n          \"finish_reason\": {\n            \"$ref\": \"#/components/schemas/FinishReason\"\n          },\n          \"generated_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 1,\n            \"minimum\": 0\n          },\n          \"input_length\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 1,\n            \"minimum\": 0\n          },\n          \"seed\": {\n            \"type\": \"integer\",\n            \"format\": \"int64\",\n            \"example\": 42,\n            \"nullable\": true,\n            \"minimum\": 0\n          }\n        }\n      },\n      \"StreamOptions\": {\n        \"type\": \"object\",\n        \"properties\": {\n          \"include_usage\": {\n            \"type\": \"boolean\",\n            \"description\": \"If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.\",\n            \"example\": \"true\"\n          }\n        }\n      },\n      \"StreamResponse\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"index\",\n          \"token\"\n        ],\n        \"properties\": {\n          \"details\": {\n            \"allOf\": [\n              {\n                \"$ref\": \"#/components/schemas/StreamDetails\"\n              }\n            ],\n            \"default\": \"null\",\n            \"nullable\": true\n          },\n          \"generated_text\": {\n            \"type\": \"string\",\n            \"default\": \"null\",\n            \"example\": \"test\",\n            \"nullable\": true\n          },\n          \"index\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"token\": {\n            \"$ref\": \"#/components/schemas/Token\"\n          },\n          \"top_tokens\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/Token\"\n            }\n          }\n        }\n      },\n      \"TextMessage\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"role\",\n          \"content\"\n        ],\n        \"properties\": {\n          \"content\": {\n            \"type\": \"string\",\n            \"example\": \"My name is David and I\"\n          },\n          \"role\": {\n            \"type\": \"string\",\n            \"example\": \"user\"\n          },\n          \"tool_call_id\": {\n            \"type\": \"string\",\n            \"nullable\": true\n          }\n        }\n      },\n      \"Token\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"text\",\n          \"logprob\",\n          \"special\"\n        ],\n        \"properties\": {\n          \"id\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"example\": 0,\n            \"minimum\": 0\n          },\n          \"logprob\": {\n            \"type\": \"number\",\n            \"format\": \"float\",\n            \"example\": -0.34,\n            \"nullable\": true\n          },\n          \"special\": {\n            \"type\": \"boolean\",\n            \"example\": \"false\"\n          },\n          \"text\": {\n            \"type\": \"string\",\n            \"example\": \"test\"\n          }\n        }\n      },\n      \"TokenizeResponse\": {\n        \"type\": \"array\",\n        \"items\": {\n          \"$ref\": \"#/components/schemas/SimpleToken\"\n        }\n      },\n      \"Tool\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"type\",\n          \"function\"\n        ],\n        \"properties\": {\n          \"function\": {\n            \"$ref\": \"#/components/schemas/FunctionDefinition\"\n          },\n          \"type\": {\n            \"type\": \"string\",\n            \"example\": \"function\"\n          }\n        }\n      },\n      \"ToolCall\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"id\",\n          \"type\",\n          \"function\"\n        ],\n        \"properties\": {\n          \"function\": {\n            \"$ref\": \"#/components/schemas/FunctionDefinition\"\n          },\n          \"id\": {\n            \"type\": \"string\"\n          },\n          \"type\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"ToolCallDelta\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"role\",\n          \"tool_calls\"\n        ],\n        \"properties\": {\n          \"role\": {\n            \"type\": \"string\",\n            \"example\": \"assistant\"\n          },\n          \"tool_calls\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/DeltaToolCall\"\n            }\n          }\n        }\n      },\n      \"ToolCallMessage\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"role\",\n          \"tool_calls\"\n        ],\n        \"properties\": {\n          \"role\": {\n            \"type\": \"string\",\n            \"example\": \"assistant\"\n          },\n          \"tool_calls\": {\n            \"type\": \"array\",\n            \"items\": {\n              \"$ref\": \"#/components/schemas/ToolCall\"\n            }\n          }\n        }\n      },\n      \"ToolChoice\": {\n        \"oneOf\": [\n          {\n            \"type\": \"string\",\n            \"description\": \"Means the model can pick between generating a message or calling one or more tools.\",\n            \"enum\": [\n              \"auto\"\n            ]\n          },\n          {\n            \"type\": \"string\",\n            \"description\": \"Means the model will not call any tool and instead generates a message.\",\n            \"enum\": [\n              \"none\"\n            ]\n          },\n          {\n            \"type\": \"string\",\n            \"description\": \"Means the model must call one or more tools.\",\n            \"enum\": [\n              \"required\"\n            ]\n          },\n          {\n            \"type\": \"object\",\n            \"required\": [\n              \"function\"\n            ],\n            \"properties\": {\n              \"function\": {\n                \"$ref\": \"#/components/schemas/FunctionName\"\n              }\n            }\n          }\n        ],\n        \"description\": \"<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>\"\n      },\n      \"Url\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"url\"\n        ],\n        \"properties\": {\n          \"url\": {\n            \"type\": \"string\"\n          }\n        }\n      },\n      \"Usage\": {\n        \"type\": \"object\",\n        \"required\": [\n          \"prompt_tokens\",\n          \"completion_tokens\",\n          \"total_tokens\"\n        ],\n        \"properties\": {\n          \"completion_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"prompt_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          },\n          \"total_tokens\": {\n            \"type\": \"integer\",\n            \"format\": \"int32\",\n            \"minimum\": 0\n          }\n        }\n      }\n    }\n  },\n  \"tags\": [\n    {\n      \"name\": \"Text Generation Inference\",\n      \"description\": \"Hugging Face Text Generation Inference API\"\n    }\n  ]\n}\n"
  },
  {
    "path": "docs/source/_toctree.yml",
    "content": "- sections:\n  - local: index\n    title: Text Generation Inference\n  - local: quicktour\n    title: Quick Tour\n  - local: supported_models\n    title: Supported Models\n  - local: installation_nvidia\n    title: Using TGI with Nvidia GPUs\n  - local: installation_amd\n    title: Using TGI with AMD GPUs\n  - local: installation_gaudi\n    title: Using TGI with Intel Gaudi\n  - local: installation_inferentia\n    title: Using TGI with AWS Trainium and Inferentia\n  - local: installation_tpu\n    title: Using TGI with Google TPUs\n  - local: installation_intel\n    title: Using TGI with Intel GPUs\n  - local: installation\n    title: Installation from source\n  - local: multi_backend_support\n    title: Multi-backend support\n\n  - local: architecture\n    title: Internal Architecture\n  - local: usage_statistics\n    title: Usage Statistics\n  title: Getting started\n- sections:\n  - local: basic_tutorials/consuming_tgi\n    title: Consuming TGI\n  - local: basic_tutorials/preparing_model\n    title: Preparing Model for Serving\n  - local: basic_tutorials/gated_model_access\n    title: Serving Private & Gated Models\n  - local: basic_tutorials/using_cli\n    title: Using TGI CLI\n  - local: basic_tutorials/non_core_models\n    title: Non-core Model Serving\n  - local: basic_tutorials/safety\n    title: Safety\n  - local: basic_tutorials/using_guidance\n    title: Using Guidance, JSON, tools\n  - local: basic_tutorials/visual_language_models\n    title: Visual Language Models\n  - local: basic_tutorials/monitoring\n    title: Monitoring TGI with Prometheus and Grafana\n  - local: basic_tutorials/train_medusa\n    title: Train Medusa\n  title: Tutorials\n- sections:\n  - local: backends/neuron\n    title: Neuron\n  - local: backends/gaudi\n    title: Gaudi\n  - local: backends/trtllm\n    title: TensorRT-LLM\n  - local: backends/llamacpp\n    title: Llamacpp\n  title: Backends\n- sections:\n  - local: reference/launcher\n    title: All TGI CLI options\n  - local: reference/metrics\n    title: Exported Metrics\n  - local: reference/api_reference\n    title: API Reference\n  title: Reference\n- sections:\n  - local: conceptual/chunking\n    title: V3 update, caching and chunking\n  - local: conceptual/streaming\n    title: Streaming\n  - local: conceptual/quantization\n    title: Quantization\n  - local: conceptual/tensor_parallelism\n    title: Tensor Parallelism\n  - local: conceptual/paged_attention\n    title: PagedAttention\n  - local: conceptual/safetensors\n    title: Safetensors\n  - local: conceptual/flash_attention\n    title: Flash Attention\n  - local: conceptual/speculation\n    title: Speculation (Medusa, ngram)\n  - local: conceptual/guidance\n    title: How Guidance Works (via outlines)\n  - local: conceptual/lora\n    title: LoRA (Low-Rank Adaptation)\n  - local: conceptual/external\n    title: External Resources\n\n\n  title: Conceptual Guides\n"
  },
  {
    "path": "docs/source/architecture.md",
    "content": "# Text Generation Inference Architecture\n\nThis document aims at describing the architecture of Text Generation Inference (TGI), by describing the call flow between the separate components.\n\nA high-level architecture diagram can be seen here:\n\n![TGI architecture](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)\n\nThis diagram shows well there are these separate components:\n\n- **The router**, also named `webserver`, that receives the client requests, buffers them, creates some batches, and prepares gRPC calls to a model server.\n- **The launcher** is a helper that will be able to launch one or several model servers (if model is sharded), and it launches the router with the compatible arguments.\n- **The model server**, responsible for receiving the gRPC requests and to process the inference on the model. If the model is sharded across multiple accelerators (e.g.: multiple GPUs), the model server shards might be synchronized via NCCL or equivalent.\n\nNote that for other backends (eg. TRTLLM) the model server and launcher are specific to the backend.\n\nThe router and the model server can be two different machines, they do not need to be deployed together.\n\n## The Router\n\nThis component is a rust web server binary that accepts HTTP requests using the custom [HTTP API](https://huggingface.github.io/text-generation-inference/), as well as OpenAI's [Messages API](https://huggingface.co/docs/text-generation-inference/messages_api).\nThe router receives the API calls and handles the \"baches\" logic (and introduction to batching can be found [here](https://github.com/huggingface/text-generation-inference/blob/main/router/README.md)).\nIt uses different strategies to reduce latency between requests and responses, especially oriented to decoding latency. It will use queues, schedulers, and block allocators to achieve that and produce batched requests that it will then be sent to the model server.\n\n### Router's command line\n\nThe router command line will be the way to pass parameters to it (it does not rely on configuration file):\n\n```\nText Generation Webserver\n\nUsage: text-generation-router [OPTIONS]\n\nOptions:\n      --max-concurrent-requests <MAX_CONCURRENT_REQUESTS>\n          [env: MAX_CONCURRENT_REQUESTS=] [default: 128]\n      --max-best-of <MAX_BEST_OF>\n          [env: MAX_BEST_OF=] [default: 2]\n      --max-stop-sequences <MAX_STOP_SEQUENCES>\n          [env: MAX_STOP_SEQUENCES=] [default: 4]\n      --max-top-n-tokens <MAX_TOP_N_TOKENS>\n          [env: MAX_TOP_N_TOKENS=] [default: 5]\n      --max-input-tokens <MAX_INPUT_TOKENS>\n          [env: MAX_INPUT_TOKENS=] [default: 1024]\n      --max-total-tokens <MAX_TOTAL_TOKENS>\n          [env: MAX_TOTAL_TOKENS=] [default: 2048]\n      --waiting-served-ratio <WAITING_SERVED_RATIO>\n          [env: WAITING_SERVED_RATIO=] [default: 1.2]\n      --max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>\n          [env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]\n      --max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>\n          [env: MAX_BATCH_TOTAL_TOKENS=]\n      --max-waiting-tokens <MAX_WAITING_TOKENS>\n          [env: MAX_WAITING_TOKENS=] [default: 20]\n      --max-batch-size <MAX_BATCH_SIZE>\n          [env: MAX_BATCH_SIZE=]\n      --hostname <HOSTNAME>\n          [env: HOSTNAME=] [default: 0.0.0.0]\n  -p, --port <PORT>\n          [env: PORT=] [default: 3000]\n      --master-shard-uds-path <MASTER_SHARD_UDS_PATH>\n          [env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]\n      --tokenizer-name <TOKENIZER_NAME>\n          [env: TOKENIZER_NAME=] [default: bigscience/bloom]\n      --tokenizer-config-path <TOKENIZER_CONFIG_PATH>\n          [env: TOKENIZER_CONFIG_PATH=]\n      --revision <REVISION>\n          [env: REVISION=]\n      --validation-workers <VALIDATION_WORKERS>\n          [env: VALIDATION_WORKERS=] [default: 2]\n      --json-output\n          [env: JSON_OUTPUT=]\n      --otlp-endpoint <OTLP_ENDPOINT>\n          [env: OTLP_ENDPOINT=]\n      --otlp-service-name <OTLP_SERVICE_NAME>\n          [env: OTLP_SERVICE_NAME=]\n      --cors-allow-origin <CORS_ALLOW_ORIGIN>\n          [env: CORS_ALLOW_ORIGIN=]\n      --ngrok\n          [env: NGROK=]\n      --ngrok-authtoken <NGROK_AUTHTOKEN>\n          [env: NGROK_AUTHTOKEN=]\n      --ngrok-edge <NGROK_EDGE>\n          [env: NGROK_EDGE=]\n      --messages-api-enabled\n          [env: MESSAGES_API_ENABLED=]\n      --disable-grammar-support\n          [env: DISABLE_GRAMMAR_SUPPORT=]\n      --max-client-batch-size <MAX_CLIENT_BATCH_SIZE>\n          [env: MAX_CLIENT_BATCH_SIZE=] [default: 4]\n  -h, --help\n          Print help\n  -V, --version\n          Print version\n```\n\n## The Model Server\n\nThe model server is a python server, capable of starting a server waiting for gRPC requests, loads a given model, perform sharding to provide [tensor parallelism](https://huggingface.co/docs/text-generation-inference/conceptual/tensor_parallelism), and stays alive while waiting for new requests.\nThe model server supports models instantiated using Pytorch and optimized for inference mainly on CUDA/ROCM.\n\n### Model Server Variants\n\nSeveral variants of the model server exist that are actively supported by Hugging Face:\n\n- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).\n- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.\n- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.\n- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).\n- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained in the main TGI repository. Some model features differ.\n- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).\n\nNot all variants provide the same features, as hardware and middleware capabilities do not provide the same optimizations.\n\n### Command Line Interface\n\nThe official command line interface (CLI) for the server supports three subcommands, `download-weights`, `quantize` and `serve`:\n\n- `download-weights` will download weights from the hub and, in some variants it will convert weights to a format that is adapted to the given implementation;\n- `quantize` will allow to quantize a model using the `qptq` package. This feature is not available nor supported on all variants;\n- `serve` will start the server that load a model (or a model shard), receives gRPC calls from the router, performs an inference and provides a formatted response to the given request.\n\nServe's command line parameters on the TGI repository are these:\n\n```\n Usage: cli.py serve [OPTIONS] MODEL_ID\n\n╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮\n│ *    model_id      TEXT  [default: None] [required]                                                      │\n╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮\n│ --revision                                       TEXT                        [default: None]             │\n│ --sharded              --no-sharded                                          [default: no-sharded]       │\n│ --quantize                                       [bitsandbytes|bitsandbytes  [default: None]             │\n│                                                  -nf4|bitsandbytes-fp4|gptq                              │\n│                                                  |awq|eetq|exl2|fp8]                                     │\n│ --speculate                                      INTEGER                     [default: None]             │\n│ --dtype                                          [float16|bfloat16]          [default: None]             │\n│ --trust-remote-code    --no-trust-remote-code                                [default:                   │\n│                                                                              no-trust-remote-code]       │\n│ --uds-path                                       PATH                        [default:                   │\n│                                                                              /tmp/text-generation-serve… │\n│ --logger-level                                   TEXT                        [default: INFO]             │\n│ --json-output          --no-json-output                                      [default: no-json-output]   │\n│ --otlp-endpoint                                  TEXT                        [default: None]             │\n│ --otlp-service-name                              TEXT                        [default:                   │\n│                                                                              text-generation-inference...│\n│ --help                                                                       Show this message and exit. │\n╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n```\n\nNote that some variants might support different parameters, and they could possibly accept more options that can be passed on using environment variables.\n\n## Call Flow\n\nOnce both components are initialized, weights downloaded and model server is up and running, router and model server exchange data and info through the gRPC call. There are currently two supported schemas, [v2](https://github.com/huggingface/text-generation-inference/blob/main/proto/generate.proto) and [v3](https://github.com/huggingface/text-generation-inference/blob/main/proto/v3/generate.proto). These two versions are almost identical, except for:\n\n- input chunks support, for text and image data,\n- paged attention support\n\nHere's a diagram that displays the exchanges that follow the router and model server startup.\n\n```mermaid\nsequenceDiagram\n\n    Router->>Model Server: service discovery\n    Model Server-->>Router: urls for other shards\n\n    Router->>Model Server: get model info\n    Model Server-->>Router: shard info\n\n    Router->>Model Server: health check\n    Model Server-->>Router: health OK\n\n    Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)\n    Model Server-->>Router: warmup result\n```\n\nAfter these are done, the router is ready to receive generate calls from multiple clients. Here's an example.\n\n```mermaid\nsequenceDiagram\n    participant Client 1\n    participant Client 2\n    participant Client 3\n    participant Router\n    participant Model Server\n\n    Client 1->>Router: generate_stream\n    Router->>Model Server: prefill(batch1)\n    Model Server-->>Router: generations, cached_batch1, timings\n    Router-->>Client 1: token 1\n\n    Router->>Model Server: decode(cached_batch1)\n    Model Server-->>Router: generations, cached_batch1, timings\n    Router-->>Client 1: token 2\n\n    Router->>Model Server: decode(cached_batch1)\n    Model Server-->>Router: generations, cached_batch1, timings\n    Router-->>Client 1: token 3\n\n    Client 2->>Router: generate_stream\n    Router->>Model Server: prefill(batch2)\n    Note right of Model Server: This stops previous batch, that is restarted\n    Model Server-->>Router: generations, cached_batch2, timings\n    Router-->>Client 2: token 1'\n\n    Router->>Model Server: decode(cached_batch1, cached_batch2)\n    Model Server-->>Router: generations, cached_batch1, timings\n    Router-->>Client 1: token 4\n    Router-->>Client 2: token 2'\n\n    Note left of Client 1: Client 1 leaves\n    Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)\n    Model Server-->>Router: filtered batch\n\n    Router->>Model Server: decode(cached_batch2)\n    Model Server-->>Router: generations, cached_batch2, timings\n    Router-->>Client 2: token 3'\n\n    Client 3->>Router: generate_stream\n    Note right of Model Server: This stops previous batch, that is restarted\n    Router->>Model Server: prefill(batch3)\n    Note left of Client 1: Client 3 leaves without receiving any batch\n    Router->>Model Server: clear_cache(batch3)\n    Note right of Model Server: This stops previous batch, that is restarted\n\n    Router->>Model Server: decode(cached_batch3)\n    Note right of Model Server: Last token (stopping criteria)\n    Model Server-->>Router: generations, cached_batch3, timings\n    Router-->>Client 2: token 4'\n\n\n```\n"
  },
  {
    "path": "docs/source/backends/gaudi.mdx",
    "content": "# Gaudi Backend for Text Generation Inference\n\n## Overview\nText Generation Inference (TGI) has been optimized to run on Gaudi hardware via the Gaudi backend for TGI.\n\n## Supported Hardware\n- **Gaudi1**: Available on [AWS EC2 DL1 instances](https://aws.amazon.com/ec2/instance-types/dl1/)\n- **Gaudi2**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)\n- **Gaudi3**: Available on [Intel Cloud](https://console.cloud.intel.com/docs/reference/ai_instances.html)\n\n## Tutorial: Getting Started with TGI on Gaudi\n\n### Basic Usage\nThe easiest way to run TGI on Gaudi is to use the official Docker image:\n\n```bash\nmodel=meta-llama/Meta-Llama-3.1-8B-Instruct\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\nhf_token=YOUR_HF_ACCESS_TOKEN\n\ndocker run --runtime=habana --cap-add=sys_nice --ipc=host \\\n    -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n    --model-id $model\n```\n\nOnce you see the `connected` log, the server is ready to accept requests:\n> 2024-05-22T19:31:48.302239Z  INFO text_generation_router: router/src/main.rs:378: Connected\n\nYou can find your `YOUR_HF_ACCESS_TOKEN` at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). This is necessary to access gated models like llama3.1.\n\n### Making Your First Request\nYou can send a request from a separate terminal:\n\n```bash\ncurl 127.0.0.1:8080/generate \\\n    -X POST \\\n    -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":32}}' \\\n    -H 'Content-Type: application/json'\n```\n\n## How-to Guides\n\nYou can view the full list of supported models in the [Supported Models](https://huggingface.co/docs/text-generation-inference/backends/gaudi#supported-models) section.\n\nFor example, to run Llama3.1-8B, you can use the following command:\n\n```bash\nmodel=meta-llama/Meta-Llama-3.1-8B-Instruct\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\nhf_token=YOUR_ACCESS_TOKEN\n\ndocker run --runtime=habana --cap-add=sys_nice --ipc=host \\\n    -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n    --model-id $model\n    <text-generation-inference-launcher-arguments>\n```\n\nFor the full list of service parameters, refer to the [launcher-arguments page](https://huggingface.co/docs/text-generation-inference/reference/launcher).\n\nThe validated docker commands can be found in the [examples/docker_commands folder](https://github.com/huggingface/text-generation-inference/tree/main/backends/gaudi/examples/docker_commands).\n\n> Note: `--runtime=habana --cap-add=sys_nice --ipc=host ` is required to enable docker to use the Gaudi hardware (more details [here](https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html)).\n\n### How to Enable Multi-Card Inference (Sharding)\n\nTGI-Gaudi supports sharding for multi-card inference, allowing you to distribute the load across multiple Gaudi cards. This is recommended to run large models and to speed up inference.\n\nFor example, on a machine with 8 Gaudi cards, you can run:\n\n```bash\ndocker run --runtime=habana --ipc=host --cap-add=sys_nice \\\n    -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \\\n    tgi-gaudi \\\n    --model-id $model --sharded true --num-shard 8\n```\n\n<Tip>\nWe recommend always using sharding when running on a multi-card machine.\n</Tip>\n\n### How to Use Different Precision Formats\n\n#### BF16 Precision (Default)\nBy default, all models run with BF16 precision on Gaudi hardware.\n\n#### FP8 Precision\nTGI-Gaudi supports FP8 precision inference, which can significantly reduce memory usage and improve performance for large models. We support model like W8A8 FP compressed-tensors parameters such as [RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8](https://huggingface.co/RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8) and AutoFP8 generated model[RedHatAI/Meta-Llama-3-8B-Instruct-FP8](https://huggingface.co/RedHatAI/Meta-Llama-3-8B-Instruct-FP8) .\nTGI-Gaudi supports FP8 precision inference with [Intel Neural Compressor (INC)](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html).\n\n\n### How to Run Vision-Language Models (VLMs)\n\nGaudi supports VLM inference.\n\nExample for Llava-v1.6-Mistral-7B on 1 card:\n\nStart the TGI server via the following command:\n```bash\nmodel=llava-hf/llava-v1.6-mistral-7b-hf\nvolume=$PWD/data   # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run -p 8080:80 \\\n   --runtime=habana \\\n   --cap-add=sys_nice \\\n   --ipc=host \\\n   -v $volume:/data \\\n   ghcr.io/huggingface/text-generation-inference:3.3.5-gaudi \\\n   --model-id $model \\\n   --max-input-tokens 4096 --max-batch-prefill-tokens 16384 \\\n   --max-total-tokens 8192 --max-batch-size 4\n```\n\nYou can then send a request to the server via the following command:\n```bash\ncurl -N 127.0.0.1:8080/generate \\\n    -X POST \\\n    -d '{\"inputs\":\"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\\n\\n\",\"parameters\":{\"max_new_tokens\":32}}' \\\n    -H 'Content-Type: application/json'\n```\n\n> Note: In Llava-v1.6-Mistral-7B, an image usually accounts for 2000 input tokens. For example, an image of size 512x512 is represented by 2800 tokens. Thus, `max-input-tokens` must be larger than the number of tokens associated with the image. Otherwise the image may be truncated. The value of `max-batch-prefill-tokens` is 16384, which is calculated as follows: `prefill_batch_size` = `max-batch-prefill-tokens` / `max-input-tokens`.\n\n### How to Benchmark Performance\n\nWe recommend using the [inference-benchmarker tool](https://github.com/huggingface/inference-benchmarker) to benchmark performance on Gaudi hardware.\n\nThis benchmark tool simulates user requests and measures the performance of the model on realistic scenarios.\n\nTo run it on the same machine, you can do the following:\n```bash\nMODEL=meta-llama/Llama-3.1-8B-Instruct\nHF_TOKEN=<your HF READ token>\n# run a benchmark to evaluate the performance of the model for chat use case\n# we mount results to the current directory\ndocker run \\\n    --rm \\\n    -it \\\n    --net host \\\n    -v $(pwd):/opt/inference-benchmarker/results \\\n    -e \"HF_TOKEN=$HF_TOKEN\" \\\n    ghcr.io/huggingface/inference-benchmarker:latest \\\n    inference-benchmarker \\\n    --tokenizer-name \"$MODEL\" \\\n    --url http://localhost:8080 \\\n    --profile chat\n```\n\nPlease refer to the [inference-benchmarker README](https://github.com/huggingface/inference-benchmarker) for more details.\n\n## Explanation: Understanding TGI on Gaudi\n\n### The Warmup Process\n\nIntel Gaudi accelerators perform best when operating on models with fixed tensor shapes. [Intel Gaudi Graph Compiler](https://docs.habana.ai/en/latest/Gaudi_Overview/Intel_Gaudi_Software_Suite.html#graph-compiler-and-runtime)\ngenerates optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be highly dependent on input and output tensor shapes, requiring graph recompilation\nwhen encountering tensors with different shapes within the same topology. While these binaries efficiently utilize Gaudi, the compilation process itself can introduce noticeable overhead in end-to-end execution.\nIn dynamic inference serving scenarios, minimizing the number of graph compilations and reducing the risk of graph compilation occurring during server runtime is important.\n\nTo ensure optimal performance, warmup is performed at the beginning of each server run. This process creates queries with various input shapes based on provided parameters and runs basic TGI operations (prefill, decode).\n\nNote: Model warmup can take several minutes, especially for FP8 inference. For faster subsequent runs, refer to [Disk Caching Eviction Policy](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#disk-caching-eviction-policy).\n\n### Understanding Parameter Tuning\n\n#### Sequence Length Parameters\n- `--max-input-tokens` is the maximum possible input prompt length. Default value is `4095`.\n- `--max-total-tokens` is the maximum possible total length of the sequence (input and output). Default value is `4096`.\n\n#### Batch Size Parameters\n- For prefill operation, please set `--max-batch-prefill-tokens` as `bs * max-input-tokens`, where `bs` is your expected maximum prefill batch size.\n- For decode operation, please set `--max-batch-size` as `bs`, where `bs` is your expected maximum decode batch size.\n- Please note that batch size will be always padded to the nearest shapes that has been warmed up. This is done to avoid out of memory issues and to ensure that the graphs are reused efficiently.\n\n\n## Reference\n\nThis section contains reference information about the Gaudi backend.\n\n### Supported Models\n\nText Generation Inference enables serving optimized models on Gaudi hardware. The following sections list which models (VLMs & LLMs) are supported on Gaudi.\n\n**Large Language Models (LLMs)**\n- [deepseek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)\n- [deepseek-v2](https://huggingface.co/deepseek-ai/DeepSeek-V2)\n- [Llama2](https://huggingface.co/collections/meta-llama/llama-2-family-661da1f90a9d678b6f55773b)\n- [Llama3](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)\n- [CodeLlama](https://huggingface.co/codellama/CodeLlama-13b-hf)\n- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)\n- [Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)\n- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)\n- [Qwen 3](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)\n- [Qwen 3 Moe](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f)\n- [Phi-1.5](https://huggingface.co/microsoft/phi-1_5)\n- [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)\n- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)\n- [Gemma](https://huggingface.co/google/gemma-7b-it)\n- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)\n- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)\n- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)\n- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)\n- [dbrx](https://huggingface.co/databricks/dbrx-instruct)\n- [Starcoder2](https://huggingface.co/bigcode/starcoder2-3b)\n- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)\n- [GPT-2](https://huggingface.co/openai-community/gpt2)\n- [gpt-j-6b](https://huggingface.co/EleutherAI/gpt-j-6b)\n- [gpt-bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)\n- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)\n\n\n**Vision-Language Models (VLMs)**\n- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf)\n- [Mllama (Multimodal Llama from Meta)](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)\n- [idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b)\n- [idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3)\n- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)\n- [Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164)\n- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)\n- [Qwen 2.5 VL](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5)\n- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)\n\nIf you have an issue with a model, please open an issue on the [Gaudi backend repository](https://github.com/huggingface/text-generation-inference/issues).\n\n### Environment Variables\n\nThe following table contains the environment variables that can be used to configure the Gaudi backend:\n\n| Name                        | Value(s)   | Default          | Description                                                                                                                      | Usage                        |\n|-----------------------------| :--------- | :--------------- | :------------------------------------------------------------------------------------------------------------------------------- | :--------------------------- |\n| LIMIT_HPU_GRAPH             | True/False | True             | Skip HPU graph usage for prefill to save memory, set to `True` for large sequence/decoding lengths(e.g. 300/212)                 | add -e in docker run command |\n| SKIP_TOKENIZER_IN_TGI       | True/False | False            | Skip tokenizer for input/output processing                                                                                       | add -e in docker run command |\n| VLLM_SKIP_WARMUP              | True/False | False             | Skip graph warmup during server initialization which is not recommended, but could be used for debug.                            | add -e in docker run command |\n\n\n## Contributing\n\nContributions to the TGI-Gaudi project are welcome. Please refer to the [contributing guide](https://github.com/huggingface/text-generation-inference/blob/main/CONTRIBUTING.md).\n\n**Guidelines for contributing to Gaudi on TGI:** All changes should be made within the `backends/gaudi` folder. In general, you should avoid modifying the router, launcher, or benchmark to accommodate Gaudi hardware, as all Gaudi-specific logic should be contained within the `backends/gaudi` folder.\n\n### Building the Docker Image from Source\n\nTo build the Docker image from source:\n\n```bash\nmake -C backends/gaudi image\n```\n\nThis builds the image and saves it as `tgi-gaudi`. You can then run TGI-Gaudi with this image:\n\n```bash\nmodel=meta-llama/Meta-Llama-3.1-8B-Instruct\nvolume=$PWD/data\nhf_token=YOUR_ACCESS_TOKEN\n\ndocker run --runtime=habana --ipc=host --cap-add=sys_nice \\\n    -p 8080:80 -v $volume:/data -e HF_TOKEN=$hf_token \\\n    tgi-gaudi \\\n    --model-id $model\n```\n\nFor more details, see the [README of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/README.md) and the [Makefile of the Gaudi backend](https://github.com/huggingface/text-generation-inference/blob/main/backends/gaudi/Makefile).\n"
  },
  {
    "path": "docs/source/backends/llamacpp.md",
    "content": "# Llamacpp Backend\n\nThe llamacpp backend facilitates the deployment of large language models\n(LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine\noptimized for both CPU and GPU computation. This backend is a component\nof Hugging Face’s **Text Generation Inference (TGI)** suite,\nspecifically designed to streamline the deployment of LLMs in production\nenvironments.\n\n## Key Capabilities\n\n- Full compatibility with GGUF format and all quantization formats\n  (GGUF-related constraints may be mitigated dynamically by on-the-fly\n  generation in future updates)\n- Optimized inference on CPU and GPU architectures\n- Containerized deployment, eliminating dependency complexity\n- Seamless interoperability with the Hugging Face ecosystem\n\n## Model Compatibility\n\nThis backend leverages models formatted in **GGUF**, providing an\noptimized balance between computational efficiency and model accuracy.\nYou will find the best models on [Hugging Face][GGUF].\n\n## Build Docker image\n\nFor optimal performance, the Docker image is compiled with native CPU\ninstructions by default. As a result, it is strongly recommended to run\nthe container on the same host architecture used during the build\nprocess. Efforts are ongoing to improve portability across different\nsystems while preserving high computational efficiency.\n\nTo build the Docker image, use the following command:\n\n```bash\ndocker build \\\n    -t tgi-llamacpp \\\n    https://github.com/huggingface/text-generation-inference.git \\\n    -f Dockerfile_llamacpp\n```\n\n### Build parameters\n\n| Parameter (with --build-arg)              | Description                      |\n| ----------------------------------------- | -------------------------------- |\n| `llamacpp_version=bXXXX`                  | Specific version of llama.cpp    |\n| `llamacpp_cuda=ON`                        | Enables CUDA acceleration        |\n| `llamacpp_native=OFF`                     | Disable automatic CPU detection  |\n| `llamacpp_cpu_arm_arch=ARCH[+FEATURE]...` | Specific ARM CPU and features    |\n| `cuda_arch=ARCH`                          | Defines target CUDA architecture |\n\nFor example, to target Graviton4 when building on another ARM\narchitecture:\n\n```bash\ndocker build \\\n    -t tgi-llamacpp \\\n    --build-arg llamacpp_native=OFF \\\n    --build-arg llamacpp_cpu_arm_arch=armv9-a+i8mm \\\n    https://github.com/huggingface/text-generation-inference.git \\\n    -f Dockerfile_llamacpp\n```\n\n## Run Docker image\n\n### CPU-based inference\n\n```bash\ndocker run \\\n    -p 3000:3000 \\\n    -e \"HF_TOKEN=$HF_TOKEN\" \\\n    -v \"$HOME/models:/app/models\" \\\n    tgi-llamacpp \\\n    --model-id \"Qwen/Qwen2.5-3B-Instruct\"\n```\n\n### GPU-Accelerated inference\n\n```bash\ndocker run \\\n    --gpus all \\\n    -p 3000:3000 \\\n    -e \"HF_TOKEN=$HF_TOKEN\" \\\n    -v \"$HOME/models:/app/models\" \\\n    tgi-llamacpp \\\n    --n-gpu-layers 99 \\\n    --model-id \"Qwen/Qwen2.5-3B-Instruct\"\n```\n\n## Using a custom GGUF\n\nGGUF files are optional as they will be automatically generated at\nstartup if not already present in the `models` directory. However, if\nthe default GGUF generation is not suitable for your use case, you can\nprovide your own GGUF file with `--model-gguf`, for example:\n\n```bash\ndocker run \\\n    -p 3000:3000 \\\n    -e \"HF_TOKEN=$HF_TOKEN\" \\\n    -v \"$HOME/models:/app/models\" \\\n    tgi-llamacpp \\\n    --model-id \"Qwen/Qwen2.5-3B-Instruct\" \\\n    --model-gguf \"models/qwen2.5-3b-instruct-q4_0.gguf\"\n```\n\nNote that `--model-id` is still required.\n\n## Advanced parameters\n\nA full listing of configurable parameters is available in the `--help`:\n\n```bash\ndocker run tgi-llamacpp --help\n\n```\n\nThe table below summarizes key options:\n\n| Parameter                           | Description                                                            |\n|-------------------------------------|------------------------------------------------------------------------|\n| `--n-threads`                       | Number of threads to use for generation                                |\n| `--n-threads-batch`                 | Number of threads to use for batch processing                          |\n| `--n-gpu-layers`                    | Number of layers to store in VRAM                                      |\n| `--split-mode`                      | Split the model across multiple GPUs                                   |\n| `--defrag-threshold`                | Defragment the KV cache if holes/size > threshold                      |\n| `--numa`                            | Enable NUMA optimizations                                              |\n| `--disable-mmap`                    | Disable memory mapping for the model                                   |\n| `--use-mlock`                       | Use memory locking to prevent swapping                                 |\n| `--disable-offload-kqv`             | Disable offloading of KQV operations to the GPU                        |\n| `--disable-flash-attention`         | Disable flash attention                                                |\n| `--type-k`                          | Data type used for K cache                                             |\n| `--type-v`                          | Data type used for V cache                                             |\n| `--validation-workers`              | Number of tokenizer workers used for payload validation and truncation |\n| `--max-concurrent-requests`         | Maximum number of concurrent requests                                  |\n| `--max-input-tokens`                | Maximum number of input tokens per request                             |\n| `--max-total-tokens`                | Maximum number of total tokens (input + output) per request            |\n| `--max-batch-total-tokens`          | Maximum number of tokens in a batch                                    |\n| `--max-physical-batch-total-tokens` | Maximum number of tokens in a physical batch                           |\n| `--max-batch-size`                  | Maximum number of requests per batch                                   |\n\n---\n[llama.cpp]: https://github.com/ggerganov/llama.cpp\n[GGUF]: https://huggingface.co/models?library=gguf&sort=trending\n"
  },
  {
    "path": "docs/source/backends/neuron.md",
    "content": "# Neuron backend for AWS Trainium and Inferentia\n\nThe Neuron backend allows the deployment of TGI on AWS Trainium and Inferentia family of chips.\n\nThe following hardware targets are supported:\n- Trainium 1,\n- Inferentia 2.\n\n## Features\n\nThe basic TGI features are supported:\n\n- continuous batching,\n- token streaming,\n- greedy search and multinomial sampling using [transformers](https://huggingface.co/docs/transformers/generation_strategies#customize-text-generation).\n\n\n## Deploy the service from the Hugging Face hub\n\nThe simplest way to deploy the NeuronX TGI service for a specific model is to follow the\ndeployment instructions in the model card:\n\n- click on the \"Deploy\" button on the right,\n- select your deployment service (\"Inference Endpoints\" and \"SageMaker\" are supported),\n- select \"AWS Trainum & Inferentia\",\n- follow the instructions.\n\n\n## Deploy the service on a dedicated host\n\nThe service is launched simply by running the text-generation-inference container with two sets of parameters:\n\n```\ndocker run <system_parameters> ghcr.io/huggingface/text-generation-inference:3.3.5-neuron <service_parameters>\n```\n\n- system parameters are used to map ports, volumes and devices between the host and the service,\n- service parameters are forwarded to the `text-generation-launcher`.\n\nWhen deploying a service, you will need a pre-compiled Neuron model. The Neuron TGI backend supports two main modes of operation:\n\n- you can either deploy the service on a model that has already been exported to Neuron,\n- or alternatively you can take advantage of the Neuron Model Cache to export your own model.\n\n### Common system parameters\n\nWhenever you launch a TGI service, we highly recommend you to mount a shared volume mounted as `/data` in the container: this is where\nthe models will be cached to speed up further instantiations of the service.\n\nNote also that enough neuron devices should be made visible to the container, knowing that each neuron device has two cores (so when deploying on two cores you need to expose at least one device).\nThe recommended way to expose a device in a production environment is to use explicitly the `--device` option (e.g `--device /dev/neuron0`) repeated as many time as there are devices to be exposed.\n\nNote: alternatively, for a quick local test it is also possible to launch the service in `privileged` mode to get access to all neuron devices.\n\nFinally, you might want to export the `HF_TOKEN` if you want to access gated repositories.\n\nHere is an example of a service instantiation exposing only the first device:\n\n```\ndocker run -p 8080:80 \\\n       -v $(pwd)/data:/data \\\n       --device=/dev/neuron0 \\\n       -e HF_TOKEN=${HF_TOKEN} \\\n       ghcr.io/huggingface/text-generation-inference:<VERSION>-neuron \\\n       <service_parameters>\n```\n\n### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) (recommended)\n\nWe maintain a Neuron Model Cache of the most popular architecture and deployment parameters under [aws-neuron/optimum-neuron-cache](https://huggingface.co/aws-neuron/optimum-neuron-cache).\n\nIf you just want to try the service quickly using a model without exporting it to Neuron first, it is thus still possible, pending some conditions:\n- you must specify the export parameters when launching the service (or use default parameters),\n- the model configuration must be cached.\n\nThe snippet below shows how you can deploy a service from a hub standard model:\n\n```\nexport HF_TOKEN=<YOUR_TOKEN>\ndocker run -p 8080:80 \\\n       -v $(pwd)/data:/data \\\n       --device=/dev/neuron0 \\\n       --device=/dev/neuron1 \\\n       --device=/dev/neuron2 \\\n       --device=/dev/neuron3 \\\n       -e HF_TOKEN=${HF_TOKEN} \\\n       -e HF_AUTO_CAST_TYPE=\"fp16\" \\\n       -e HF_NUM_CORES=8 \\\n       ghcr.io/huggingface/text-generation-inference:<VERSION>-neuron \\\n       --model-id meta-llama/Meta-Llama-3-8B \\\n       --max-batch-size 1 \\\n       --max-input-length 3164 \\\n       --max-total-tokens 4096\n```\n\n### Using a model exported to a local path\n\nAlternatively, you can first [export the model to neuron format](https://huggingface.co/docs/optimum-neuron/main/en/guides/export_model#exporting-neuron-models-using-text-generation-inference) locally.\n\nYou can then deploy the service inside the shared volume:\n\n```\ndocker run -p 8080:80 \\\n       -v $(pwd)/data:/data \\\n       --device=/dev/neuron0 \\\n       --device=/dev/neuron1 \\\n       ghcr.io/huggingface/text-generation-inference:<VERSION>-neuron \\\n       --model-id /data/<neuron_model_path>\n```\n\nNote: You don't need to specify any service parameters, as they will all be deduced from the model export configuration. You must however expose enough devices to match the number of cores specified during the export phase.\n\n\n### Using a neuron model from the 🤗 [HuggingFace Hub](https://huggingface.co/)\n\nThe easiest way to share a neuron model inside your organization is to push it on the Hugging Face hub, so that it can be deployed directly without requiring an export.\n\nThe snippet below shows how you can deploy a service from a hub neuron model:\n\n```\ndocker run -p 8080:80 \\\n       -v $(pwd)/data:/data \\\n       --device=/dev/neuron0 \\\n       --device=/dev/neuron1 \\\n       -e HF_TOKEN=${HF_TOKEN} \\\n       ghcr.io/huggingface/text-generation-inference:<VERSION>-neuron \\\n       --model-id <organization>/<neuron-model>\n```\n\n### Choosing service parameters\n\nUse the following command to list the available service parameters:\n\n```\ndocker run ghcr.io/huggingface/text-generation-inference:<VERSION>-neuron --help\n```\n\nThe configuration of an inference endpoint is always a compromise between throughput and latency: serving more requests in parallel will allow a higher throughput, but it will increase the latency.\n\nThe neuron models have static input dimensions `[batch_size, max_length]`.\n\nThis adds several restrictions to the following parameters:\n\n- `--max-batch-size` must be set to `batch size`,\n- `--max-input-length` must be lower than `max_length`,\n- `--max-total-tokens` must be set to `max_length` (it is per-request).\n\nAlthough not strictly necessary, but important for efficient prefilling:\n\n- `--max-batch-prefill-tokens` should be set to `batch_size` * `max-input-length`.\n\n### Choosing the correct batch size\n\nAs seen in the previous paragraph, neuron model static batch size has a direct influence on the endpoint latency and throughput.\n\nPlease refer to [text-generation-inference](https://github.com/huggingface/text-generation-inference) for optimization hints.\n\nNote that the main constraint is to be able to fit the model for the specified `batch_size` within the total device memory available\non your instance (16GB per neuron core, with 2 cores per device).\n\n## Query the service\n\nYou can query the model using either the `/generate` or `/generate_stream` routes:\n\n```\ncurl 127.0.0.1:8080/generate \\\n    -X POST \\\n    -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n    -H 'Content-Type: application/json'\n```\n\n```\ncurl 127.0.0.1:8080/generate_stream \\\n    -X POST \\\n    -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n    -H 'Content-Type: application/json'\n```\n\nNote: replace 127.0.0.1:8080 with your actual IP address and port.\n"
  },
  {
    "path": "docs/source/backends/trtllm.md",
    "content": "# TensorRT-LLM backend\n\nThe NVIDIA TensorRT-LLM (TRTLLM) backend is a high-performance backend for LLMs\nthat uses NVIDIA's TensorRT library for inference acceleration.\nIt makes use of specific optimizations for NVIDIA GPUs, such as custom kernels.\n\nTo use the TRTLLM backend **you need to compile** `engines` for the models you want to use.\nEach `engine` must be compiled for a given set of:\n- GPU architecture that you will use for inference (e.g. A100, L40, etc.)\n- Maximum batch size\n- Maximum input length\n- Maximum output length\n- Maximum beams width\n\n## Supported models\n\nCheck the [support matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html) to see which models are\nsupported.\n\n## Compiling engines\n\nYou can use [Optimum-NVIDIA](https://github.com/huggingface/optimum-nvidia) to compile engines for the models you\nwant to use.\n\n```bash\nMODEL_NAME=\"meta-llama/Llama-3.1-8B-Instruct\"\nDESTINATION=\"/tmp/engines/$MODEL_NAME\"\nHF_TOKEN=\"hf_xxx\"\n# Compile the engine using Optimum-NVIDIA\n# This will create a compiled engine in the /tmp/engines/meta-llama/Llama-3.1-8B-Instruct\n# directory for 1 GPU\ndocker run \\\n  --rm \\\n  -it \\\n  --gpus=1 \\\n  --shm-size=1g \\\n  -v \"$DESTINATION\":/engine \\\n  -e HF_TOKEN=$HF_TOKEN \\\n  -e HF_HUB_ENABLE_HF_TRANSFER=1 \\\n  huggingface/optimum-nvidia:v0.1.0b9-py310 \\\n    bash -c \"optimum-cli export trtllm \\\n    --tp=1 \\\n    --pp=1 \\\n    --max-batch-size=64 \\\n    --max-input-length 4096 \\\n    --max-output-length 8192 \\\n    --max-beams-width=1 \\\n    --destination /tmp/engine \\\n    $MODEL_NAME && cp -rL /tmp/engine/* /engine/\"\n```\n\nYour compiled engine will be saved in the `/tmp/engines/$MODEL_NAME` directory, in a subfolder named after the GPU used to compile the model.\n\n## Using the TRTLLM backend\n\nRun TGI-TRTLLM Docker image with the compiled engine:\n\n```bash\nMODEL_NAME=\"meta-llama/Llama-3.1-8B-Instruct\"\nDESTINATION=\"/tmp/engines/$MODEL_NAME\"\nHF_TOKEN=\"hf_xxx\"\ndocker run \\\n  --gpus 1 \\\n  --shm-size=1g \\\n  -it \\\n  --rm \\\n  -p 3000:3000 \\\n  -e MODEL=$MODEL_NAME \\\n  -e PORT=3000 \\\n  -e HF_TOKEN=$HF_TOKEN \\\n  -v \"$DESTINATION\"/<YOUR_GPU_ARCHITECTURE>/engines:/data \\\n  ghcr.io/huggingface/text-generation-inference:latest-trtllm \\\n  --model-id /data/ \\\n  --tokenizer-name $MODEL_NAME\n```\n\n## Development\n\nTo develop TRTLLM backend, you can use [dev containers](https://containers.dev/) with the following `.devcontainer.json` file:\n```json\n{\n  \"name\": \"CUDA\",\n  \"build\": {\n    \"dockerfile\": \"Dockerfile_trtllm\",\n    \"context\": \"..\"\n  },\n  \"remoteEnv\": {\n    \"PATH\": \"${containerEnv:PATH}:/usr/local/cuda/bin\",\n    \"LD_LIBRARY_PATH\": \"$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64\",\n    \"XLA_FLAGS\": \"--xla_gpu_cuda_data_dir=/usr/local/cuda\"\n  },\n  \"customizations\" : {\n    \"jetbrains\" : {\n      \"backend\" : \"CLion\"\n    }\n  }\n}\n```\n\nand `Dockerfile_trtllm`:\n\n```Dockerfile\nARG cuda_arch_list=\"75-real;80-real;86-real;89-real;90-real\"\nARG build_type=release\nARG ompi_version=4.1.7\n\n# CUDA dependent dependencies resolver stage\nFROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 AS cuda-builder\n\nRUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \\\n    build-essential \\\n    cmake \\\n    curl \\\n    gcc-14  \\\n    g++-14 \\\n    git \\\n    git-lfs \\\n    lld \\\n    libssl-dev \\\n    libucx-dev \\\n    libasan8 \\\n    libubsan1 \\\n    ninja-build \\\n    pkg-config \\\n    pipx \\\n    python3 \\\n    python3-dev \\\n    python3-setuptools \\\n    tar \\\n    wget --no-install-recommends && \\\n    pipx ensurepath\n\nENV TGI_INSTALL_PREFIX=/usr/local/tgi\nENV TENSORRT_INSTALL_PREFIX=/usr/local/tensorrt\n\n# Install OpenMPI\nFROM cuda-builder AS mpi-builder\nWORKDIR /opt/src/mpi\n\nARG ompi_version\nENV OMPI_VERSION=${ompi_version}\nENV OMPI_TARBALL_FILENAME=openmpi-${OMPI_VERSION}.tar.bz2\nADD --checksum=sha256:54a33cb7ad81ff0976f15a6cc8003c3922f0f3d8ceed14e1813ef3603f22cd34 \\\n    https://download.open-mpi.org/release/open-mpi/v4.1/${OMPI_TARBALL_FILENAME} .\n\nRUN tar --strip-components=1 -xf ${OMPI_TARBALL_FILENAME} &&\\\n    ./configure --prefix=/usr/local/mpi --with-cuda=/usr/local/cuda --with-slurm && \\\n    make -j all && \\\n    make install && \\\n    rm -rf ${OMPI_TARBALL_FILENAME}/..\n\n# Install TensorRT\nFROM cuda-builder AS trt-builder\nCOPY backends/trtllm/scripts/install_tensorrt.sh /opt/install_tensorrt.sh\nRUN chmod +x /opt/install_tensorrt.sh && \\\n    /opt/install_tensorrt.sh\n\n# Build Backend\nFROM cuda-builder AS tgi-builder\nWORKDIR /usr/src/text-generation-inference\n\n# Scoped global args reuse\nARG cuda_arch_list\nARG build_type\nARG sccache_gha_enabled\nARG actions_results_url\nARG actions_runtime_token\n\n# Install Rust\nENV PATH=\"/root/.cargo/bin:$PATH\"\nRUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | bash -s -- -y && \\\n    chmod -R a+w /root/.rustup && \\\n    chmod -R a+w /root/.cargo && \\\n    cargo install sccache --version \">=0.10.0\" --locked\n\nENV LD_LIBRARY_PATH=\"/usr/local/mpi/lib:$LD_LIBRARY_PATH\"\nENV PKG_CONFIG_PATH=\"/usr/local/mpi/lib/pkgconfig\"\nENV CMAKE_PREFIX_PATH=\"/usr/local/mpi:/usr/local/tensorrt\"\n\nENV USE_LLD_LINKER=ON\nENV CUDA_ARCH_LIST=${cuda_arch_list}\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/consuming_tgi.md",
    "content": "# Consuming Text Generation Inference\n\nThere are many ways to consume Text Generation Inference (TGI) server in your applications. After launching the server, you can use the [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) `/v1/chat/completions` route and make a `POST` request to get results from the server. You can also pass `\"stream\": true` to the call if you want TGI to return a stream of tokens.\n\nFor more information on the API, consult the OpenAPI documentation of `text-generation-inference` available [here](https://huggingface.github.io/text-generation-inference).\n\nYou can make the requests using any tool of your preference, such as curl, Python, or TypeScript. For an end-to-end experience, we've open-sourced [ChatUI](https://github.com/huggingface/chat-ui), a chat interface for open-access models.\n\n## curl\n\nAfter a successful server launch, you can query the model using the `v1/chat/completions` route, to get responses that are compliant to the OpenAI Chat Completion spec:\n\n```bash\ncurl localhost:8080/v1/chat/completions \\\n    -X POST \\\n    -d '{\n  \"model\": \"tgi\",\n  \"messages\": [\n    {\n      \"role\": \"system\",\n      \"content\": \"You are a helpful assistant.\"\n    },\n    {\n      \"role\": \"user\",\n      \"content\": \"What is deep learning?\"\n    }\n  ],\n  \"stream\": true,\n  \"max_tokens\": 20\n}' \\\n    -H 'Content-Type: application/json'\n```\n\nFor non-chat use-cases, you can also use the `/generate` and `/generate_stream` routes.\n\n```bash\ncurl 127.0.0.1:8080/generate \\\n    -X POST \\\n    -d '{\n  \"inputs\":\"What is Deep Learning?\",\n  \"parameters\":{\n    \"max_new_tokens\":20\n  }\n}' \\\n    -H 'Content-Type: application/json'\n```\n\n## Python\n\n### Inference Client\n\n[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/main/en/index) is a Python library to interact with the Hugging Face Hub, including its endpoints. It provides a high-level class, [`huggingface_hub.InferenceClient`](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient), which makes it easy to make calls to TGI's Messages API. `InferenceClient` also takes care of parameter validation and provides a simple-to-use interface.\n\nInstall `huggingface_hub` package via pip.\n\n```bash\npip install huggingface_hub\n```\n\nYou can now use `InferenceClient` the exact same way you would use `OpenAI` client in Python\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(\n    base_url=\"http://localhost:8080/v1/\",\n)\n\noutput = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"Count to 10\"},\n    ],\n    stream=True,\n    max_tokens=1024,\n)\n\nfor chunk in output:\n    print(chunk.choices[0].delta.content)\n```\n\nYou can check out more details about OpenAI compatibility [here](https://huggingface.co/docs/huggingface_hub/en/guides/inference#openai-compatibility).\n\nThere is also an async version of the client, `AsyncInferenceClient`, based on `asyncio` and `aiohttp`. You can find docs for it [here](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.AsyncInferenceClient)\n\n### OpenAI Client\n\nYou can directly use the OpenAI [Python](https://github.com/openai/openai-python) or [JS](https://github.com/openai/openai-node) clients to interact with TGI.\n\nInstall the OpenAI Python package via pip.\n\n```bash\npip install openai\n```\n\n```python\nfrom openai import OpenAI\n\n# init the client but point it to TGI\nclient = OpenAI(\n    base_url=\"http://localhost:8080/v1/\",\n    api_key=\"-\"\n)\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n        {\"role\": \"user\", \"content\": \"What is deep learning?\"}\n    ],\n    stream=True\n)\n\n# iterate and print stream\nfor message in chat_completion:\n    print(message)\n```\n\n## UI\n\n### Gradio\n\nGradio is a Python library that helps you build web applications for your machine learning models with a few lines of code. It has a `ChatInterface` wrapper that helps create neat UIs for chatbots. Let's take a look at how to create a chatbot with streaming mode using TGI and Gradio. Let's install Gradio and Hub Python library first.\n\n```bash\npip install huggingface-hub gradio\n```\n\nAssume you are serving your model on port 8080, we will query through [InferenceClient](consuming_tgi#inference-client).\n\n```python\nimport gradio as gr\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(base_url=\"http://127.0.0.1:8080\")\n\ndef inference(message, history):\n    partial_message = \"\"\n    output = client.chat.completions.create(\n        messages=[\n            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n            {\"role\": \"user\", \"content\": message},\n        ],\n        stream=True,\n        max_tokens=1024,\n    )\n\n    for chunk in output:\n        partial_message += chunk.choices[0].delta.content\n        yield partial_message\n\ngr.ChatInterface(\n    inference,\n    type=\"messages\",\n    description=\"This is the demo for Gradio UI consuming TGI endpoint.\",\n    title=\"Gradio 🤝 TGI\",\n    examples=[\"Are tomatoes vegetables?\"],\n).queue().launch()\n```\n\nYou can check out the UI and try the demo directly here 👇\n\n<div class=\"block dark:hidden\">\n\t<iframe\n        src=\"https://merve-gradio-tgi-2.hf.space?__theme=light\"\n        width=\"850\"\n        height=\"750\"\n    ></iframe>\n</div>\n<div class=\"hidden dark:block\">\n    <iframe\n        src=\"https://merve-gradio-tgi-2.hf.space?__theme=dark\"\n        width=\"850\"\n        height=\"750\"\n    ></iframe>\n</div>\n\n\nYou can read more about how to customize a `ChatInterface` [here](https://www.gradio.app/guides/creating-a-chatbot-fast).\n\n### ChatUI\n\n[ChatUI](https://github.com/huggingface/chat-ui) is an open-source interface built for consuming LLMs. It offers many customization options, such as web search with SERP API and more. ChatUI can automatically consume the TGI server and even provides an option to switch between different TGI endpoints. You can try it out at [Hugging Chat](https://huggingface.co/chat/), or use the [ChatUI Docker Space](https://huggingface.co/new-space?template=huggingchat/chat-ui-template) to deploy your own Hugging Chat to Spaces.\n\nTo serve both ChatUI and TGI in same environment, simply add your own endpoints to the `MODELS` variable in `.env.local` file inside the `chat-ui` repository. Provide the endpoints pointing to where TGI is served.\n\n```\n{\n// rest of the model config here\n\"endpoints\": [{\"url\": \"https://HOST:PORT/generate_stream\"}]\n}\n```\n\n![ChatUI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/chatui_screen.png)\n"
  },
  {
    "path": "docs/source/basic_tutorials/gated_model_access.md",
    "content": "# Serving Private & Gated Models\n\nIf the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)\n\nIf you're using the CLI, set the `HF_TOKEN` environment variable. For example:\n\n```\nexport HF_TOKEN=<YOUR READ TOKEN>\n```\n\nIf you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below.\n\n```bash\nmodel=meta-llama/Llama-2-7b-chat-hf\nvolume=$PWD/data\ntoken=<your READ token>\n\ndocker run --gpus all \\\n    --shm-size 1g \\\n    -e HF_TOKEN=$token \\\n    -p 8080:80 \\\n    -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 \\\n    --model-id $model\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/monitoring.md",
    "content": "# Monitoring TGI server with Prometheus and Grafana dashboard\n\nTGI server deployment can easily be monitored through a Grafana dashboard, consuming a Prometheus data collection. Example of inspectable metrics are statistics on the effective batch sizes used by TGI, prefill/decode latencies, number of generated tokens, etc.\n\nIn this tutorial, we look at how to set up a local Grafana dashboard to monitor TGI usage.\n\n![Grafana dashboard for TGI](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/grafana.png)\n\n## Setup on the server machine\n\nFirst, on your server machine, TGI needs to be launched as usual. TGI exposes [multiple](https://github.com/huggingface/text-generation-inference/discussions/1127#discussioncomment-7240527) metrics that can be collected by Prometheus monitoring server.\n\nIn the rest of this tutorial, we assume that TGI was launched through Docker with `--network host`.\n\nOn the server where TGI is hosted, a Prometheus server needs to be installed and launched. To do so, please follow [Prometheus installation instructions](https://prometheus.io/download/#prometheus). For example, at the time of writing on a Linux machine:\n\n```\nwget https://github.com/prometheus/prometheus/releases/download/v2.52.0/prometheus-2.52.0.linux-amd64.tar.gz\ntar -xvzf prometheus-2.52.0.linux-amd64.tar.gz\ncd prometheus\n```\n\nPrometheus needs to be configured to listen on TGI's port. To do so, in Prometheus configuration file `prometheus.yml`, one needs to edit the lines:\n```\n    static_configs:\n      - targets: [\"0.0.0.0:80\"]\n```\nto use the correct IP address and port.\n\nWe suggest to try `curl 0.0.0.0:80/generate -X POST -d '{\"inputs\":\"hey chatbot, how are\",\"parameters\":{\"max_new_tokens\":15}}' -H 'Content-Type: application/json'` on the server side to make sure to configure the correct IP and port.\n\nOnce Prometheus is configured, Prometheus server can be launched on the same machine where TGI is launched:\n```\n./prometheus --config.file=\"prometheus.yml\"\n```\n\nIn this guide, Prometheus monitoring data will be consumed on a local computer. Hence, we need to forward Prometheus port (by default 9090) to the local computer. To do so, we can for example:\n* Use ssh [local port forwarding](https://www.ssh.com/academy/ssh/tunneling-example)\n* Use ngrok port tunneling\n\nFor simplicity, we will use [Ngrok](https://ngrok.com/docs/) in this guide to tunnel Prometheus port from the TGI server to the outside world.\n\nFor that, you should follow the steps at https://dashboard.ngrok.com/get-started/setup/linux, and once Ngrok is installed, use:\n```bash\nngrok http http://0.0.0.0:9090\n```\n\nAs a sanity check, one can make sure that Prometheus server can be accessed at the URL given by Ngrok (in the style of https://d661-4-223-164-145.ngrok-free.app) from a local machine.\n\n## Setup on the monitoring machine\n\nMonitoring is typically done on an other machine than the server one. We use a Grafana dashboard to monitor TGI's server usage.\n\nTwo options are available:\n* Use Grafana Cloud for an hosted dashboard solution (https://grafana.com/products/cloud/).\n* Self-host a grafana dashboard.\n\nIn this tutorial, for simplicity, we will self host the dashbard. We recommend installing Grafana Open-source edition following [the official install instructions](https://grafana.com/grafana/download?platform=linux&edition=oss), using the available Linux binaries. For example:\n\n```bash\nwget https://dl.grafana.com/oss/release/grafana-11.0.0.linux-amd64.tar.gz\ntar -zxvf grafana-11.0.0.linux-amd64.tar.gz\ncd grafana-11.0.0\n./bin/grafana-server\n```\n\nOnce the Grafana server is launched, the Grafana interface is available at http://localhost:3000. One needs to log in with the `admin` username and `admin` password.\n\nOnce logged in, the Prometheus data source for Grafana needs to be configured, in the option `Add your first data source`. There, a Prometheus data source needs to be added with the Ngrok address we got earlier, that exposes Prometheus port (example: https://d661-4-223-164-145.ngrok-free.app).\n\nOnce Prometheus data source is configured, we can finally create our dashboard! From home, go to `Create your first dashboard` and then `Import dashboard`. There, we will use the recommended dashboard template [tgi_grafana.json](https://github.com/huggingface/text-generation-inference/blob/main/assets/tgi_grafana.json) for a dashboard ready to be used, but you may configure your own dashboard as you like.\n\nCommunity contributed dashboard templates are also available, for example [here](https://grafana.com/grafana/dashboards/19831-text-generation-inference-dashboard/) or [here](https://grafana.com/grafana/dashboards/20246-text-generation-inference/).\n\nLoad your dashboard configuration, and your TGI dashboard should be ready to go!\n"
  },
  {
    "path": "docs/source/basic_tutorials/non_core_models.md",
    "content": "# Non-core Model Serving\n\nTGI supports various LLM architectures (see full list [here](../supported_models)). If you wish to serve a model that is not one of the supported models, TGI will fallback to the `transformers` implementation of that model. This means you will be unable to use some of the features introduced by TGI, such as tensor-parallel sharding or flash attention. However, you can still get many benefits of TGI, such as continuous batching or streaming outputs.\n\nYou can serve these models using the same Docker command-line invocation as with fully supported models 👇\n\n```bash\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id gpt2\n```\n\nIf the model you wish to serve is a custom transformers model, and its weights and implementation are available in the Hub, you can still serve the model by passing the `--trust-remote-code` flag to the `docker run` command like below 👇\n\n```bash\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id <CUSTOM_MODEL_ID> --trust-remote-code\n```\n\nFinally, if the model is not on Hugging Face Hub but on your local, you can pass the path to the folder that contains your model like below 👇\n\n```bash\n# Make sure your model is in the $volume directory\ndocker run --shm-size 1g -p 8080:80 -v $volume:/data  ghcr.io/huggingface/text-generation-inference:latest --model-id /data/<PATH-TO-FOLDER>\n```\n\nYou can refer to [transformers docs on custom models](https://huggingface.co/docs/transformers/main/en/custom_models) for more information.\n"
  },
  {
    "path": "docs/source/basic_tutorials/preparing_model.md",
    "content": "# Preparing the Model\n\nText Generation Inference improves the model in several aspects.\n\n## Quantization\n\nTGI supports [bits-and-bytes](https://github.com/TimDettmers/bitsandbytes#bitsandbytes), [GPT-Q](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [Marlin](https://github.com/IST-DASLab/marlin), [EETQ](https://github.com/NetEase-FuXi/EETQ), [EXL2](https://github.com/turboderp/exllamav2), and [fp8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) quantization. To speed up inference with quantization, simply set `quantize` flag to `bitsandbytes`, `gptq`, `awq`, `marlin`, `exl2`, `eetq` or `fp8` depending on the quantization technique you wish to use. When using GPT-Q quantization, you need to point to one of the models [here](https://huggingface.co/models?search=gptq). Similarly, when using AWQ quantization, you need to point to one of [these models](https://huggingface.co/models?search=awq). To get more information about quantization, please refer to [quantization guide](./../conceptual/quantization)\n\n\n## RoPE Scaling\n\nRoPE scaling can be used to increase the sequence length of the model during the inference time without necessarily fine-tuning it. To enable RoPE scaling, simply pass `--rope-scaling`, `--max-input-length` and `--rope-factors` flags when running through CLI. `--rope-scaling` can take the values `linear` or `dynamic`. If your model is not fine-tuned to a longer sequence length, use `dynamic`. `--rope-factor` is the ratio between the intended max sequence length and the model's original max sequence length. Make sure to pass `--max-input-length` to provide maximum input length for extension.\n\n<Tip>\n\nWe recommend using `dynamic` RoPE scaling.\n\n</Tip>\n\n## Safetensors\n\n[Safetensors](https://github.com/huggingface/safetensors) is a fast and safe persistence format for deep learning models, and is required for tensor parallelism. TGI supports `safetensors` model loading under the hood. By default, given a repository with `safetensors` and `pytorch` weights, TGI will always load `safetensors`. If there's no `pytorch` weights, TGI will convert the weights to `safetensors` format.\n"
  },
  {
    "path": "docs/source/basic_tutorials/safety.md",
    "content": "# Model safety.\n\n[Pytorch uses pickle](https://pytorch.org/docs/master/generated/torch.load.html) by default meaning that for quite a long while\n*Every* model using that format is potentially executing unintended code while purely loading the model.\n\nThere is a big red warning on Python's page for pickle [link](https://docs.python.org/3/library/pickle.html) but for quite a while\nthis was ignored by the community. Now that AI/ML is getting used much more ubiquitously we need to switch away from this format.\n\nHuggingFace is leading the effort here by creating a new format which contains pure data ([safetensors](https://github.com/huggingface/safetensors))\nand moving slowly but surely all the libs to make use of it by default.\nThe move is intentionnally slow in order to make breaking changes as little impact as possible on users throughout.\n\n\n# TGI 2.0\n\nSince the release of TGI 2.0, we take the opportunity of this major version increase to break backward compatibility for these pytorch\nmodels (since they are a huge security risk for anyone deploying them).\n\n\nFrom now on, TGI will not convert automatically pickle files without having `--trust-remote-code` flag or `TRUST_REMOTE_CODE=true` in the environment variables.\nThis flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers.\n\n\nIf you want to use a model that uses pickle, but you still do not want to trust the authors entirely we recommend making a convertion on our space made for that.\n\nhttps://huggingface.co/spaces/safetensors/convert\n\nThis space will create a PR on the original model, which you are use directly regardless of merge status from the original authors. Just use\n```\ndocker run .... --revision refs/pr/#ID # Or use REVISION=refs/pr/#ID in the environment\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/train_medusa.md",
    "content": "# Train Medusa\n\nThis tutorial will show you how to train a Medusa model on a dataset of your choice. Please check out the [speculation documentation](../conceptual/speculation) for more information on how Medusa works and speculation in general.\n\n## What are the benefits of training a Medusa model?\n\nTraining Medusa heads can greatly improve the speed of generation. Medusa adds extra \"heads\" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training.\n\nOne of the most important things is to have a good dataset (with similar data to what will be used in production) because Medusa has a much higher hit-rate when the generation is in-domain.\n\nIf you train Medusa on a dataset that is very different from the one you will use in production then the model will not be able to predict the future tokens accurately and consequently the speedup will be minimal or non-existent.\n\n## Self-distillation (Generating data for training)\n\nThere are many methods for preparing data for training, but one of the easiest and most effective ways is to \"self-distill\" the data. This means that you can use the same model to generate the data that you will use to train the model.\n\nEssentially, you prompt the model with a similar input to what you will use in production and the model will generate the output.\n\nWe'll use this output to help train the medusa heads to predict the `n+1`, `n+2`, `n+3`, etc tokens in the sequence.\n\n## Training\n\nThe original implementation of Medusa is available at [https://github.com/FasterDecoding/Medusa](https://github.com/FasterDecoding/Medusa) and we'll follow a very similar process to train the model as described on the original repository.\n\n### Getting Started\n\nThere are two methods for training the model:\n\n- `torchrun` that is a wrapper around `torch.distributed.launch`\n- a forked version of `axlotl` that supports Medusa\n\nIn this tutorial we'll use `torchrun` to train the model as it is the most straightforward way to train the model but similar steps can be followed to train the model using `axlotl` if you prefer.\n\n### Training with `torchrun`\n\n```bash\nmkdir medusa-training\ncd medusa-training\n\npyenv install 3.10\npyenv local 3.10\n\nuv venv -p 3.10\nsource .venv/bin/activate\n```\n\nNow lets clone the original `Medusa` repository and install the library.\n\n```bash\ngit clone https://github.com/FasterDecoding/Medusa.git\ncd Medusa\npip install -e .\n```\n\nNext we'll need some data to train on, we can use the `ShareGPT_Vicuna_unfiltered` dataset that is available on the Hugging Face Hub.\n\n```bash\napt install git-lfs\ngit lfs install\ngit clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered\n```\n\nCurrently our directory structure looks like this:\n\n```bash\n.\n├── assets\n├── CITATION.cff\n├── create_data.py\n├── data_generation\n├── deepspeed.json\n├── last_run_prepared\n├── LICENSE\n├── llm_judge\n├── medusa\n├── medusa_llm.egg-info\n├── mistral.json\n├── notebooks\n├── pyproject.toml\n├── README.md\n├── ROADMAP.md\n├── scripts\n├── ShareGPT_Vicuna_unfiltered\n│   ├── README.md\n│   ├── ShareGPT_2023.05.04v0_Wasteland_Edition.json\n│   └── ShareGPT_V4.3_unfiltered_cleaned_split.json\n├── simple_gradio_interface.py\n├── tiny-llama.json\n└── vicuna_7b_qlora_stage1\n```\n\n## Start Training\n\nNow the lets generate the data and start training the model. This process will take a while since we are generating data from the model.\n\nFirst make sure you have an instance of TGI running with the model you want to use for self-distillation.\n\n```bash\nmodel=HuggingFaceH4/zephyr-7b-beta\nvolume=/home/ubuntu/.cache/huggingface/hub/\n\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model\n```\n\nNow we can generate the data using the `create_data.py` script.\n\n```bash\npython create_data.py \\\n    --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \\\n    --output-filename zephyr_self_distill.json\n```\n\nAt this point our terminal should look like this:\n\n<div class=\"flex justify-center\">\n    <img\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-large.gif\"\n        width=\"550\"\n    />\n</div>\n\n> Note: In the screen shot above we are only using a the first 500 examples from the dataset to speed up the process, you should have a much larger dataset for training.\n\nNow we can finally get to the fun part and start training the model!\n\nUsing `torchrun` we can easily launch the `medusa` training script with the `zephyr_self_distill.json` configuration file.\n\n> NOTE: If you just self-distilled you may still have the model running, make sure to stop it before starting the training in order to allow all of the resources to be used for training.\n\n```bash\nWANDB_MODE=offline torchrun --nproc_per_node=4 medusa/train/train_legacy.py \\\n    --model_name_or_path HuggingFaceH4/zephyr-7b-beta \\\n    --data_path zephyr_self_distill.json \\\n    --bf16 True \\\n    --output_dir zephyr_out \\\n    --num_train_epochs 5 \\\n    --per_device_train_batch_size 4 \\\n    --per_device_eval_batch_size 4 \\\n    --gradient_accumulation_steps 4 \\\n    --evaluation_strategy \"no\" \\\n    --save_strategy \"no\" \\\n    --learning_rate 1e-3 \\\n    --weight_decay 0.0 \\\n    --warmup_ratio 0.1 \\\n    --lr_scheduler_type \"cosine\" \\\n    --logging_steps 1 \\\n    --tf32 True \\\n    --model_max_length 2048 \\\n    --lazy_preprocess True \\\n    --medusa_num_heads 3 \\\n    --medusa_num_layers 1 \\\n    --deepspeed deepspeed.json\n```\n\n<div class=\"flex justify-center\">\n    <img\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/medusa-train-heads-large.gif\"\n        width=\"550\"\n    />\n</div>\n\nIf successful, you should see the similar output to the one below:\n\n```bash\nwandb: Run history:\nwandb:                    train/epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███\nwandb:              train/global_step ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███\nwandb:            train/learning_rate ▅███▇▇▆▅▅▄▃▂▂▁▁▁\nwandb:                     train/loss ██▆▄▄▃▃▂▂▃▁▁▂▁▁▁\nwandb:             train/medusa0_loss ▆▆▇▆▆▅▄▅▃▃▃▃▂▂▂▂▂▃▂▂▂▁▁▁▂▁▁▁▁▁█▁▁▁▂▁▁▁▁▁\nwandb:             train/medusa0_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▄▄▄▃▄▃▄▄▅▅▆▅▆▆▇▅▇▇▄▇█▇▅▇█▆▇▇\nwandb:             train/medusa1_loss ▇▇█▇▇▆▅▅▃▄▃▃▃▃▃▃▃▃▃▃▂▁▂▂▂▁▁▂▁▁▇▁▁▁▂▁▁▁▁▁\nwandb:             train/medusa1_top1 ▁▁▁▁▁▁▁▁▃▂▃▃▃▄▄▃▃▂▃▃▅▅▆▄█▆▇▅▇▇▅█▇▇▅▇█▆▆▇\nwandb:             train/medusa2_loss ▃▃▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▂▁▁▁▁▁\nwandb:             train/medusa2_top1 ▁▁▁▂▁▁▁▁▂▂▃▃▃▄▄▃▃▂▃▃▅▆▅▄█▆▆▅▆▆▄█▇▇▄▇█▆▆▇\nwandb:               train/total_flos ▁\nwandb:               train/train_loss ▁\nwandb:            train/train_runtime ▁\nwandb: train/train_samples_per_second ▁\nwandb:   train/train_steps_per_second ▁\nwandb:\nwandb: Run summary:\nwandb:                    train/epoch 2.0\nwandb:              train/global_step 16\nwandb:            train/learning_rate 0.0\nwandb:                     train/loss 14.8906\nwandb:             train/medusa0_loss 4.25\nwandb:             train/medusa0_top1 0.28809\nwandb:             train/medusa1_loss 4.8125\nwandb:             train/medusa1_top1 0.22727\nwandb:             train/medusa2_loss 5.5\nwandb:             train/medusa2_top1 0.17293\nwandb:               train/total_flos 0.0\nwandb:               train/train_loss 23.98242\nwandb:            train/train_runtime 396.9266\nwandb: train/train_samples_per_second 2.519\nwandb:   train/train_steps_per_second 0.04\n```\n\nLast but most importantly, don't forget to push this model to the Hugging Face Hub so you can use it in your projects.\n\n```bash\npython -m medusa.hf_utils \\\n    --folder zephyr_out_medusa_mlp_zephyr-7b-beta_medusa_3_lr_0.001_layers_1 \\\n    --repo drbh/zephyr_medusa_demo\n```\n\nWoo, we've successfully trained a Medusa model and pushed it to the Hugging Face Hub! 🎉\n"
  },
  {
    "path": "docs/source/basic_tutorials/using_cli.md",
    "content": "# Using TGI CLI\n\nYou can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters. To install the CLI, please refer to [the installation section](../installation#install-cli).\n\n`text-generation-server` lets you download the model with `download-weights` command like below 👇\n\n```bash\ntext-generation-server download-weights MODEL_HUB_ID\n```\n\nYou can also use it to quantize models like below 👇\n\n```bash\ntext-generation-server quantize MODEL_HUB_ID OUTPUT_DIR\n```\n\nYou can use `text-generation-launcher` to serve models.\n\n```bash\ntext-generation-launcher --model-id MODEL_HUB_ID --port 8080\n```\n\nThere are many options and parameters you can pass to `text-generation-launcher`. The documentation for CLI is kept minimal and intended to rely on self-generating documentation, which can be found by running\n\n```bash\ntext-generation-launcher --help\n```\n\nYou can also find it hosted in this [Swagger UI](https://huggingface.github.io/text-generation-inference/).\n\nSame documentation can be found for `text-generation-server`.\n\n```bash\ntext-generation-server --help\n```\n"
  },
  {
    "path": "docs/source/basic_tutorials/using_guidance.md",
    "content": "# Guidance\n\nText Generation Inference (TGI) now supports [JSON and regex grammars](#grammar-and-constraints) and [tools and functions](#tools-and-functions) to help developers guide LLM responses to fit their needs.\n\nThese feature are available starting from version `1.4.3`. They are accessible via the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The tool support is compatible with OpenAI's client libraries. The following guide will walk you through the new features and how to use them!\n\n_note: guidance is supported as grammar in the `/generate` endpoint and as tools in the `v1/chat/completions` endpoint._\n\n## How it works\n\nTGI leverages the [outlines](https://github.com/outlines-dev/outlines) library to efficiently parse and compile the grammatical structures and tools specified by users. This integration transforms the defined grammars into an intermediate representation that acts as a framework to guide and constrain content generation, ensuring that outputs adhere to the specified grammatical rules.\n\nIf you are interested in the technical details on how outlines is used in TGI, you can check out the [conceptual guidance documentation](../conceptual/guidance).\n\n## Table of Contents 📚\n\n### Grammar and Constraints\n\n- [The Grammar Parameter](#the-grammar-parameter): Shape your AI's responses with precision.\n- [Constrain with Pydantic](#constrain-with-pydantic): Define a grammar using Pydantic models.\n- [JSON Schema Integration](#json-schema-integration): Fine-grained control over your requests via JSON schema.\n- [Using the client](#using-the-client): Use TGI's client libraries to shape the AI's responses.\n\n### Tools and Functions\n\n- [The Tools Parameter](#the-tools-parameter): Enhance the AI's capabilities with predefined functions.\n- [Via the client](#text-generation-inference-client): Use TGI's client libraries to interact with the Messages API and Tool functions.\n- [OpenAI integration](#openai-integration): Use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.\n\n## Grammar and Constraints 🛣️\n\n### The Grammar Parameter\n\nIn TGI `1.4.3`, we've introduced the grammar parameter, which allows you to specify the format of the response you want from the LLM.\n\nUsing curl, you can make a request to TGI's Messages API with the grammar parameter. This is the most primitive way to interact with the API and using [Pydantic](#constrain-with-pydantic) is recommended for ease of use and readability.\n\n```json\ncurl localhost:3000/generate \\\n    -X POST \\\n    -H 'Content-Type: application/json' \\\n    -d '{\n    \"inputs\": \"I saw a puppy a cat and a raccoon during my bike ride in the park\",\n    \"parameters\": {\n        \"repetition_penalty\": 1.3,\n        \"grammar\": {\n            \"type\": \"json\",\n            \"value\": {\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\"\n                    },\n                    \"activity\": {\n                        \"type\": \"string\"\n                    },\n                    \"animals_seen\": {\n                        \"type\": \"integer\",\n                        \"minimum\": 1,\n                        \"maximum\": 5\n                    },\n                    \"animals\": {\n                        \"type\": \"array\",\n                        \"items\": {\n                            \"type\": \"string\"\n                        }\n                    }\n                },\n                \"required\": [\"location\", \"activity\", \"animals_seen\", \"animals\"]\n            }\n        }\n    }\n}'\n// {\"generated_text\":\"{ \\n\\n\\\"activity\\\": \\\"biking\\\",\\n\\\"animals\\\": [\\\"puppy\\\",\\\"cat\\\",\\\"raccoon\\\"],\\n\\\"animals_seen\\\": 3,\\n\\\"location\\\": \\\"park\\\"\\n}\"}\n\n```\n\n### Hugging Face Hub Python Library\n\nThe Hugging Face Hub Python library provides a client that makes it easy to interact with the Messages API. Here's an example of how to use the client to send a request with a grammar parameter.\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(\"http://localhost:3000\")\n\nschema = {\n    \"properties\": {\n        \"location\": {\"title\": \"Location\", \"type\": \"string\"},\n        \"activity\": {\"title\": \"Activity\", \"type\": \"string\"},\n        \"animals_seen\": {\n            \"maximum\": 5,\n            \"minimum\": 1,\n            \"title\": \"Animals Seen\",\n            \"type\": \"integer\",\n        },\n        \"animals\": {\"items\": {\"type\": \"string\"}, \"title\": \"Animals\", \"type\": \"array\"},\n    },\n    \"required\": [\"location\", \"activity\", \"animals_seen\", \"animals\"],\n    \"title\": \"Animals\",\n    \"type\": \"object\",\n}\n\nuser_input = \"I saw a puppy a cat and a raccoon during my bike ride in the park\"\nresp = client.text_generation(\n    f\"convert to JSON: 'f{user_input}'. please use the following schema: {schema}\",\n    max_new_tokens=100,\n    seed=42,\n    grammar={\"type\": \"json\", \"value\": schema},\n)\n\nprint(resp)\n# { \"activity\": \"bike ride\", \"animals\": [\"puppy\", \"cat\", \"raccoon\"], \"animals_seen\": 3, \"location\": \"park\" }\n\n```\n\nA grammar can be defined using Pydantic models, JSON schemas, or regular expressions. The LLM will then generate a response that conforms to the specified grammar.\n\n> Note: A grammar must compile to an intermediate representation to constrain the output. Grammar compilation is a computationally expensive and may take a few seconds to complete on the first request. Subsequent requests will use the cached grammar and will be much faster.\n\n### Constrain with Pydantic\n\nUsing Pydantic models we can define a similar grammar as the previous example in a shorter and more readable way.\n\n```python\nfrom huggingface_hub import InferenceClient\nfrom pydantic import BaseModel, conint\nfrom typing import List\n\n\nclass Animals(BaseModel):\n    location: str\n    activity: str\n    animals_seen: conint(ge=1, le=5)  # Constrained integer type\n    animals: List[str]\n\n\nclient = InferenceClient(\"http://localhost:3000\")\n\nuser_input = \"I saw a puppy a cat and a raccoon during my bike ride in the park\"\nresp = client.text_generation(\n    f\"convert to JSON: 'f{user_input}'. please use the following schema: {Animals.model_json_schema()}\",\n    max_new_tokens=100,\n    seed=42,\n    grammar={\"type\": \"json\", \"value\": Animals.model_json_schema()},\n)\n\nprint(resp)\n# { \"activity\": \"bike ride\", \"animals\": [\"puppy\", \"cat\", \"raccoon\"], \"animals_seen\": 3, \"location\": \"park\" }\n\n\n```\n\ndefining a grammar as regular expressions\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(\"http://localhost:3000\")\n\nsection_regex = \"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\"\nregexp = f\"HELLO\\.{section_regex}\\.WORLD\\.{section_regex}\"\n\n# This is a more realistic example of an ip address regex\n# regexp = f\"{section_regex}\\.{section_regex}\\.{section_regex}\\.{section_regex}\"\n\n\nresp = client.text_generation(\n    f\"Whats Googles DNS? Please use the following regex: {regexp}\",\n    seed=42,\n    grammar={\n        \"type\": \"regex\",\n        \"value\": regexp,\n    },\n)\n\n\nprint(resp)\n# HELLO.255.WORLD.255\n\n```\n\n## Tools and Functions 🛠️\n\n### The Tools Parameter\n\nIn addition to the grammar parameter, we've also introduced a set of tools and functions to help you get the most out of the Messages API.\n\nTools are a set of user defined functions that can be used in tandem with the chat functionality to enhance the LLM's capabilities. Functions, similar to grammar are defined as JSON schema and can be passed as part of the parameters to the Messages API.\n\n```json\ncurl localhost:3000/v1/chat/completions \\\n    -X POST \\\n    -H 'Content-Type: application/json' \\\n    -d '{\n    \"model\": \"tgi\",\n    \"messages\": [\n        {\n            \"role\": \"user\",\n            \"content\": \"What is the weather like in New York?\"\n        }\n    ],\n    \"tools\": [\n        {\n            \"type\": \"function\",\n            \"function\": {\n                \"name\": \"get_current_weather\",\n                \"description\": \"Get the current weather\",\n                \"parameters\": {\n                    \"type\": \"object\",\n                    \"properties\": {\n                        \"location\": {\n                            \"type\": \"string\",\n                            \"description\": \"The city and state, e.g. San Francisco, CA\"\n                        },\n                        \"format\": {\n                            \"type\": \"string\",\n                            \"enum\": [\"celsius\", \"fahrenheit\"],\n                            \"description\": \"The temperature unit to use. Infer this from the users location.\"\n                        }\n                    },\n                    \"required\": [\"location\", \"format\"]\n                }\n            }\n        }\n    ],\n    \"tool_choice\": \"get_current_weather\"\n}'\n// {\"id\":\"\",\"object\":\"text_completion\",\"created\":1709051640,\"model\":\"HuggingFaceH4/zephyr-7b-beta\",\"system_fingerprint\":\"1.4.3-native\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"tool_calls\":{\"id\":0,\"type\":\"function\",\"function\":{\"description\":null,\"name\":\"tools\",\"parameters\":{\"format\":\"celsius\",\"location\":\"New York\"}}}},\"logprobs\":null,\"finish_reason\":\"eos_token\"}],\"usage\":{\"prompt_tokens\":157,\"completion_tokens\":19,\"total_tokens\":176}}\n```\n\n### Chat Completion with Tools\n\nGrammars are supported in the `/generate` endpoint, while tools are supported in the `/chat/completions` endpoint. Here's an example of how to use the client to send a request with a tool parameter.\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(\"http://localhost:3000\")\n\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_current_weather\",\n            \"description\": \"Get the current weather\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"format\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"celsius\", \"fahrenheit\"],\n                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n                    },\n                },\n                \"required\": [\"location\", \"format\"],\n            },\n        },\n    },\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_n_day_weather_forecast\",\n            \"description\": \"Get an N-day weather forecast\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"format\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"celsius\", \"fahrenheit\"],\n                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n                    },\n                    \"num_days\": {\n                        \"type\": \"integer\",\n                        \"description\": \"The number of days to forecast\",\n                    },\n                },\n                \"required\": [\"location\", \"format\", \"num_days\"],\n            },\n        },\n    },\n]\n\nchat = client.chat_completion(\n    messages=[\n        {\n            \"role\": \"system\",\n            \"content\": \"You're a helpful assistant! Answer the users question best you can.\",\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"What is the weather like in Brooklyn, New York?\",\n        },\n    ],\n    tools=tools,\n    seed=42,\n    max_tokens=100,\n)\n\nprint(chat.choices[0].message.tool_calls)\n# [ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'fahrenheit', 'location': 'Brooklyn, New York', 'num_days': 7}, name='get_n_day_weather_forecast', description=None), id=0, type='function')]\n\n```\n\n### OpenAI integration\n\nTGI exposes an OpenAI-compatible API, which means you can use OpenAI's client libraries to interact with TGI's Messages API and Tool functions.\n\n```python\nfrom openai import OpenAI\n\n# Initialize the client, pointing it to one of the available models\nclient = OpenAI(\n    base_url=\"http://localhost:3000/v1\",\n    api_key=\"_\",\n)\n\n# NOTE: tools defined above and removed for brevity\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\n            \"role\": \"system\",\n            \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\",\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"What's the weather like the next 3 days in San Francisco, CA?\",\n        },\n    ],\n    tools=tools,\n    tool_choice=\"auto\",  # tool selected by model\n    max_tokens=500,\n)\n\n\ncalled = chat_completion.choices[0].message.tool_calls\nprint(called)\n# {\n#     \"id\": 0,\n#     \"type\": \"function\",\n#     \"function\": {\n#         \"description\": None,\n#         \"name\": \"tools\",\n#         \"parameters\": {\n#             \"format\": \"celsius\",\n#             \"location\": \"San Francisco, CA\",\n#             \"num_days\": 3,\n#         },\n#     },\n# }\n```\n\n### Tool Choice Configuration\n\nWhen configuring how the model interacts with tools during a chat completion, there are several options for determining if or how a tool should be called. These options are controlled by the `tool_choice` parameter, which specifies the behavior of the model in relation to tool usage. The following modes are supported:\n\n1. **`auto`**:\n\n   - The model decides whether to call a tool or generate a response message based on the user's input.\n   - If tools are provided, this is the default mode.\n   - Example usage:\n     ```python\n     tool_choice=\"auto\"\n     ```\n\n2. **`none`**:\n\n   - The model will never call any tools and will only generate a response message.\n   - If no tools are provided, this is the default mode.\n   - Example usage:\n     ```python\n     tool_choice=\"none\"\n     ```\n\n3. **`required`**:\n\n   - The model must call one or more tools and will not generate a response message on its own.\n   - Example usage:\n     ```python\n     tool_choice=\"required\"\n     ```\n\n4. **Specific Tool Call by Function Name**:\n   - You can force the model to call a specific tool either by specifying the tool function directly or by using an object definition.\n   - Two ways to do this:\n     1. Provide the function name as a string:\n        ```python\n        tool_choice=\"get_current_weather\"\n        ```\n     2. Use the function object format:\n        ```python\n        tool_choice={\n          \"type\": \"function\",\n          \"function\": {\n              \"name\": \"get_current_weather\"\n          }\n        }\n        ```\n\nThese options allow flexibility when integrating tools with the chat completions endpoint. You can configure the model to either rely on tools automatically or force it to follow a predefined behavior, based on the needs of the task at hand.\n\n---\n\n| **Tool Choice Option**                | **Description**                                                                                                                 | **When to Use**                                                                        |\n| ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- |\n| `auto`                                | The model decides whether to call a tool or generate a message. This is the default if tools are provided.                      | Use when you want the model to decide when a tool is necessary.                        |\n| `none`                                | The model generates a message without calling any tools. This is the default if no tools are provided.                          | Use when you do not want the model to call any tools.                                  |\n| `required`                            | The model must call one or more tools and will not generate a message on its own.                                               | Use when a tool call is mandatory, and you do not want a regular message generated.    |\n| Specific Tool Call (`name` or object) | Force the model to call a specific tool either by specifying its name (`tool_choice=\"get_current_weather\"`) or using an object. | Use when you want to restrict the model to calling a particular tool for the response. |\n"
  },
  {
    "path": "docs/source/basic_tutorials/visual_language_models.md",
    "content": "# Vision Language Model Inference in TGI\n\nVisual Language Model (VLM) are models that consume both image and text inputs to generate text.\n\nVLM's are trained on a combination of image and text data and can handle a wide range of tasks, such as image captioning, visual question answering, and visual dialog.\n\n> What distinguishes VLMs from other text and image models is their ability to handle long context and generate text that is coherent and relevant to the image even after multiple turns or in some cases, multiple images.\n\nBelow are couple of common use cases for vision language models:\n\n- **Image Captioning**: Given an image, generate a caption that describes the image.\n- **Visual Question Answering (VQA)**: Given an image and a question about the image, generate an answer to the question.\n- **Mulimodal Dialog**: Generate response to multiple turns of images and conversations.\n- **Image Information Retrieval**: Given an image, retrieve information from the image.\n\n## How to Use a Vision Language Model?\n\n### Hugging Face Hub Python Library\n\nTo infer with vision language models through Python, you can use the [`huggingface_hub`](https://pypi.org/project/huggingface-hub/) library. The `InferenceClient` class provides a simple way to interact with the [Inference API](https://huggingface.co/docs/api-inference/index). Images can be passed as URLs or base64-encoded strings. The `InferenceClient` will automatically detect the image format.\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(base_url=\"http://127.0.0.1:3000\")\nimage = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\nprompt = f\"![]({image})What is this a picture of?\\n\\n\"\nfor token in client.text_generation(prompt, max_new_tokens=16, stream=True):\n    print(token)\n\n# This is a picture of an anthropomorphic rabbit in a space suit.\n```\n\n```python\nfrom huggingface_hub import InferenceClient\nimport base64\nimport requests\nimport io\n\nclient = InferenceClient(base_url=\"http://127.0.0.1:3000\")\n\n# read image from local file\nimage_path = \"rabbit.png\"\nwith open(image_path, \"rb\") as f:\n    image = base64.b64encode(f.read()).decode(\"utf-8\")\n\nimage = f\"data:image/png;base64,{image}\"\nprompt = f\"![]({image})What is this a picture of?\\n\\n\"\n\nfor token in client.text_generation(prompt, max_new_tokens=10, stream=True):\n    print(token)\n\n# This is a picture of an anthropomorphic rabbit in a space suit.\n```\n\nor via the `chat_completion` endpoint:\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(base_url=\"http://127.0.0.1:3000\")\n\nchat = client.chat_completion(\n    messages=[\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"Whats in this image?\"},\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                    },\n                },\n            ],\n        },\n    ],\n    seed=42,\n    max_tokens=100,\n)\n\nprint(chat)\n# ChatCompletionOutput(choices=[ChatCompletionOutputComplete(finish_reason='length', index=0, message=ChatCompletionOutputMessage(role='assistant', content=\" The image you've provided features an anthropomorphic rabbit in spacesuit attire. This rabbit is depicted with human-like posture and movement, standing on a rocky terrain with a vast, reddish-brown landscape in the background. The spacesuit is detailed with mission patches, circuitry, and a helmet that covers the rabbit's face and ear, with an illuminated red light on the chest area.\\n\\nThe artwork style is that of a\", name=None, tool_calls=None), logprobs=None)], created=1714589614, id='', model='llava-hf/llava-v1.6-mistral-7b-hf', object='text_completion', system_fingerprint='2.0.2-native', usage=ChatCompletionOutputUsage(completion_tokens=100, prompt_tokens=2943, total_tokens=3043))\n\n```\n\nor with OpenAI's [client library](https://github.com/openai/openai-python):\n\n```python\nfrom openai import OpenAI\n\n# init the client but point it to TGI\nclient = OpenAI(base_url=\"http://localhost:3000/v1\", api_key=\"-\")\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\"type\": \"text\", \"text\": \"Whats in this image?\"},\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": {\n                        \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                    },\n                },\n            ],\n        },\n    ],\n    stream=False,\n)\n\nprint(chat_completion)\n# ChatCompletion(id='', choices=[Choice(finish_reason='eos_token', index=0, logprobs=None, message=ChatCompletionMessage(content=' The image depicts an anthropomorphic rabbit dressed in a space suit with gear that resembles NASA attire. The setting appears to be a solar eclipse with dramatic mountain peaks and a partial celestial body in the sky. The artwork is detailed and vivid, with a warm color palette and a sense of an adventurous bunny exploring or preparing for a journey beyond Earth. ', role='assistant', function_call=None, tool_calls=None))], created=1714589732, model='llava-hf/llava-v1.6-mistral-7b-hf', object='text_completion', system_fingerprint='2.0.2-native', usage=CompletionUsage(completion_tokens=84, prompt_tokens=2943, total_tokens=3027))\n```\n\n### Inference Through Sending `cURL` Requests\n\nTo use the `generate_stream` endpoint with curl, you can add the `-N` flag. This flag disables curl default buffering and shows data as it arrives from the server.\n\n```bash\ncurl -N 127.0.0.1:3000/generate_stream \\\n    -X POST \\\n    -d '{\"inputs\":\"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\\n\\n\",\"parameters\":{\"max_new_tokens\":16, \"seed\": 42}}' \\\n    -H 'Content-Type: application/json'\n\n# ...\n# data:{\"index\":16,\"token\":{\"id\":28723,\"text\":\".\",\"logprob\":-0.6196289,\"special\":false},\"generated_text\":\"This is a picture of an anthropomorphic rabbit in a space suit.\",\"details\":null}\n```\n\n### Inference Through JavaScript\n\nFirst, we need to install the `@huggingface/inference` library.\n\n```bash\nnpm install @huggingface/inference\n```\n\nWhether you use Inference Providers (our serverless API), or Inference Endpoints, you can call `InferenceClient`.\n\nWe can create a `InferenceClient` providing our endpoint URL and [Hugging Face access token](https://huggingface.co/settings/tokens).\n\n```js\nimport { InferenceClient } from \"@huggingface/inference\";\n\nconst client = new InferenceClient('hf_YOUR_TOKEN', { endpointUrl: 'https://YOUR_ENDPOINT.endpoints.huggingface.cloud' });\n\nconst prompt =\n  \"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\\n\\n\";\n\nconst stream = client.textGenerationStream({\n  inputs: prompt,\n  parameters: { max_new_tokens: 16, seed: 42 },\n});\nfor await (const r of stream) {\n  // yield the generated token\n  process.stdout.write(r.token.text);\n}\n\n// This is a picture of an anthropomorphic rabbit in a space suit.\n```\n\n## Combining Vision Language Models with Other Features\n\nVLMs in TGI have several advantages, for example these models can be used in tandem with other features for more complex tasks. For example, you can use VLMs with [Guided Generation](/docs/conceptual/guided-generation) to generate specific JSON data from an image.\n\n<div class=\"flex justify-center\">\n    <img\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n        width=\"400\"\n    />\n</div>\n\nFor example we can extract information from the rabbit image and generate a JSON object with the location, activity, number of animals seen, and the animals seen. That would look like this:\n\n```json\n{\n  \"activity\": \"Standing\",\n  \"animals\": [\"Rabbit\"],\n  \"animals_seen\": 1,\n  \"location\": \"Rocky surface with mountains in the background and a red light on the rabbit's chest\"\n}\n```\n\nAll we need to do is provide a JSON schema to the VLM model and it will generate the JSON object for us.\n\n```bash\ncurl localhost:3000/generate \\\n    -X POST \\\n    -H 'Content-Type: application/json' \\\n    -d '{\n    \"inputs\":\"![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)What is this a picture of?\\n\\n\",\n    \"parameters\": {\n        \"max_new_tokens\": 100,\n        \"seed\": 42,\n        \"grammar\": {\n            \"type\": \"json\",\n            \"value\": {\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\"\n                    },\n                    \"activity\": {\n                        \"type\": \"string\"\n                    },\n                    \"animals_seen\": {\n                        \"type\": \"integer\",\n                        \"minimum\": 1,\n                        \"maximum\": 5\n                    },\n                    \"animals\": {\n                        \"type\": \"array\",\n                        \"items\": {\n                            \"type\": \"string\"\n                        }\n                    }\n                },\n                \"required\": [\"location\", \"activity\", \"animals_seen\", \"animals\"]\n            }\n        }\n    }\n}'\n\n# {\n#   \"generated_text\": \"{ \\\"activity\\\": \\\"Standing\\\", \\\"animals\\\": [ \\\"Rabbit\\\" ], \\\"animals_seen\\\": 1, \\\"location\\\": \\\"Rocky surface with mountains in the background and a red light on the rabbit's chest\\\" }\"\n# }\n```\n\nWant to learn more about how Vision Language Models work? Check out the [awesome blog post on the topic](https://huggingface.co/blog/vlms).\n"
  },
  {
    "path": "docs/source/conceptual/chunking.md",
    "content": "# TGI v3 overview\n## Summary\n\n\nPerformance leap: TGI processes 3x more tokens, 13x faster than vLLM on long prompts. Zero config !\n\n### 3x more tokens.\nBy reducing our memory footprint, we’re able to ingest many more tokens and more dynamically than before. A single L4 (24GB) can handle 30k tokens on llama 3.1-8B, while vLLM gets barely 10k. A lot of work went into reducing the footprint of the runtime and its effect are best seen on smaller constrained environments.\n\n### 13x faster\nOn long prompts (200k+ tokens) conversation replies take 27.5s in vLLM, while it takes only 2s in TGI. How so ? We keep the initial conversation around, so when a new reply comes in, we can answer almost instantly. The overhead of the lookup is ~5us. Thanks @Daniël de Kok for the beast data structure.\n\n### Zero config\nThat’s it. Remove all the flags your are using and you’re likely to get the best performance. By evaluating the hardware and model, TGI carefully selects automatic values to give best performance. In production, we don’t have any flags anymore in our deployments. We kept all existing flags around, they may come in handy in niche scenarios.\n\n\n\n## Benchmarks\n\n### Methodology\n\nTo ensure accurate and reliable results, we employed a robust benchmarking protocol that addresses common pitfalls in performance evaluation. Specifically:\n\n1.  **Consistent Code**: We used the same codebase to run against different engines, ensuring that any performance differences are attributable to the LLM itself, rather than variations in the testing framework.\n2.  **Request-Based Measurement**: Instead of measuring Requests Per Second (RPS) by sending as many requests as possible, we opted for a more consistent approach, sending a fixed number of requests and measuring the time it takes for the server to complete all of them. This method avoids boundary effects and provides a more accurate representation of performance.\n3.  **Realistic Combinations**: We selected realistic combinations of LLMs and hardware configurations so we used 8xH100 for a 70B, not a 8B, which would be a waste of money.\n4.  **Realistic scenarios** We benchmarked engines with prefix caching on, so we are reporting the results of the 2nd run, not the first one.\nDuring the first run of a benchmark, every request is new, so prefix caching is not working, masking the real world benefits of using it.\n\nNote: Boundary effect is when the benchmarks are flaky because their results depend on fine details of the engine being benchmarked.\nFor instance, a system ingesting a constant 10RPS, but receiving in the benchmark a single final request at -0.1s before the end of the benchmark, and that single request takes a full 10s to process. Then a benchmark taking 30s would measure 7.5RPS instead of the expected 10, because that single query isn't being parallelized with others. Another very slightly slower engine would receive that request at +0.1s which would get discarded by the benchmark and therefore measure the slower system as being faster.\n\nFor more details on benchmarking in general we recommend the documentation of k6: https://grafana.com/docs/k6/latest/.\n\n### Scenarios\n\nWe selected a handful of scenarios to simplify the picture, they seem to accurately reflect a larger trend.\n\n1. **Small scenario**: This scenario consists of the first 200 requests from the orca datasets being prompted to the model. The 200 requests total 8k tokens together and are representative of conversation starters. Prefix caching has very limited impact in that scenario and we feel it's a relatively balanced benchmark for simple use cases.\n2. **Long scenario**: This scenario consists of 20 requests totalling 200k prompt tokens which are essentially asking for summaries of large chunks for text. In practical scenarios this is really useful when you are feeding large chunks of code, large chunks of business data or documents repeatedly and ask simple questions about them (summarization, classification, or where to find some data). This scenario is the one closest to what a lot of professional use cases seem to be doing by including a lot of information in the prompt itself. Those very long conversations are the ones that benefit the most for our recent changes since we are enable ever larger prompts and ever faster caching.\n\n   ### Hardware\n\n   1. `L4` : This is a single L4 (24GB) which represents small or even home compute capabilities. We tested `meta-llama/Meta-Llama-3.1-8B-Instruct` on it.\n   2. `4xL4`: This is a more beefy deployment usually used for either very large requests deployments for 8B models (the ones under test) or it can also easily handle all 30GB models. For this benchmark we tested `meta-llama/Meta-Llama-3.1-8B-Instruct`\n   3. `8xH100` This is one of the beefiest deployments possible. We tested  `meta-llama/Meta-Llama-3.1-70B-Instruct` as it's the most representative models of this size. Llama 3.3 wasn't released at the time of benchmarking (it's the exact same model so it doesn't make any difference).\n\n\n### Replicating the results\n\n\n\nThe commands to run the benchmarks are as follows:\n\n1. Prepare the datasets:\n\n```bash\ncd text-generation-inference/load_tests\nmake prepare_orca\npython long.py\n```\n\n2. Launch the engine:\n\nTGI: `text-generation-launcher --model-id $MODEL_ID --num-shard $N --port 8000` (or docker variant)\nvLLM: `vllm serve $MODEL_ID --tensor-parallel $N —enable-prefix-caching` (or docker variant)\n\n3. Start scenario:\nSmall: `MODEL_ID=$MODEL_ID  HOST=localhost:8000 k6 run load_tests/common.js`\nLong:  `MODEL_ID=$MODEL_ID  HOST=localhost:8000 k6 run load_tests/long.js`\n\n\n### Results\n\n![benchmarks_v3](https://raw.githubusercontent.com/huggingface/text-generation-inference/refs/heads/main/assets/v3_benchmarks.png)\n\nOur benchmarking results show significant performance gains, with a 13x speedup over vLLM with prefix caching, and up to 30x speedup without prefix caching. These results are consistent with our production data and demonstrate the effectiveness of our optimized LLM architecture.\n\nRaw results\n\n|   |   |   |   |   |\n|---|---|---|---|---|\n|2nd run ||**TGI v3** (time in s)|**vLLM** (s)|**Amount of req**|\n|**Llama 3.1 8b**|Small test - L4 - 8B|17.5|19.9|200|\n|**Llama 3.1 8b**|Long test* - L4 - 8B|53|57|10|\n|**Llama 3.1 8b**|Small test - 4xL4 - 8B|4.8|6|200|\n|**Llama 3.1 8b**|Long test - 4xL4 - 8B|3.2|12.5|20|\n|**Llama 3.1 70b**|Small test - 8XH100 - 70B|6.2|7.4|200|\n|**Llama 3.1 70b**|Long test - 8H100 - 70B|2|27.5|20|\n||||||\n|1st run ||TGI (s)|vLLM (s)|Amount of req|\n|**Llama 3.1 8b**|Small test - L4|19.9|19.9|200|\n|**Llama 3.1 8b**|Long test (10) - L4|49.8|55|10|\n|**Llama 3.1 8b**|Small test - 4xL4|13|12.6|200|\n|**Llama 3.1 8b**|Long test - 4xL4|47|50.3|20|\n|**Llama 3.1 70b**|Small test - 8XH100|7.5|7.6|200|\n|**Llama 3.1 70b**|Long test - 8H100|12.1|28.3|20|\n\n\n### Caveats and Limitations\n\nWhile our results are promising, there are some caveats to consider:\n\n1. **Constrained kv-cache**: If a deployment lacks kv-cache space, that means that many queries will require the same slots of kv-cache, leading to contention in the kv-cache. You can limit that effect by limiting `--max-total-tokens` to reduce individual queries impact. You can also use more GPUs or larger GPUs in order to increase the size of the kv-cache.\n2.  **Replication**: In scenarios where multiple replicas are behind a single endpoint, there's no reason for every query from a particular user to hit the same replica, therefore the cache will not be present, meaning no speed benefit. You can use sticky sessions load balancing to force every user to send their requests on the same replica. Do not apply this blindly, it's possible this may not be necessary at all.\n\n## Technical Insights\n\nOur performance gains can be attributed to several key factors:\n\n1.  **New Kernels**: Our custom kernels, including `flashinfer` and `flashdecoding`, offer improved performance at large prompt lengths and enable more efficient scheduling.\n2.  **Prefix Caching**: Our optimized prefix caching structure allows for fast query matching, even for long prompts. The overhead is roughly 6us.\n3.  **Chunking Code**: Our chunking code enables finer control over compute resources, ensuring optimal performance and reduced VRAM usage.\n4.  **Kernel Optimizations**: We've implemented various other kernel optimizations, including better kernel selection. Notably we've implemented several small kernels involved in the queries bookkeeping which are particularly efficient on small models. Every kernel launch has an overhead of several milliseconds so fusing them together increases a lot performance when this bookkeeping is important relative to the raw model calculations. This happens typically on oversized compute for a particular model and particularly small models.\n5. **VRAM efficiency**: In the realm of very large requests (100k+ tokens) there are a lot of places which start becoming big memory consumers. We've hunted the biggest ones and found ways to reduce/reuse or delete them. The biggest culprit probably is `logits` calculation. Logits for llama 3.1-8b take 25.6GB (=100k tokens * 128k vocabulary * 2(f16)) which is more than the full model which is 16GB. The thing is that in general we do not need every prompt logits, so we simply removed them and removed them from being potentially asked by users by default. We think this is ok since they are mostly used by researchers. You can enable your deployments to have them again by using the `--enable-prefill-logprobs` flag, but you will experience reduced token prompt size.\n\n## Future Directions\n\nWhile we've made significant progress, there are still opportunities for improvement:\n\n1.  **Special models**: All LLMs come with the aforementioned improvements. Some specific set of features might not (some quantizations, speculation or VLMs for instance are harder to optimize for with the same level of detail).\n2.  **KV-Cache Long-Term Retention**: Addressing KV-cache long-term retention is a challenge. There are several solutions envisionned like shared KV-cache (like redis or memcached) solutions or innovative storage approaches. It is an area of ongoing research of ours.\n3.  **Multimodal models**: We are also investigating quite a lot other kind of models, like audio-to-audio, image/video generation, and other hybrids, where we see a lot of potential of applying the same principles we've applied in TGI to maximize performance.\n\nBy sharing our benchmarking methodology, results, and technical insights, we aim to contribute to the ongoing development of more efficient and effective LLMs.\n"
  },
  {
    "path": "docs/source/conceptual/external.md",
    "content": "# External Resources\n\n- Adyen wrote a detailed article about the interplay between TGI's main components: router and server.\n[LLM inference at scale with TGI (Martin Iglesias Goyanes - Adyen, 2024)](https://www.adyen.com/knowledge-hub/llm-inference-at-scale-with-tgi)\n"
  },
  {
    "path": "docs/source/conceptual/flash_attention.md",
    "content": "# Flash Attention\n\nScaling the transformer architecture is heavily bottlenecked by the self-attention mechanism, which has quadratic time and memory complexity. Recent developments in accelerator hardware mainly focus on enhancing compute capacities and not memory and transferring data between hardware. This results in attention operation having a memory bottleneck. **Flash Attention** is an attention algorithm used to reduce this problem and scale transformer-based models more efficiently, enabling faster training and inference.\n\nStandard attention mechanism uses High Bandwidth Memory (HBM) to store, read and write keys, queries and values. HBM is large in memory, but slow in processing, meanwhile SRAM is smaller in memory, but faster in operations. In the standard attention implementation, the cost of loading and writing keys, queries, and values from HBM is high. It loads keys, queries, and values from HBM to GPU on-chip SRAM, performs a single step of the attention mechanism, writes it back to HBM, and repeats this for every single attention step. Instead, Flash Attention loads keys, queries, and values once, fuses the operations of the attention mechanism, and writes them back.\n\n![Flash Attention](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/flash-attn.png)\n\nIt is implemented for supported models. You can check out the complete list of models that support Flash Attention [here](https://github.com/huggingface/text-generation-inference/tree/main/server/text_generation_server/models), for models with flash prefix.\n\nYou can learn more about Flash Attention by reading the paper in this [link](https://arxiv.org/abs/2205.14135).\n"
  },
  {
    "path": "docs/source/conceptual/guidance.md",
    "content": "# Guidance\n\n## What is Guidance?\n\nGuidance is a feature that allows users to constrain the generation of a large language model with a specified grammar. This feature is particularly useful when you want to generate text that follows a specific structure or uses a specific set of words or produce output in a specific format. A prominent example is JSON grammar, where the model is forced to output valid JSON.\n\n## How is it used?\n\nGuidance can be implemented in many ways and the community is always finding new ways to use it. Here are some examples of how you can use guidance:\n\nTechnically, guidance can be used to generate:\n\n- a specific JSON object\n- a function signature\n- typed output like a list of integers\n\nHowever these use cases can span a wide range of applications, such as:\n\n- extracting structured data from unstructured text\n- summarizing text into a specific format\n- limit output to specific classes of words (act as a LLM powered classifier)\n- generate the input to specific APIs or services\n- provide reliable and consistent output for downstream tasks\n- extract data from multimodal inputs\n\n## How it works?\n\nDiving into the details, guidance is enabled by including a grammar with a generation request that is compiled, and used to modify the chosen tokens.\n\nThis process can be broken down into the following steps:\n\n1. A request is sent to the backend, it is processed and placed in batch. Processing includes compiling the grammar into a finite state machine and a grammar state.\n\n<div class=\"flex justify-center\">\n    <img\n        class=\"block dark:hidden\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch.gif\"\n    />\n    <img\n        class=\"hidden dark:block\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/request-to-batch-dark.gif\"\n    />\n</div>\n\n2. The model does a forward pass over the batch. This returns probabilities for each token in the vocabulary for each request in the batch.\n\n3. The process of choosing one of those tokens is called `sampling`. The model samples from the distribution of probabilities to choose the next token. In TGI all of the steps before sampling are called `processor`. Grammars are applied as a processor that masks out tokens that are not allowed by the grammar.\n\n<div class=\"flex justify-center\">\n    <img\n        class=\"block dark:hidden\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask.gif\"\n    />\n    <img\n        class=\"hidden dark:block\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/logit-grammar-mask-dark.gif\"\n    />\n</div>\n\n4. The grammar mask is applied and the model samples from the remaining tokens. Once a token is chosen, we update the grammar state with the new token, to prepare it for the next pass.\n\n<div class=\"flex justify-center\">\n    <img\n        class=\"block dark:hidden\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits.gif\"\n    />\n    <img\n        class=\"hidden dark:block\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/sample-logits-dark.gif\"\n    />\n</div>\n\n## How to use Guidance?\n\nThere are two main ways to use guidance; you can either use the `/generate` endpoint with a grammar or use the `/chat/completion` endpoint with tools.\n\nUnder the hood tools are a special case of grammars that allows the model to choose one or none of the provided tools.\n\nPlease refer to [using guidance](../basic_tutorials/using_guidance) for more examples and details on how to use guidance in Python, JavaScript, and cURL.\n\n### Getting the most out of guidance\n\nDepending on how you are using guidance, you may want to make use of different features. Here are some tips to get the most out of guidance:\n\n- If you are using the `/generate` with a `grammar` it is recommended to include the grammar in the prompt prefixed by something like `Please use the following JSON schema to generate the output:`. This will help the model understand the context of the grammar and generate the output accordingly.\n- If you are getting a response with many repeated tokens, please use the `frequency_penalty` or `repetition_penalty` to reduce the number of repeated tokens in the output.\n"
  },
  {
    "path": "docs/source/conceptual/lora.md",
    "content": "# LoRA (Low-Rank Adaptation)\n\n## What is LoRA?\n\nLoRA is a technique that allows for efficent fine-tuning a model while only updating a small portion of the model's weights. This is useful when you have a large model that has been pre-trained on a large dataset, but you want to fine-tune it on a smaller dataset or for a specific task.\n\nLoRA works by adding a small number of additional weights to the model, which are used to adapt the model to the new dataset or task. These additional weights are learned during the fine-tuning process, while the rest of the model's weights are kept fixed.\n\n## How is it used?\n\nLoRA can be used in many ways and the community is always finding new ways to use it. Here are some examples of how you can use LoRA:\n\nTechnically, LoRA can be used to fine-tune a large language model on a small dataset. However, these use cases can span a wide range of applications, such as:\n\n- fine-tuning a language model on a small dataset\n- fine-tuning a language model on a domain-specific dataset\n- fine-tuning a language model on a dataset with limited labels\n\n## Optimizing Inference with LoRA\n\nLoRA's can be used during inference by mutliplying the adapter weights with the model weights at each specified layer. This process can be computationally expensive, but due to awesome work by [punica-ai](https://github.com/punica-ai/punica) and the [lorax](https://github.com/predibase/lorax) team, optimized kernels/and frameworks have been developed to make this process more efficient. TGI leverages these optimizations in order to provide fast and efficient inference with mulitple LoRA models.\n\n## Serving multiple LoRA adapters with TGI\n\nOnce a LoRA model has been trained, it can be used to generate text or perform other tasks just like a regular language model. However, because the model has been fine-tuned on a specific dataset, it may perform better on that dataset than a model that has not been fine-tuned.\n\nIn practice its often useful to have multiple LoRA models, each fine-tuned on a different dataset or for a different task. This allows you to use the model that is best suited for a particular task or dataset.\n\nText Generation Inference (TGI) now supports loading multiple LoRA models at startup that can be used in generation requests. This feature is available starting from version `~2.0.6` and is compatible with LoRA models trained using the `peft` library.\n\n### Specifying LoRA models\n\nTo use LoRA in TGI, when starting the server, you can specify the list of LoRA models to load using the `LORA_ADAPTERS` environment variable. For example:\n\n```bash\nLORA_ADAPTERS=predibase/customer_support,predibase/dbpedia\n```\n\nTo specify model revision, use `adapter_id@revision`, as follows:\n\n```bash\nLORA_ADAPTERS=predibase/customer_support@main,predibase/dbpedia@rev2\n```\n\nTo use a locally stored lora adapter, use `adapter-name=/path/to/adapter`, as seen below. When you want to use this adapter, set `\"parameters\": {\"adapter_id\": \"adapter-name\"}\"`\n\n```bash\nLORA_ADAPTERS=myadapter=/some/path/to/adapter,myadapter2=/another/path/to/adapter\n```\n\nnote it's possible to mix adapter_ids with adapter_id=adapter_path e.g.\n\n```bash\nLORA_ADAPTERS=predibase/dbpedia,myadapter=/path/to/dir/\n```\n\nIn the server logs, you will see the following message:\n\n```txt\nLoading adapter weights into model: predibase/customer_support\nLoading adapter weights into model: predibase/dbpedia\n```\n\n## Generate text\n\nYou can then use these models in generation requests by specifying the `lora_model` parameter in the request payload. For example:\n\n```json\ncurl 127.0.0.1:3000/generate \\\n    -X POST \\\n    -H 'Content-Type: application/json' \\\n    -d '{\n  \"inputs\": \"Hello who are you?\",\n  \"parameters\": {\n    \"max_new_tokens\": 40,\n    \"adapter_id\": \"predibase/customer_support\"\n  }\n}'\n```\n\nIf you are using a lora adapter stored locally that was set in the following manner: `LORA_ADAPTERS=myadapter=/some/path/to/adapter`, here is an example payload:\n\n```json\ncurl 127.0.0.1:3000/generate \\\n    -X POST \\\n    -H 'Content-Type: application/json' \\\n    -d '{\n  \"inputs\": \"Hello who are you?\",\n  \"parameters\": {\n    \"max_new_tokens\": 40,\n    \"adapter_id\": \"myadapter\"\n  }\n}'\n```\n\n\n> **Note:** The Lora feature is new and still being improved. If you encounter any issues or have any feedback, please let us know by opening an issue on the [GitHub repository](https://github.com/huggingface/text-generation-inference/issues/new/choose). Additionally documentation and an improved client library will be published soon.\n\nAn updated tutorial with detailed examples will be published soon. Stay tuned!\n"
  },
  {
    "path": "docs/source/conceptual/paged_attention.md",
    "content": "# PagedAttention\n\nLLMs struggle with memory limitations during generation. In the decoding part of generation, all the attention keys and values generated for previous tokens are stored in GPU memory for reuse. This is called _KV cache_, and it may take up a large amount of memory for large models and long sequences.\n\nPagedAttention attempts to optimize memory use by partitioning the KV cache into blocks that are accessed through a lookup table. Thus, the KV cache does not need to be stored in contiguous memory, and blocks are allocated as needed. The memory efficiency can increase GPU utilization on memory-bound workloads, so more inference batches can be supported.\n\nThe use of a lookup table to access the memory blocks can also help with KV sharing across multiple generations. This is helpful for techniques such as _parallel sampling_, where multiple outputs are generated simultaneously for the same prompt. In this case, the cached KV blocks can be shared among the generations.\n\nTGI's PagedAttention implementation leverages the custom cuda kernels developed by the [vLLM Project](https://github.com/vllm-project/vllm). You can learn more about this technique in the [project's page](https://vllm.ai/).\n"
  },
  {
    "path": "docs/source/conceptual/quantization.md",
    "content": "# Quantization\n\nTGI offers many quantization schemes to run LLMs effectively and fast based on your use-case. TGI supports GPTQ, AWQ, bits-and-bytes, EETQ, Marlin, EXL2 and fp8 quantization.\n\nTo leverage GPTQ, AWQ, Marlin and EXL2 quants, you must provide pre-quantized weights. Whereas for bits-and-bytes, EETQ and fp8, weights are quantized by TGI on the fly.\n\nWe recommend using the official quantization scripts for creating your quants:\n1. [AWQ](https://github.com/casper-hansen/AutoAWQ/blob/main/examples/quantize.py)\n2. [GPTQ/ Marlin](https://github.com/AutoGPTQ/AutoGPTQ/blob/main/examples/quantization/basic_usage.py)\n3. [EXL2](https://github.com/turboderp/exllamav2/blob/master/doc/convert.md)\n\nFor on-the-fly quantization you simply need to pass one of the supported quantization types and TGI takes care of the rest.\n\n## Quantization with bitsandbytes, EETQ & fp8\n\nbitsandbytes is a library used to apply 8-bit and 4-bit quantization to models. Unlike GPTQ quantization, bitsandbytes doesn't require a calibration dataset or any post-processing – weights are automatically quantized on load. However, inference with bitsandbytes is slower than GPTQ or FP16 precision.\n\n8-bit quantization enables multi-billion parameter scale models to fit in smaller hardware without degrading performance too much.\nIn TGI, you can use 8-bit quantization by adding `--quantize bitsandbytes` like below 👇\n\n```bash\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize bitsandbytes\n```\n\n4-bit quantization is also possible with bitsandbytes. You can choose one of the following 4-bit data types: 4-bit float (`fp4`), or 4-bit `NormalFloat` (`nf4`). These data types were introduced in the context of parameter-efficient fine-tuning, but you can apply them for inference by automatically converting the model weights on load.\n\nIn TGI, you can use 4-bit quantization by adding `--quantize bitsandbytes-nf4` or `--quantize bitsandbytes-fp4` like below 👇\n\n```bash\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize bitsandbytes-nf4\n```\n\nYou can get more information about 8-bit quantization by reading this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration), and 4-bit quantization by reading [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).\n\nSimilarly you can use pass you can pass `--quantize eetq` or `--quantize fp8` for respective quantization schemes.\n\nIn addition to this, TGI allows creating GPTQ quants directly by passing the model weights and a calibration dataset.\n\n## Quantization with GPTQ\n\nGPTQ is a post-training quantization method to make the model smaller. It quantizes the layers by finding a compressed version of that weight, that will yield a minimum mean squared error like below 👇\n\nGiven a layer \\\\(l\\\\) with weight matrix \\\\(W_{l}\\\\) and layer input \\\\(X_{l}\\\\), find quantized weight \\\\(\\\\hat{W}_{l}\\\\):\n\n$$({\\hat{W}_{l}}^{*} = argmin_{\\hat{W_{l}}} ||W_{l}X-\\hat{W}_{l}X||^{2}_{2})$$\n\n\nTGI allows you to both run an already GPTQ quantized model (see available models [here](https://huggingface.co/models?search=gptq)) or quantize a model of your choice using quantization script. You can run a quantized model by simply passing --quantize like below 👇\n\n```bash\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:3.3.5 --model-id $model --quantize gptq\n```\n\nNote that TGI's GPTQ implementation doesn't use [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) under the hood. However, models quantized using AutoGPTQ or Optimum can still be served by TGI.\n\nTo quantize a given model using GPTQ with a calibration dataset, simply run\n\n```bash\ntext-generation-server quantize tiiuae/falcon-40b /data/falcon-40b-gptq\n# Add --upload-to-model-id MYUSERNAME/falcon-40b to push the created model to the hub directly\n```\n\nThis will create a new directory with the quantized files which you can use with,\n\n```bash\ntext-generation-launcher --model-id /data/falcon-40b-gptq/ --sharded true --num-shard 2 --quantize gptq\n```\n\nYou can learn more about the quantization options by running `text-generation-server quantize --help`.\n\nIf you wish to do more with GPTQ models (e.g. train an adapter on top), you can read about transformers GPTQ integration [here](https://huggingface.co/blog/gptq-integration).\nYou can learn more about GPTQ from the [paper](https://arxiv.org/pdf/2210.17323.pdf).\n"
  },
  {
    "path": "docs/source/conceptual/safetensors.md",
    "content": "# Safetensors\n\nSafetensors is a model serialization format for deep learning models. It is [faster](https://huggingface.co/docs/safetensors/speed) and safer compared to other serialization formats like pickle (which is used under the hood in many deep learning libraries).\n\nTGI depends on safetensors format mainly to enable [tensor parallelism sharding](./tensor_parallelism). For a given model repository during serving, TGI looks for safetensors weights. If there are no safetensors weights, TGI converts the PyTorch weights to safetensors format.\n\nYou can learn more about safetensors by reading the [safetensors documentation](https://huggingface.co/docs/safetensors/index).\n"
  },
  {
    "path": "docs/source/conceptual/speculation.md",
    "content": "## Speculation\n\n\nSpeculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.\nThe idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.\n\nSo you are making *more* computations on your LLM, but if you are correct you produce 1, 2, 3 etc.. tokens on a single LLM pass. Since LLMs are usually memory bound (and not compute bound), provided your guesses are correct enough, this is a 2-3x faster inference (It can be much more for code oriented tasks for instance).\n\nYou can check a more [detailed explanation](https://huggingface.co/blog/assisted-generation).\n\nText-generation inference supports 2 main speculative methods:\n\n- Medusa\n- N-gram\n\n\n### Medusa\n\n\nMedusa is a [simple method](https://arxiv.org/abs/2401.10774) to create many tokens in a single pass using fine-tuned LM heads in addition to your existing models.\n\n\nYou can check a few existing  fine-tunes for popular models:\n\n- [text-generation-inference/gemma-7b-it-medusa](https://huggingface.co/text-generation-inference/gemma-7b-it-medusa)\n- [text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa](https://huggingface.co/text-generation-inference/Mixtral-8x7B-Instruct-v0.1-medusa)\n- [text-generation-inference/Mistral-7B-Instruct-v0.2-medusa](https://huggingface.co/text-generation-inference/Mistral-7B-Instruct-v0.2-medusa)\n\n\nIn order to create your own medusa heads for your own finetune, you should check own the original medusa repo. Read for more in [Train Medusa](../basic_tutorials/train_medusa#training).\n\n\nIn order to use medusa models in TGI, simply point to a medusa enabled model, and everything will load automatically.\n\n\n### N-gram\n\n\nIf you don't have a medusa model, or don't have the resource to fine-tune, you can try to use `n-gram`.\nN-gram works by trying to find matching tokens in the previous sequence, and use those as speculation for generating new tokens. For example, if the tokens \"np.mean\" appear multiple times in the sequence, the model can speculate that the next continuation of the tokens \"np.\" is probably also \"mean\".\n\nThis is an extremely simple method, which works best for code, or highly repetitive text. This might not be beneficial, if the speculation misses too much.\n\n\nIn order to enable n-gram speculation simply use\n\n`--speculate 2` in your flags.\n\n[Details about the flag](https://huggingface.co/docs/text-generation-inference/basic_tutorials/launcher#speculate)\n"
  },
  {
    "path": "docs/source/conceptual/streaming.md",
    "content": "# Streaming\n\n\n## What is Streaming?\n\nToken streaming is the mode in which the server returns the tokens one by one as the model generates them. This enables showing progressive generations to the user rather than waiting for the whole generation. Streaming is an essential aspect of the end-user experience as it reduces latency, one of the most critical aspects of a smooth experience.\n\n<div class=\"flex justify-center\">\n    <img\n        class=\"block dark:hidden\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/streaming-generation-visual_360.gif\"\n    />\n    <img\n        class=\"hidden dark:block\"\n        src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/streaming-generation-visual-dark_360.gif\"\n    />\n</div>\n\nWith token streaming, the server can start returning the tokens one by one before having to generate the whole response. Users can have a sense of the generation's quality before the end of the generation. This has different positive effects:\n\n* Users can get results orders of magnitude earlier for extremely long queries.\n* Seeing something in progress allows users to stop the generation if it's not going in the direction they expect.\n* Perceived latency is lower when results are shown in the early stages.\n* When used in conversational UIs, the experience feels more natural.\n\nFor example, a system can generate 100 tokens per second. If the system generates 1000 tokens, with the non-streaming setup, users need to wait 10 seconds to get results. On the other hand, with the streaming setup, users get initial results immediately, and although end-to-end latency will be the same, they can see half of the generation after five seconds. Below you can see an interactive demo that shows non-streaming vs streaming side-by-side. Click **generate** below.\n\n<div class=\"block dark:hidden\">\n\t<iframe\n        src=\"https://huggingface-streaming-vs-non-streaming.hf.space?__theme=light\"\n        width=\"850\"\n        height=\"350\"\n    ></iframe>\n</div>\n<div class=\"hidden dark:block\">\n    <iframe\n        src=\"https://huggingface-streaming-vs-non-streaming.hf.space?__theme=dark\"\n        width=\"850\"\n        height=\"350\"\n    ></iframe>\n</div>\n\n## How to use Streaming?\n\n### Streaming with Python\n\nTo stream tokens with `InferenceClient`, simply pass `stream=True` and iterate over the response.\n\n```python\nfrom huggingface_hub import InferenceClient\n\nclient = InferenceClient(base_url=\"http://127.0.0.1:8080\")\noutput = client.chat.completions.create(\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n        {\"role\": \"user\", \"content\": \"Count to 10\"},\n    ],\n    stream=True,\n    max_tokens=1024,\n)\n\nfor chunk in output:\n    print(chunk.choices[0].delta.content)\n\n# 1\n# 2\n# 3\n# 4\n# 5\n# 6\n# 7\n# 8\n# 9\n# 10\n```\n\nThe `huggingface_hub` library also comes with an `AsyncInferenceClient` in case you need to handle the requests concurrently.\n\n```python\nfrom huggingface_hub import AsyncInferenceClient\n\nclient = AsyncInferenceClient(base_url=\"http://127.0.0.1:8080\")\nasync def main():\n    stream = await client.chat.completions.create(\n        messages=[{\"role\": \"user\", \"content\": \"Say this is a test\"}],\n        stream=True,\n    )\n    async for chunk in stream:\n        print(chunk.choices[0].delta.content or \"\", end=\"\")\n\nasyncio.run(main())\n\n# This\n# is\n# a\n# test\n#.\n```\n\n### Streaming with cURL\n\nTo use the OpenAI Chat Completions compatible Messages API `v1/chat/completions` endpoint with curl, you can add the `-N` flag, which disables curl default buffering and shows data as it arrives from the server\n\n```curl\ncurl localhost:8080/v1/chat/completions \\\n    -X POST \\\n    -d '{\n  \"model\": \"tgi\",\n  \"messages\": [\n    {\n      \"role\": \"system\",\n      \"content\": \"You are a helpful assistant.\"\n    },\n    {\n      \"role\": \"user\",\n      \"content\": \"What is deep learning?\"\n    }\n  ],\n  \"stream\": true,\n  \"max_tokens\": 20\n}' \\\n    -H 'Content-Type: application/json'\n```\n\n### Streaming with JavaScript\n\nFirst, we need to install the `@huggingface/inference` library.\n\n```bash\nnpm install @huggingface/inference\n```\n\nWhether you use Inference Providers (our serverless API), or Inference Endpoints, you can call `InferenceClient`.\n\n\n```js\nimport { InferenceClient } from '@huggingface/inference';\n\nconst client = new InferenceClient('hf_YOUR_TOKEN', { endpointUrl: 'https://YOUR_ENDPOINT.endpoints.huggingface.cloud' });\n\n// prompt\nconst prompt = 'What can you do in Nuremberg, Germany? Give me 3 Tips';\n\nconst stream = client.textGenerationStream({ inputs: prompt });\nfor await (const r of stream) {\n  // yield the generated token\n  process.stdout.write(r.token.text);\n}\n```\n\n## How does Streaming work under the hood?\n\nUnder the hood, TGI uses Server-Sent Events (SSE). In an SSE Setup, a client sends a request with the data, opening an HTTP connection and subscribing to updates. Afterward, the server sends data to the client. There is no need for further requests; the server will keep sending the data. SSEs are unidirectional, meaning the client does not send other requests to the server. SSE sends data over HTTP, making it easy to use.\n\nSSEs are different than:\n* Polling: where the client keeps calling the server to get data. This means that the server might return empty responses and cause overhead.\n* Webhooks: where there is a bi-directional connection. The server can send information to the client, but the client can also send data to the server after the first request. Webhooks are more complex to operate as they don’t only use HTTP.\n\nIf there are too many requests at the same time, TGI returns an HTTP Error with an `overloaded` error type (`huggingface_hub` returns `OverloadedError`). This allows the client to manage the overloaded server (e.g., it could display a busy error to the user or retry with a new request). To configure the maximum number of concurrent requests, you can specify `--max_concurrent_requests`, allowing clients to handle backpressure.\n"
  },
  {
    "path": "docs/source/conceptual/tensor_parallelism.md",
    "content": "# Tensor Parallelism\n\nTensor parallelism is a technique used to fit a large model in multiple GPUs. For example, when multiplying the input tensors with the first weight tensor, the matrix multiplication is equivalent to splitting the weight tensor column-wise, multiplying each column with the input separately, and then concatenating the separate outputs. These outputs are then transferred from the GPUs and concatenated together to get the final result, like below 👇\n\n![Image courtesy of Anton Lozkhov](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/tgi/TP.png)\n\n\n<Tip warning={true}>\n\nTensor Parallelism only works for [models officially supported](../supported_models), it will not work when falling back to `transformers`. You can get more information about unsupported models [here](../basic_tutorials/non_core_models).\n\n</Tip>\n\nYou can learn a lot more details about tensor-parallelism from [the `transformers` docs](https://huggingface.co/docs/transformers/main/en/perf_train_gpu_many#tensor-parallelism).\n"
  },
  {
    "path": "docs/source/index.md",
    "content": "# Text Generation Inference\n\n\n> [!CAUTION]\n> text-generation-inference is now in maintenance mode. Going forward, we will accept pull requests for minor bug fixes, documentation improvements and lightweight maintenance tasks.\n>\n> TGI has initiated the movement for optimized inference engines to rely on a `transformers` model architectures. This approach is now adopted by downstream inference engines, which we contribute to and recommend using going forward: [vllm](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), as well as local engines with inter-compatibility such as llama.cpp or MLX.\n\n\nText Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and T5.\n\n![Text Generation Inference](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/TGI.png)\n\nText Generation Inference implements many optimizations and features, such as:\n\n- Simple launcher to serve most popular LLMs\n- Production ready (distributed tracing with Open Telemetry, Prometheus metrics)\n- Tensor Parallelism for faster inference on multiple GPUs\n- Token streaming using Server-Sent Events (SSE)\n- Continuous batching of incoming requests for increased total throughput\n- Optimized transformers code for inference using [Flash Attention](https://github.com/HazyResearch/flash-attention) and [Paged Attention](https://github.com/vllm-project/vllm) on the most popular architectures\n- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [GPT-Q](https://arxiv.org/abs/2210.17323)\n- [Safetensors](https://github.com/huggingface/safetensors) weight loading\n- Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)\n- Logits warper (temperature scaling, top-p, top-k, repetition penalty)\n- Stop sequences\n- Log probabilities\n- Fine-tuning Support: Utilize fine-tuned models for specific tasks to achieve higher accuracy and performance.\n- [Guidance](conceptual/guidance): Enable function calling and tool-use by forcing the model to generate structured outputs based on your own predefined output schemas.\n\nText Generation Inference is used in production by multiple projects, such as:\n\n- [Hugging Chat](https://github.com/huggingface/chat-ui), an open-source interface for open-access models, such as Open Assistant and Llama\n- [OpenAssistant](https://open-assistant.io/), an open-source community effort to train LLMs in the open\n- [nat.dev](http://nat.dev/), a playground to explore and compare LLMs.\n"
  },
  {
    "path": "docs/source/installation.md",
    "content": "# Installation from source\n\n<Tip warning={true}>\n\nInstalling TGI from source is not the recommended usage. We strongly recommend to use TGI through Docker, check the [Quick Tour](./quicktour), [Installation for Nvidia GPUs](./installation_nvidia) and [Installation for AMD GPUs](./installation_amd) to learn how to use TGI with Docker.\n\n</Tip>\n\n## Install CLI\n\nYou can use TGI command-line interface (CLI) to download weights, serve and quantize models, or get information on serving parameters.\n\nTo install the CLI, you need to first clone the TGI repository and then run `make`.\n\n```bash\ngit clone https://github.com/huggingface/text-generation-inference.git && cd text-generation-inference\nmake install\n```\n\nIf you would like to serve models with custom kernels, run\n\n```bash\nBUILD_EXTENSIONS=True make install\n```\n\n## Local Installation from Source\n\nBefore you start, you will need to setup your environment, and install Text Generation Inference. Text Generation Inference is tested on **Python 3.9+**.\n\nText Generation Inference is available on pypi, conda and GitHub.\n\nTo install and launch locally, first [install Rust](https://rustup.rs/) and create a Python virtual environment with at least\nPython 3.9, e.g. using conda:\n\n```bash\ncurl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh\n\nconda create -n text-generation-inference python=3.9\nconda activate text-generation-inference\n```\n\nYou may also need to install Protoc.\n\nOn Linux:\n\n```bash\nPROTOC_ZIP=protoc-21.12-linux-x86_64.zip\ncurl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP\nsudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc\nsudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'\nrm -f $PROTOC_ZIP\n```\n\nOn MacOS, using Homebrew:\n\n```bash\nbrew install protobuf\n```\n\nThen run to install Text Generation Inference:\n\n```bash\ngit clone https://github.com/huggingface/text-generation-inference.git && cd text-generation-inference\nBUILD_EXTENSIONS=True make install\n```\n\n<Tip warning={true}>\n\nOn some machines, you may also need the OpenSSL libraries and gcc. On Linux machines, run:\n\n```bash\nsudo apt-get install libssl-dev gcc -y\n```\n\n</Tip>\n\nOnce installation is done, simply run:\n\n```bash\nmake run-falcon-7b-instruct\n```\n\nThis will serve Falcon 7B Instruct model from the port 8080, which we can query.\n"
  },
  {
    "path": "docs/source/installation_amd.md",
    "content": "# Using TGI with AMD GPUs\n\nTGI is supported and tested on [AMD Instinct MI210](https://www.amd.com/en/products/accelerators/instinct/mi200/mi210.html), [MI250](https://www.amd.com/en/products/accelerators/instinct/mi200/mi250.html) and [MI300](https://www.amd.com/en/products/accelerators/instinct/mi300.html) GPUs. The support may be extended in the future. The recommended usage is through Docker. Make sure to check the [AMD documentation](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html) on how to use Docker with AMD GPUs.\n\nOn a server powered by AMD GPUs, TGI can be launched with the following command:\n\n```bash\nmodel=teknium/OpenHermes-2.5-Mistral-7B\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \\\n    --device=/dev/kfd --device=/dev/dri --group-add video \\\n    --ipc=host --shm-size 256g --net host -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5-rocm \\\n    --model-id $model\n```\n\nThe launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.\n\n## TunableOp\n\nTGI's docker image for AMD GPUs integrates [PyTorch's TunableOp](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable), which allows to do an additional warmup to select the best performing matrix multiplication (GEMM) kernel from rocBLAS or hipBLASLt.\n\nExperimentally, on MI300X, we noticed a 6-8% latency improvement when using TunableOp on top of ROCm 6.1 and PyTorch 2.3.\n\nTunableOp is enabled by default, the warmup may take 1-2 minutes. In case you would like to disable TunableOp, please pass `--env PYTORCH_TUNABLEOP_ENABLED=\"0\"` when launcher TGI's docker container.\n\n## Flash attention implementation\n\nTwo implementations of Flash Attention are available for ROCm, the first is [ROCm/flash-attention](https://github.com/ROCm/flash-attention) based on a [Composable Kernel](https://github.com/ROCm/composable_kernel) (CK) implementation, and the second is a [Triton implementation](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/layers/attention/flash_attn_triton.py).\n\nBy default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON=\"0\"` when launching TGI's docker container.\n\n## Custom PagedAttention\n\nFor better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.\n\nThe custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.\n\n## Unsupported features\n\nThe following features are currently not supported in the ROCm version of TGI, and the support may be extended in the future:\n* Loading [AWQ](https://huggingface.co/docs/transformers/quantization#awq) checkpoints.\n* Kernel for sliding window attention (Mistral)\n"
  },
  {
    "path": "docs/source/installation_gaudi.md",
    "content": "# Using TGI with Intel Gaudi\n\nYou can use TGI on Intel Gaudi using the [TGI gaudi backend](https://huggingface.co/docs/text-generation-inference/backends/gaudi).\n"
  },
  {
    "path": "docs/source/installation_inferentia.md",
    "content": "# Using TGI with Inferentia\n\nYou can use TGI on AWS Trainium and Inferentia platforms using the [TGI neuron backend](https://huggingface.co/docs/text-generation-inference/backends/neuron).\n"
  },
  {
    "path": "docs/source/installation_intel.md",
    "content": "# Using TGI with Intel GPUs\n\nTGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.\n\n\nOn a server powered by Intel GPUs, TGI can be launched with the following command:\n\n```bash\nmodel=teknium/OpenHermes-2.5-Mistral-7B\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run --rm --privileged --cap-add=sys_nice \\\n    --device=/dev/dri \\\n    --ipc=host --shm-size 1g --net host -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5-intel-xpu \\\n    --model-id $model --cuda-graphs 0\n```\n\n# Using TGI with Intel CPUs\n\nIntel® Extension for PyTorch (IPEX) also provides further optimizations for Intel CPUs. The IPEX provides optimization operations such as flash attention, page attention, Add + LayerNorm, ROPE and more.\n\nOn a server powered by Intel CPU, TGI can be launched with the following command:\n\n```bash\nmodel=teknium/OpenHermes-2.5-Mistral-7B\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run --rm --privileged --cap-add=sys_nice \\\n    --device=/dev/dri \\\n    --ipc=host --shm-size 1g --net host -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5-intel-cpu \\\n    --model-id $model --cuda-graphs 0\n```\n\nThe launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.\n"
  },
  {
    "path": "docs/source/installation_nvidia.md",
    "content": "# Using TGI with Nvidia GPUs\n\nTGI optimized models are supported on NVIDIA [H100](https://www.nvidia.com/en-us/data-center/h100/), [A100](https://www.nvidia.com/en-us/data-center/a100/), [A10G](https://www.nvidia.com/en-us/data-center/products/a10-gpu/) and [T4](https://www.nvidia.com/en-us/data-center/tesla-t4/) GPUs with CUDA 12.2+. Note that you have to install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) to use it.\n\nFor other NVIDIA GPUs, continuous batching will still apply, but some operations like flash attention and paged attention will not be executed.\n\nTGI can be used on NVIDIA GPUs through its official docker image:\n\n```bash\nmodel=teknium/OpenHermes-2.5-Mistral-7B\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5 \\\n    --model-id $model\n```\n\nThe launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.\n"
  },
  {
    "path": "docs/source/installation_tpu.md",
    "content": "# Using TGI with Google TPUs\n\nCheck out this [guide](https://huggingface.co/docs/optimum-tpu) on how to serve models with TGI on TPUs.\n"
  },
  {
    "path": "docs/source/multi_backend_support.md",
    "content": "# Multi-backend support\n\nTGI (Text Generation Inference) offers flexibility by supporting multiple backends for serving large language models (LLMs).\nWith multi-backend support, you can choose the backend that best suits your needs,\nwhether you prioritize performance, ease of use, or compatibility with specific hardware. API interaction with\nTGI remains consistent across backends, allowing you to switch between them seamlessly.\n\n**Supported backends:**\n* **TGI CUDA backend**: This high-performance backend is optimized for NVIDIA GPUs and serves as the default option\n  within TGI. Developed in-house, it boasts numerous optimizations and is used in production by various projects, including those by Hugging Face.\n* **[TGI TRTLLM backend](./backends/trtllm)**: This backend leverages NVIDIA's TensorRT library to accelerate LLM inference.\n  It utilizes specialized optimizations and custom kernels for enhanced performance.\n  However, it requires a model-specific compilation step for each GPU architecture.\n* **[TGI Llamacpp backend](./backends/llamacpp)**: This backend facilitates the deployment of large language models\n  (LLMs) by integrating [llama.cpp][llama.cpp], an advanced inference engine optimized for both CPU and GPU computation.\n* **[TGI Neuron backend](./backends/neuron)**: This backend leverages the [AWS Neuron SDK](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) to allow the deployment of large language models (LLMs) on [AWS Trainium and Inferentia chips](https://aws.amazon.com/ai/machine-learning/trainium/).\n"
  },
  {
    "path": "docs/source/quicktour.md",
    "content": "# Quick Tour\n\nThe easiest way of getting started is using the official Docker container. Install Docker following [their installation instructions](https://docs.docker.com/get-docker/).\n\n## Launching TGI\n\nLet's say you want to deploy [teknium/OpenHermes-2.5-Mistral-7B](https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B) model with TGI on an Nvidia GPU. Here is an example on how to do that:\n\n```bash\nmodel=teknium/OpenHermes-2.5-Mistral-7B\nvolume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run\n\ndocker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \\\n    ghcr.io/huggingface/text-generation-inference:3.3.5 \\\n    --model-id $model\n```\n\n<Tip>\n\nIf you want to serve gated or private models, please refer to\n[this guide](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/gated_model_access)\nfor detailed instructions.\n\n</Tip>\n\n### Supported hardware\n\nTGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.\n\n## Consuming TGI\n\nOnce TGI is running, you can use the `generate` endpoint or the Open AI Chat Completion API compatible [Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) by doing requests. To learn more about how to query the endpoints, check the [Consuming TGI](./basic_tutorials/consuming_tgi) section, where we show examples with utility libraries and UIs. Below you can see a simple snippet to query the endpoint.\n\n<inferencesnippet>\n<python>\n\n```python\nimport requests\n\nheaders = {\n    \"Content-Type\": \"application/json\",\n}\n\ndata = {\n    'inputs': 'What is Deep Learning?',\n    'parameters': {\n        'max_new_tokens': 20,\n    },\n}\n\nresponse = requests.post('http://127.0.0.1:8080/generate', headers=headers, json=data)\nprint(response.json())\n# {'generated_text': '\\n\\nDeep Learning is a subset of Machine Learning that is concerned with the development of algorithms that can'}\n```\n</python>\n<js>\n\n```js\nasync function query() {\n    const response = await fetch(\n        'http://127.0.0.1:8080/generate',\n        {\n            method: 'POST',\n            headers: { 'Content-Type': 'application/json'},\n            body: JSON.stringify({\n                'inputs': 'What is Deep Learning?',\n                'parameters': {\n                    'max_new_tokens': 20\n                }\n            })\n        }\n    );\n}\n\nquery().then((response) => {\n    console.log(JSON.stringify(response));\n});\n/// {\"generated_text\":\"\\n\\nDeep Learning is a subset of Machine Learning that is concerned with the development of algorithms that can\"}\n```\n\n</js>\n<curl>\n\n```curl\ncurl 127.0.0.1:8080/generate \\\n    -X POST \\\n    -d '{\"inputs\":\"What is Deep Learning?\",\"parameters\":{\"max_new_tokens\":20}}' \\\n    -H 'Content-Type: application/json'\n```\n\n</curl>\n</inferencesnippet>\n\n<Tip>\n\nTo see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.\n\n```bash\ndocker run ghcr.io/huggingface/text-generation-inference:3.3.5 --help\n```\n\n</Tip>\n"
  },
  {
    "path": "docs/source/reference/api_reference.md",
    "content": "# HTTP API Reference\n\n#### Table of Contents\n\n- [Text Generation Inference custom API](#text-generation-inference-custom-api)\n- [OpenAI Messages API](#openai-messages-api)\n  - [Making a Request](#making-a-request)\n  - [Streaming](#streaming)\n  - [Synchronous](#synchronous)\n  - [Hugging Face Inference Endpoints](#hugging-face-inference-endpoints)\n  - [Cloud Providers](#cloud-providers)\n      - [Amazon SageMaker](#amazon-sagemaker)\n\nThe HTTP API is a RESTful API that allows you to interact with the text-generation-inference component. Two endpoints are available:\n* Text Generation Inference [custom API](https://huggingface.github.io/text-generation-inference/)\n* OpenAI's [Messages API](#openai-messages-api)\n\n\n## Text Generation Inference custom API\n\nCheck the [API documentation](https://huggingface.github.io/text-generation-inference/) for more information on how to interact with the Text Generation Inference API.\n\n## OpenAI Messages API\n\nText Generation Inference (TGI) now supports the Messages API, which is fully compatible with the OpenAI Chat Completion API. This feature is available starting from version 1.4.0. You can use OpenAI's client libraries or third-party libraries expecting OpenAI schema to interact with TGI's Messages API. Below are some examples of how to utilize this compatibility.\n\n> **Note:** The Messages API is supported from TGI version 1.4.0 and above. Ensure you are using a compatible version to access this feature.\n\n## Making a Request\n\nYou can make a request to TGI's Messages API using `curl`. Here's an example:\n\n```bash\ncurl localhost:3000/v1/chat/completions \\\n    -X POST \\\n    -d '{\n  \"model\": \"tgi\",\n  \"messages\": [\n    {\n      \"role\": \"system\",\n      \"content\": \"You are a helpful assistant.\"\n    },\n    {\n      \"role\": \"user\",\n      \"content\": \"What is deep learning?\"\n    }\n  ],\n  \"stream\": true,\n  \"max_tokens\": 20\n}' \\\n    -H 'Content-Type: application/json'\n```\n\n## Streaming\n\nYou can also use OpenAI's Python client library to make a streaming request. Here's how:\n\n```python\nfrom openai import OpenAI\n\n# init the client but point it to TGI\nclient = OpenAI(\n    base_url=\"http://localhost:3000/v1\",\n    api_key=\"-\"\n)\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n        {\"role\": \"user\", \"content\": \"What is deep learning?\"}\n    ],\n    stream=True\n)\n\n# iterate and print stream\nfor message in chat_completion:\n    print(message)\n```\n\n## Synchronous\n\nIf you prefer to make a synchronous request, you can do so like this:\n\n```python\nfrom openai import OpenAI\n\n# init the client but point it to TGI\nclient = OpenAI(\n    base_url=\"http://localhost:3000/v1\",\n    api_key=\"-\"\n)\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n        {\"role\": \"user\", \"content\": \"What is deep learning?\"}\n    ],\n    stream=False\n)\n\nprint(chat_completion)\n```\n\n## Hugging Face Inference Endpoints\n\nThe Messages API is integrated with [Inference Endpoints](https://huggingface.co/inference-endpoints/dedicated).\nEvery endpoint that uses \"Text Generation Inference\" with an LLM, which has a chat template can now be used. Below is an example of how to use IE with TGI using OpenAI's Python client library:\n\n> **Note:** Make sure to replace `base_url` with your endpoint URL and to include `v1/` at the end of the URL. The `api_key` should be replaced with your Hugging Face API key.\n\n```python\nfrom openai import OpenAI\n\n# init the client but point it to TGI\nclient = OpenAI(\n    # replace with your endpoint url, make sure to include \"v1/\" at the end\n    base_url=\"https://vlzz10eq3fol3429.us-east-1.aws.endpoints.huggingface.cloud/v1/\",\n    # replace with your API key\n    api_key=\"hf_XXX\"\n)\n\nchat_completion = client.chat.completions.create(\n    model=\"tgi\",\n    messages=[\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n        {\"role\": \"user\", \"content\": \"What is deep learning?\"}\n    ],\n    stream=True\n)\n\n# iterate and print stream\nfor message in chat_completion:\n    print(message.choices[0].delta.content, end=\"\")\n```\n\n## Cloud Providers\n\nTGI can be deployed on various cloud providers for scalable and robust text generation. One such provider is Amazon SageMaker, which has recently added support for TGI. Here's how you can deploy TGI on Amazon SageMaker:\n\n## Amazon SageMaker\n\nAmazon Sagemaker natively supports the message API:\n\n```python\nimport json\nimport sagemaker\nimport boto3\nfrom sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri\n\ntry:\n role = sagemaker.get_execution_role()\nexcept ValueError:\n iam = boto3.client('iam')\n role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n\n# Hub Model configuration. https://huggingface.co/models\nhub = {\n 'HF_MODEL_ID':'HuggingFaceH4/zephyr-7b-beta',\n 'SM_NUM_GPUS': json.dumps(1),\n}\n\n# create Hugging Face Model Class\nhuggingface_model = HuggingFaceModel(\n image_uri=get_huggingface_llm_image_uri(\"huggingface\",version=\"3.3.5\"),\n env=hub,\n role=role,\n)\n\n# deploy model to SageMaker Inference\npredictor = huggingface_model.deploy(\n initial_instance_count=1,\n instance_type=\"ml.g5.2xlarge\",\n container_startup_health_check_timeout=300,\n  )\n\n# send request\npredictor.predict({\n\"messages\": [\n        {\"role\": \"system\", \"content\": \"You are a helpful assistant.\" },\n        {\"role\": \"user\", \"content\": \"What is deep learning?\"}\n    ]\n})\n```\n"
  },
  {
    "path": "docs/source/reference/launcher.md",
    "content": "# Text-generation-launcher arguments\n\n<!-- WRAP CODE BLOCKS -->\n\n```shell\nText Generation Launcher\n\nUsage: text-generation-launcher [OPTIONS]\n\nOptions:\n```\n## MODEL_ID\n```shell\n      --model-id <MODEL_ID>\n          The name of the model to load. Can be a MODEL_ID as listed on <https://hf.co/models> like `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`. Or it can be a local directory containing the necessary files as saved by `save_pretrained(...)` methods of transformers\n          \n          [env: MODEL_ID=]\n          [default: bigscience/bloom-560m]\n\n```\n## REVISION\n```shell\n      --revision <REVISION>\n          The actual revision of the model if you're referring to a model on the hub. You can use a specific commit id or a branch like `refs/pr/2`\n          \n          [env: REVISION=]\n\n```\n## VALIDATION_WORKERS\n```shell\n      --validation-workers <VALIDATION_WORKERS>\n          The number of tokenizer workers used for payload validation and truncation inside the router\n          \n          [env: VALIDATION_WORKERS=]\n          [default: 2]\n\n```\n## SHARDED\n```shell\n      --sharded <SHARDED>\n          Whether to shard the model across multiple GPUs By default text-generation-inference will use all available GPUs to run the model. Setting it to `false` deactivates `num_shard`\n          \n          [env: SHARDED=]\n          [possible values: true, false]\n\n```\n## NUM_SHARD\n```shell\n      --num-shard <NUM_SHARD>\n          The number of shards to use if you don't want to use all GPUs on a given machine. You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2` and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance\n          \n          [env: NUM_SHARD=]\n\n```\n## QUANTIZE\n```shell\n      --quantize <QUANTIZE>\n          Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.\n          \n          Marlin kernels will be used automatically for GPTQ/AWQ models.\n\n          Possible values:\n          - awq:                4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency\n          - compressed-tensors: Compressed tensors, which can be a mixture of different quantization methods\n          - eetq:               8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>\n          - exl2:               Variable bit quantization. Requires a specific EXL2 quantized model: <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does not support tensor parallelism (num_shard > 1)\n          - gptq:               4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels\n          - marlin:             4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>\n          - bitsandbytes:       Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16\n          - bitsandbytes-nf4:   Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16\n          - bitsandbytes-fp4:   Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model\n          - fp8:                [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above This dtype has native ops should be the fastest if available. This is currently not the fastest because of local unpacking + padding to satisfy matrix multiplication limitations\n          \n          [env: QUANTIZE=]\n\n```\n## SPECULATE\n```shell\n      --speculate <SPECULATE>\n          The number of input_ids to speculate on If using a medusa model, the heads will be picked up automatically Other wise, it will use n-gram speculation which is relatively free in terms of compute, but the speedup heavily depends on the task\n          \n          [env: SPECULATE=]\n\n```\n## DTYPE\n```shell\n      --dtype <DTYPE>\n          The dtype to be forced upon the model. This option cannot be used with `--quantize`\n          \n          [env: DTYPE=]\n          [possible values: float16, bfloat16]\n\n```\n## KV_CACHE_DTYPE\n```shell\n      --kv-cache-dtype <KV_CACHE_DTYPE>\n          Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA\n          \n          [env: KV_CACHE_DTYPE=]\n          [possible values: fp8_e4m3fn, fp8_e5m2]\n\n```\n## TRUST_REMOTE_CODE\n```shell\n      --trust-remote-code\n          Whether you want to execute hub modelling code. Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision\n          \n          [env: TRUST_REMOTE_CODE=]\n\n```\n## MAX_CONCURRENT_REQUESTS\n```shell\n      --max-concurrent-requests <MAX_CONCURRENT_REQUESTS>\n          The maximum amount of concurrent requests for this particular deployment. Having a low limit will refuse clients requests instead of having them wait for too long and is usually good to handle backpressure correctly\n          \n          [env: MAX_CONCURRENT_REQUESTS=]\n          [default: 128]\n\n```\n## MAX_BEST_OF\n```shell\n      --max-best-of <MAX_BEST_OF>\n          This is the maximum allowed value for clients to set `best_of`. Best of makes `n` generations at the same time, and return the best in terms of overall log probability over the entire generated sequence\n          \n          [env: MAX_BEST_OF=]\n          [default: 2]\n\n```\n## MAX_STOP_SEQUENCES\n```shell\n      --max-stop-sequences <MAX_STOP_SEQUENCES>\n          This is the maximum allowed value for clients to set `stop_sequences`. Stop sequences are used to allow the model to stop on more than just the EOS token, and enable more complex \"prompting\" where users can preprompt the model in a specific way and define their \"own\" stop token aligned with their prompt\n          \n          [env: MAX_STOP_SEQUENCES=]\n          [default: 4]\n\n```\n## MAX_TOP_N_TOKENS\n```shell\n      --max-top-n-tokens <MAX_TOP_N_TOKENS>\n          This is the maximum allowed value for clients to set `top_n_tokens`. `top_n_tokens` is used to return information about the the `n` most likely tokens at each generation step, instead of just the sampled token. This information can be used for downstream tasks like for classification or ranking\n          \n          [env: MAX_TOP_N_TOKENS=]\n          [default: 5]\n\n```\n## MAX_INPUT_TOKENS\n```shell\n      --max-input-tokens <MAX_INPUT_TOKENS>\n          This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_allocatable, max_position_embeddings) - 1\n          \n          [env: MAX_INPUT_TOKENS=]\n\n```\n## MAX_INPUT_LENGTH\n```shell\n      --max-input-length <MAX_INPUT_LENGTH>\n          Legacy version of [`Args::max_input_tokens`]\n          \n          [env: MAX_INPUT_LENGTH=]\n\n```\n## MAX_TOTAL_TOKENS\n```shell\n      --max-total-tokens <MAX_TOTAL_TOKENS>\n          This is the most important value to set as it defines the \"memory budget\" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_allocatable, max_position_embeddings)\n          \n          [env: MAX_TOTAL_TOKENS=]\n\n```\n## WAITING_SERVED_RATIO\n```shell\n      --waiting-served-ratio <WAITING_SERVED_RATIO>\n          This represents the ratio of waiting queries vs running queries where you want to start considering pausing the running queries to include the waiting ones into the same batch. `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's only 10 queries left in the current batch we check if we can fit those 12 waiting queries into the batching strategy, and if yes, then batching happens delaying the 10 running queries by a `prefill` run.\n          \n          This setting is only applied if there is room in the batch as defined by `max_batch_total_tokens`.\n          \n          [env: WAITING_SERVED_RATIO=]\n          [default: 0.3]\n\n```\n## MAX_BATCH_PREFILL_TOKENS\n```shell\n      --max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>\n          Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room\n          \n          [env: MAX_BATCH_PREFILL_TOKENS=]\n\n```\n## MAX_BATCH_TOTAL_TOKENS\n```shell\n      --max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>\n          **IMPORTANT** This is one critical control to allow maximum usage of the available hardware.\n          \n          This represents the total amount of potential tokens within a batch. When using padding (not recommended) this would be equivalent of `batch_size` * `max_total_tokens`.\n          \n          However in the non-padded (flash attention) version this can be much finer.\n          \n          For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100` or a single query of `1000` tokens.\n          \n          Overall this number should be the largest possible amount that fits the remaining memory (after the model is loaded). Since the actual memory overhead depends on other parameters like if you're using quantization, flash attention or the model implementation, text-generation-inference infers this number automatically if not provided ensuring that the value is as large as possible.\n          \n          [env: MAX_BATCH_TOTAL_TOKENS=]\n\n```\n## MAX_WAITING_TOKENS\n```shell\n      --max-waiting-tokens <MAX_WAITING_TOKENS>\n          This setting defines how many tokens can be passed before forcing the waiting queries to be put on the batch (if the size of the batch allows for it). New queries require 1 `prefill` forward, which is different from `decode` and therefore you need to pause the running batch in order to run `prefill` to create the correct values for the waiting queries to be able to join the batch.\n          \n          With a value too small, queries will always \"steal\" the compute to run `prefill` and running queries will be delayed by a lot.\n          \n          With a value too big, waiting queries could wait for a very long time before being allowed a slot in the running batch. If your server is busy that means that requests that could run in ~2s on an empty server could end up running in ~20s because the query had to wait for 18s.\n          \n          This number is expressed in number of tokens to make it a bit more \"model\" agnostic, but what should really matter is the overall latency for end users.\n          \n          [env: MAX_WAITING_TOKENS=]\n          [default: 20]\n\n```\n## MAX_BATCH_SIZE\n```shell\n      --max-batch-size <MAX_BATCH_SIZE>\n          Enforce a maximum number of requests per batch Specific flag for hardware targets that do not support unpadded inference\n          \n          [env: MAX_BATCH_SIZE=]\n\n```\n## CUDA_GRAPHS\n```shell\n      --cuda-graphs <CUDA_GRAPHS>\n          Specify the batch sizes to compute cuda graphs for. Use \"0\" to disable. Default = \"1,2,4,8,16,32\"\n          \n          [env: CUDA_GRAPHS=]\n\n```\n## HOSTNAME\n```shell\n      --hostname <HOSTNAME>\n          The IP address to listen on\n          \n          [env: HOSTNAME=]\n          [default: 0.0.0.0]\n\n```\n## PORT\n```shell\n  -p, --port <PORT>\n          The port to listen on\n          \n          [env: PORT=]\n          [default: 3000]\n\n```\n## PROMETHEUS_PORT\n```shell\n  -p, --prometheus-port <PROMETHEUS_PORT>\n          The Prometheus port to listen on\n          \n          [env: PROMETHEUS_PORT=]\n          [default: 9000]\n\n```\n## SHARD_UDS_PATH\n```shell\n      --shard-uds-path <SHARD_UDS_PATH>\n          The name of the socket for gRPC communication between the webserver and the shards\n          \n          [env: SHARD_UDS_PATH=]\n          [default: /tmp/text-generation-server]\n\n```\n## MASTER_ADDR\n```shell\n      --master-addr <MASTER_ADDR>\n          The address the master shard will listen on. (setting used by torch distributed)\n          \n          [env: MASTER_ADDR=]\n          [default: localhost]\n\n```\n## MASTER_PORT\n```shell\n      --master-port <MASTER_PORT>\n          The address the master port will listen on. (setting used by torch distributed)\n          \n          [env: MASTER_PORT=]\n          [default: 29500]\n\n```\n## HUGGINGFACE_HUB_CACHE\n```shell\n      --huggingface-hub-cache <HUGGINGFACE_HUB_CACHE>\n          The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance\n          \n          [env: HUGGINGFACE_HUB_CACHE=]\n\n```\n## WEIGHTS_CACHE_OVERRIDE\n```shell\n      --weights-cache-override <WEIGHTS_CACHE_OVERRIDE>\n          The location of the huggingface hub cache. Used to override the location if you want to provide a mounted disk for instance\n          \n          [env: WEIGHTS_CACHE_OVERRIDE=]\n\n```\n## DISABLE_CUSTOM_KERNELS\n```shell\n      --disable-custom-kernels\n          For some models (like bloom), text-generation-inference implemented custom cuda kernels to speed up inference. Those kernels were only tested on A100. Use this flag to disable them if you're running on different hardware and encounter issues\n          \n          [env: DISABLE_CUSTOM_KERNELS=]\n\n```\n## CUDA_MEMORY_FRACTION\n```shell\n      --cuda-memory-fraction <CUDA_MEMORY_FRACTION>\n          Limit the CUDA available memory. The allowed value equals the total visible memory multiplied by cuda-memory-fraction\n          \n          [env: CUDA_MEMORY_FRACTION=]\n          [default: 1.0]\n\n```\n## ROPE_SCALING\n```shell\n      --rope-scaling <ROPE_SCALING>\n          Rope scaling will only be used for RoPE models and allow rescaling the position rotary to accomodate for larger prompts.\n          \n          Goes together with `rope_factor`.\n          \n          `--rope-factor 2.0` gives linear scaling with a factor of 2.0 `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0 `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed basically)\n          \n          `--rope-scaling linear --rope-factor` fully describes the scaling you want\n          \n          [env: ROPE_SCALING=]\n          [possible values: linear, dynamic]\n\n```\n## ROPE_FACTOR\n```shell\n      --rope-factor <ROPE_FACTOR>\n          Rope scaling will only be used for RoPE models See `rope_scaling`\n          \n          [env: ROPE_FACTOR=]\n\n```\n## JSON_OUTPUT\n```shell\n      --json-output\n          Outputs the logs in JSON format (useful for telemetry)\n          \n          [env: JSON_OUTPUT=]\n\n```\n## OTLP_ENDPOINT\n```shell\n      --otlp-endpoint <OTLP_ENDPOINT>\n          [env: OTLP_ENDPOINT=]\n\n```\n## OTLP_SERVICE_NAME\n```shell\n      --otlp-service-name <OTLP_SERVICE_NAME>\n          [env: OTLP_SERVICE_NAME=]\n          [default: text-generation-inference.router]\n\n```\n## CORS_ALLOW_ORIGIN\n```shell\n      --cors-allow-origin <CORS_ALLOW_ORIGIN>\n          [env: CORS_ALLOW_ORIGIN=]\n\n```\n## API_KEY\n```shell\n      --api-key <API_KEY>\n          [env: API_KEY=]\n\n```\n## WATERMARK_GAMMA\n```shell\n      --watermark-gamma <WATERMARK_GAMMA>\n          [env: WATERMARK_GAMMA=]\n\n```\n## WATERMARK_DELTA\n```shell\n      --watermark-delta <WATERMARK_DELTA>\n          [env: WATERMARK_DELTA=]\n\n```\n## NGROK\n```shell\n      --ngrok\n          Enable ngrok tunneling\n          \n          [env: NGROK=]\n\n```\n## NGROK_AUTHTOKEN\n```shell\n      --ngrok-authtoken <NGROK_AUTHTOKEN>\n          ngrok authentication token\n          \n          [env: NGROK_AUTHTOKEN=]\n\n```\n## NGROK_EDGE\n```shell\n      --ngrok-edge <NGROK_EDGE>\n          ngrok edge\n          \n          [env: NGROK_EDGE=]\n\n```\n## TOKENIZER_CONFIG_PATH\n```shell\n      --tokenizer-config-path <TOKENIZER_CONFIG_PATH>\n          The path to the tokenizer config file. This path is used to load the tokenizer configuration which may include a `chat_template`. If not provided, the default config will be used from the model hub\n          \n          [env: TOKENIZER_CONFIG_PATH=]\n\n```\n## DISABLE_GRAMMAR_SUPPORT\n```shell\n      --disable-grammar-support\n          Disable outlines grammar constrained generation. This is a feature that allows you to generate text that follows a specific grammar\n          \n          [env: DISABLE_GRAMMAR_SUPPORT=]\n\n```\n## ENV\n```shell\n  -e, --env\n          Display a lot of information about your runtime environment\n\n```\n## MAX_CLIENT_BATCH_SIZE\n```shell\n      --max-client-batch-size <MAX_CLIENT_BATCH_SIZE>\n          Control the maximum number of inputs that a client can send in a single request\n          \n          [env: MAX_CLIENT_BATCH_SIZE=]\n          [default: 4]\n\n```\n## LORA_ADAPTERS\n```shell\n      --lora-adapters <LORA_ADAPTERS>\n          Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during startup that will be available to callers via the `adapter_id` field in a request\n          \n          [env: LORA_ADAPTERS=]\n\n```\n## USAGE_STATS\n```shell\n      --usage-stats <USAGE_STATS>\n          Control if anonymous usage stats are collected. Options are \"on\", \"off\" and \"no-stack\" Defaul is on\n\n          Possible values:\n          - on:       Default option, usage statistics are collected anonymously\n          - off:      Disables all collection of usage statistics\n          - no-stack: Doesn't send the error stack trace or error type, but allows sending a crash event\n          \n          [env: USAGE_STATS=]\n          [default: on]\n\n```\n## PAYLOAD_LIMIT\n```shell\n      --payload-limit <PAYLOAD_LIMIT>\n          Payload size limit in bytes\n          \n          Default is 2MB\n          \n          [env: PAYLOAD_LIMIT=]\n          [default: 2000000]\n\n```\n## ENABLE_PREFILL_LOGPROBS\n```shell\n      --enable-prefill-logprobs\n          Enables prefill logprobs\n          \n          Logprobs in the prompt are deactivated by default because they consume a large amount of VRAM (especially for long prompts). Using this flag reallows users to ask for them.\n          \n          [env: ENABLE_PREFILL_LOGPROBS=]\n\n```\n## GRACEFUL_TERMINATION_TIMEOUT\n```shell\n  -g, --graceful-termination-timeout <GRACEFUL_TERMINATION_TIMEOUT>\n          Change timeout of graceful termination of the TGI server\n          \n          [env: GRACEFUL_TERMINATION_TIMEOUT=]\n          [default: 90]\n\n```\n## HELP\n```shell\n  -h, --help\n          Print help (see a summary with '-h')\n\n```\n## VERSION\n```shell\n  -V, --version\n          Print version\n\n```\n"
  },
  {
    "path": "docs/source/reference/metrics.md",
    "content": "# Metrics\n\nTGI exposes multiple metrics that can be collected via the `/metrics` Prometheus endpoint.\nThese metrics can be used to monitor the performance of TGI, autoscale deployment and to help identify bottlenecks.\n\nThe following metrics are exposed:\n\n| Metric Name                                | Description                                                                              | Type      | Unit    |\n|--------------------------------------------|------------------------------------------------------------------------------------------|-----------|---------|\n| `tgi_batch_current_max_tokens`             | Maximum tokens for the current batch                                                     | Gauge     | Count   |\n| `tgi_batch_current_size`                   | Current batch size                                                                       | Gauge     | Count   |\n| `tgi_batch_decode_duration`                | Time spent decoding a batch per method (prefill or decode)                               | Histogram | Seconds |\n| `tgi_batch_filter_duration`                | Time spent filtering batches and sending generated tokens per method (prefill or decode) | Histogram | Seconds |\n| `tgi_batch_forward_duration`               | Batch forward duration per method (prefill or decode)                                    | Histogram | Seconds |\n| `tgi_batch_inference_count`                | Inference calls per method (prefill or decode)                                           | Counter   | Count   |\n| `tgi_batch_inference_duration`             | Batch inference duration                                                                 | Histogram | Seconds |\n| `tgi_batch_inference_success`              | Number of successful inference calls per method (prefill or decode)                      | Counter   | Count   |\n| `tgi_batch_next_size`                      | Batch size of the next batch                                                             | Histogram | Count   |\n| `tgi_queue_size`                           | Current queue size                                                                       | Gauge     | Count   |\n| `tgi_request_count`                        | Total number of requests                                                                 | Counter   | Count   |\n| `tgi_request_duration`                     | Total time spent processing the request (e2e latency)                                    | Histogram | Seconds |\n| `tgi_request_generated_tokens`             | Generated tokens per request                                                             | Histogram | Count   |\n| `tgi_request_inference_duration`           | Request inference duration                                                               | Histogram | Seconds |\n| `tgi_request_input_length`                 | Input token length per request                                                           | Histogram | Count   |\n| `tgi_request_max_new_tokens`               | Maximum new tokens per request                                                           | Histogram | Count   |\n| `tgi_request_mean_time_per_token_duration` | Mean time per token per request (inter-token latency)                                    | Histogram | Seconds |\n| `tgi_request_queue_duration`               | Time spent in the queue per request                                                      | Histogram | Seconds |\n| `tgi_request_skipped_tokens`               | Speculated tokens per request                                                            | Histogram | Count   |\n| `tgi_request_success`                      | Number of successful requests                                                            | Counter   |         |\n| `tgi_request_validation_duration`          | Time spent validating the request                                                        | Histogram | Seconds |\n"
  },
  {
    "path": "docs/source/supported_models.md",
    "content": "\n# Supported Models\n\nText Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.\n\n- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)\n- [Deepseek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3)\n- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)\n- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)\n- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)\n- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)\n- [Llama4](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)\n- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)\n- [Granite](https://huggingface.co/ibm-granite/granite-3.0-8b-instruct)\n- [Gemma](https://huggingface.co/google/gemma-7b)\n- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)\n- [Gemma2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315)\n- [Gemma3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)\n- [Gemma3 Text](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d)\n- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)\n- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)\n- [Mamba](https://huggingface.co/state-spaces/mamba-2.8b-slimpj)\n- [Mistral](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)\n- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1)\n- [Gpt Bigcode](https://huggingface.co/bigcode/gpt_bigcode-santacoder)\n- [Phi](https://huggingface.co/microsoft/phi-1_5)\n- [PhiMoe](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct)\n- [Baichuan](https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat)\n- [Falcon](https://huggingface.co/tiiuae/falcon-7b-instruct)\n- [StarCoder 2](https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1)\n- [Qwen 2](https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f)\n- [Qwen 2 VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d)\n- [Qwen 2.5 VL](https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e)\n- [Opt](https://huggingface.co/facebook/opt-6.7b)\n- [T5](https://huggingface.co/google/flan-t5-xxl)\n- [Galactica](https://huggingface.co/facebook/galactica-120b)\n- [SantaCoder](https://huggingface.co/bigcode/santacoder)\n- [Bloom](https://huggingface.co/bigscience/bloom-560m)\n- [Mpt](https://huggingface.co/mosaicml/mpt-7b-instruct)\n- [Gpt2](https://huggingface.co/openai-community/gpt2)\n- [Gpt Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)\n- [Gptj](https://huggingface.co/EleutherAI/gpt-j-6b)\n- [Idefics](https://huggingface.co/HuggingFaceM4/idefics-9b) (Multimodal)\n- [Mllama](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) (Multimodal)\n\n\n\nIf the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:\n\n```python\n# for causal LMs/text-generation models\nAutoModelForCausalLM.from_pretrained(<model>, device_map=\"auto\")\n# or, for text-to-text generation models\nAutoModelForSeq2SeqLM.from_pretrained(<model>, device_map=\"auto\")\n```\n\nIf you wish to serve a supported model that already exists on a local folder, just point to the local folder.\n\n```bash\ntext-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>\n```\n"
  },
  {
    "path": "docs/source/usage_statistics.md",
    "content": "\n# Collection of Usage Statistics\n\nText Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted.\n\nUsage statistics are collected only when TGI is running in a Docker container. This prevents data collection when TGI is run directly on the host machine. The collected data includes startup and shutdown events, as well as a heartbeat signal sent every 15 minutes.\n\n## What data is collected\n\nThe code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs).\nAs of release 2.1.2 this is an example of the data collected:\n\n- From the TGI configuration:\n```json\n{\n  \"event_type\": \"start\",\n  \"disable_grammar_support\": false,\n  \"max_batch_prefill_tokens\": 4096,\n  \"max_batch_size\": null,\n  \"max_batch_total_tokens\": null,\n  \"max_best_of\": 2,\n  \"max_client_batch_size\": 4,\n  \"max_concurrent_requests\": 128,\n  \"max_input_tokens\": 1024,\n  \"max_stop_sequences\": 4,\n  \"max_top_n_tokens\": 5,\n  \"max_total_tokens\": 2048,\n  \"max_waiting_tokens\": 20,\n  \"model_config\": {\n    \"model_type\": \"Bloom\"\n  },\n  \"revision\": null,\n  \"tokenizer_class\": \"BloomTokenizerFast\",\n  \"validation_workers\": 2,\n  \"waiting_served_ratio\": 1.2,\n  \"docker_label\": \"latest\",\n  \"git_sha\": \"cfc118704880453d29bcbe4fbbd91dda501cf5fe\",\n  \"nvidia_env\": {\n    \"name\": \"NVIDIA A10G\",\n    \"pci_bus_id\": \"00000000:00:1E.0\",\n    \"driver_version\": \"535.183.01\",\n    \"pstate\": \"P8\",\n    \"pcie_link_gen_max\": \"4\",\n    \"pcie_link_gen_current\": \"1\",\n    \"temperature_gpu\": \"31\",\n    \"utilization_gpu\": \"0 %\",\n    \"utilization_memory\": \"0 %\",\n    \"memory_total\": \"23028 MiB\",\n    \"memory_free\": \"22515 MiB\",\n    \"memory_used\": \"0 MiB\",\n    \"reset_status_reset_required\": \"No\",\n    \"reset_status_drain_and_reset_recommended\": \"No\",\n    \"compute_cap\": \"8.6\",\n    \"ecc_errors_corrected_volatile_total\": \"0\",\n    \"mig_mode_current\": \"[N/A]\",\n    \"power_draw_instant\": \"10.86 W\",\n    \"power_limit\": \"300.00 W\"\n  },\n  \"system_env\": {\n    \"cpu_count\": 16,\n    \"cpu_type\": \"AMD EPYC 7R32\",\n    \"total_memory\": 66681196544,\n    \"architecture\": \"x86_64\",\n    \"platform\": \"linux-unix-x86_64\"\n  }\n}\n\n```\n\n## How to opt-out\n\nBy passing the `--usage-stats` to the text-generation-launcher you can control how much usage statistics are being collected.\n`--usage-stats=no-stack` will not emit the stack traces from errors and the error types, but will continue to send start and stop events\n`--usage-stats=off` will completely disable everything\n"
  },
  {
    "path": "flake.nix",
    "content": "{\n  inputs = {\n    crate2nix = {\n      url = \"github:nix-community/crate2nix\";\n      inputs.nixpkgs.follows = \"hf-nix/nixpkgs\";\n    };\n    nix-filter.url = \"github:numtide/nix-filter\";\n    hf-nix.url = \"github:huggingface/hf-nix\";\n    nixpkgs.follows = \"hf-nix/nixpkgs\";\n    flake-utils.url = \"github:numtide/flake-utils\";\n    rust-overlay = {\n      url = \"github:oxalica/rust-overlay\";\n      inputs.nixpkgs.follows = \"hf-nix/nixpkgs\";\n    };\n  };\n  outputs =\n    {\n      self,\n      crate2nix,\n      nix-filter,\n      nixpkgs,\n      flake-utils,\n      rust-overlay,\n      hf-nix,\n    }:\n    flake-utils.lib.eachDefaultSystem (\n      system:\n      let\n        cargoNix = crate2nix.tools.${system}.appliedCargoNix {\n          name = \"tgi\";\n          src = ./.;\n          additionalCargoNixArgs = [ \"--all-features\" ];\n        };\n        pkgs = import nixpkgs {\n          inherit system;\n          inherit (hf-nix.lib) config;\n          overlays = [\n            rust-overlay.overlays.default\n            hf-nix.overlays.default\n            (import nix/overlay.nix)\n          ];\n        };\n        crateOverrides = import ./nix/crate-overrides.nix { inherit pkgs nix-filter; };\n        benchmark = cargoNix.workspaceMembers.text-generation-benchmark.build.override {\n          inherit crateOverrides;\n        };\n        launcher =\n          let\n            launcherUnwrapped = cargoNix.workspaceMembers.text-generation-launcher.build.override {\n              inherit crateOverrides;\n            };\n            packagePath =\n              with pkgs.python3.pkgs;\n              makePythonPath [\n                torch\n              ];\n          in\n          pkgs.writeShellApplication {\n            name = \"text-generation-launcher\";\n            text = ''\n              PYTHONPATH=\"${packagePath}\" ${launcherUnwrapped}/bin/text-generation-launcher \"$@\"\n            '';\n          };\n\n        router =\n          let\n            routerUnwrapped = cargoNix.workspaceMembers.text-generation-router-v3.build.override {\n              inherit crateOverrides;\n            };\n            packagePath =\n              with pkgs.python3.pkgs;\n              makePythonPath [\n                protobuf\n                sentencepiece\n                torch\n                transformers\n              ];\n          in\n          pkgs.writeShellApplication {\n            name = \"text-generation-router\";\n            text = ''\n              PYTHONPATH=\"${packagePath}\" ${routerUnwrapped}/bin/text-generation-router \"$@\"\n            '';\n          };\n        server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };\n        client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };\n      in\n      {\n        checks = {\n          rust =\n            with pkgs;\n            rustPlatform.buildRustPackage {\n              name = \"rust-checks\";\n              src = ./.;\n              cargoLock = {\n                lockFile = ./Cargo.lock;\n              };\n              buildInputs = [ openssl.dev ];\n              nativeBuildInputs = [\n                clippy\n                pkg-config\n                protobuf\n                python3\n                rustfmt\n              ];\n              buildPhase = ''\n                cargo check\n              '';\n              checkPhase = ''\n                cargo fmt -- --check\n                cargo test -j $NIX_BUILD_CORES\n                cargo clippy\n              '';\n              installPhase = \"touch $out\";\n            };\n        };\n        formatter = pkgs.nixfmt-rfc-style;\n        devShells = with pkgs; rec {\n          default = pure;\n\n          pure = mkShell {\n            buildInputs = [\n              benchmark\n              launcher\n              router\n              server\n            ];\n          };\n          test = mkShell {\n            buildInputs =\n              [\n                benchmark\n                launcher\n                router\n                server\n                client\n                openssl.dev\n                pkg-config\n                cargo\n                rustfmt\n                clippy\n              ]\n              ++ (with python3.pkgs; [\n                docker\n                pytest\n                pytest-asyncio\n                syrupy\n                pre-commit\n                ruff\n              ]);\n          };\n\n          impure = callPackage ./nix/impure-shell.nix { inherit server; };\n\n          impureWithCuda = callPackage ./nix/impure-shell.nix {\n            inherit server;\n            withCuda = true;\n          };\n\n          impure-flash-attn-v1 = callPackage ./nix/impure-shell.nix {\n            server = server.override { flash-attn = python3.pkgs.flash-attn-v1; };\n          };\n        };\n\n        packages = rec {\n          inherit server;\n\n          default = pkgs.writeShellApplication {\n            name = \"text-generation-inference\";\n            runtimeInputs = [\n              server\n              router\n            ];\n            text = ''\n              ${launcher}/bin/text-generation-launcher \"$@\"\n            '';\n          };\n\n          # Use plain nixpkgs without overlays for dockerTools. dockerTools\n          # uses a Python package for computing the layers from the transitive\n          # closure. However, this needs a lot of rebuilds due to our overlay.\n\n          dockerImage = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {\n            text-generation-inference = default;\n          };\n\n          dockerImageStreamed = nixpkgs.legacyPackages.${system}.callPackage nix/docker.nix {\n            text-generation-inference = default;\n            stream = true;\n          };\n        };\n      }\n    );\n}\n"
  },
  {
    "path": "integration-tests/conftest.py",
    "content": "pytest_plugins = [\n    \"fixtures.neuron.service\",\n    \"fixtures.neuron.export_models\",\n    \"fixtures.gaudi.service\",\n]\n# ruff: noqa: E402\nfrom _pytest.fixtures import SubRequest\nfrom huggingface_hub.inference._generated.types.chat_completion import (\n    ChatCompletionStreamOutput,\n    ChatCompletionOutput,\n)\nfrom openai.types.chat.chat_completion_chunk import (\n    ChatCompletionChunk as OAIChatCompletionChunk,\n)\nfrom openai.types.completion import Completion as OAICompletion\nimport requests\n\n\nclass SessionTimeoutFix(requests.Session):\n    def request(self, *args, **kwargs):\n        timeout = kwargs.pop(\"timeout\", 120)\n        return super().request(*args, **kwargs, timeout=timeout)\n\n\nrequests.sessions.Session = SessionTimeoutFix\n\nimport warnings\nimport asyncio\nimport contextlib\nimport json\nimport math\nimport os\nimport random\nimport subprocess\nimport sys\nimport tempfile\nimport time\nimport docker\nimport pytest\nimport base64\n\nfrom pathlib import Path\nfrom typing import Dict, List, Optional\nfrom aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError\nfrom docker.errors import NotFound\nfrom syrupy.extensions.json import JSONSnapshotExtension\nfrom text_generation import AsyncClient\nfrom text_generation.types import (\n    BestOfSequence,\n    Message,\n    ChatComplete,\n    ChatCompletionChunk,\n    ChatCompletionComplete,\n    Completion,\n    Details,\n    Grammar,\n    InputToken,\n    Response,\n    Token,\n)\n\nDOCKER_IMAGE = os.getenv(\"DOCKER_IMAGE\", None)\nHF_TOKEN = os.getenv(\"HF_TOKEN\", None)\nDOCKER_VOLUME = os.getenv(\"DOCKER_VOLUME\", \"/data\")\nDOCKER_DEVICES = os.getenv(\"DOCKER_DEVICES\")\n\n\ndef pytest_addoption(parser):\n    parser.addoption(\n        \"--release\", action=\"store_true\", default=False, help=\"run release tests\"\n    )\n    parser.addoption(\n        \"--neuron\", action=\"store_true\", default=False, help=\"run neuron tests\"\n    )\n    parser.addoption(\n        \"--gaudi\", action=\"store_true\", default=False, help=\"run gaudi tests\"\n    )\n    parser.addoption(\n        \"--gaudi-all-models\",\n        action=\"store_true\",\n        default=False,\n        help=\"Run tests for all models instead of just the default subset\",\n    )\n\n\ndef pytest_configure(config):\n    config.addinivalue_line(\"markers\", \"release: mark test as a release-only test\")\n    config.addinivalue_line(\"markers\", \"neuron: mark test as a neuron test\")\n\n\ndef pytest_collection_modifyitems(config, items):\n    selectors = []\n    if not config.getoption(\"--release\"):\n        # --release not given in cli: skip release tests\n        def skip_release(item):\n            if \"release\" in item.keywords:\n                item.add_marker(pytest.mark.skip(reason=\"need --release option to run\"))\n\n        selectors.append(skip_release)\n\n    if config.getoption(\"--gaudi\"):\n\n        def skip_not_gaudi(item):\n            if \"gaudi\" not in item.keywords:\n                item.add_marker(pytest.mark.skip(reason=\"requires --gaudi to run\"))\n\n        selectors.append(skip_not_gaudi)\n    else:\n\n        def skip_gaudi(item):\n            if \"gaudi\" in item.keywords:\n                item.add_marker(pytest.mark.skip(reason=\"requires --gaudi to run\"))\n\n        selectors.append(skip_gaudi)\n\n    if config.getoption(\"--neuron\"):\n\n        def skip_not_neuron(item):\n            if \"neuron\" not in item.keywords:\n                item.add_marker(\n                    pytest.mark.skip(reason=\"incompatible with --neuron option\")\n                )\n\n        selectors.append(skip_not_neuron)\n    else:\n\n        def skip_neuron(item):\n            if \"neuron\" in item.keywords:\n                item.add_marker(pytest.mark.skip(reason=\"requires --neuron to run\"))\n\n        selectors.append(skip_neuron)\n\n    for item in items:\n        for selector in selectors:\n            selector(item)\n\n\n@pytest.fixture(autouse=True, scope=\"module\")\ndef container_log(request: SubRequest):\n    error_log = request.getfixturevalue(\"error_log\")\n    assert error_log is not None\n    yield\n    if request.session.testsfailed:\n        error_log.seek(0)\n        print(error_log.read(), file=sys.stderr)\n    else:\n        error_log.truncate(0)\n        error_log.seek(0)\n\n\nclass ResponseComparator(JSONSnapshotExtension):\n    rtol = 0.2\n    ignore_logprob = False\n\n    def _serialize(\n        self,\n        data,\n    ):\n        if (\n            isinstance(data, Response)\n            or isinstance(data, ChatComplete)\n            or isinstance(data, ChatCompletionChunk)\n            or isinstance(data, ChatCompletionComplete)\n            or isinstance(data, Completion)\n            or isinstance(data, OAIChatCompletionChunk)\n            or isinstance(data, OAICompletion)\n        ):\n            data = data.model_dump()\n        elif isinstance(data, ChatCompletionStreamOutput) or isinstance(\n            data, ChatCompletionOutput\n        ):\n            data = dict(data)\n        elif isinstance(data, List):\n            data = [self._serialize(d) for d in data]\n        elif isinstance(data, dict):\n            return data\n        else:\n            raise RuntimeError(f\"Unexpected data {type(data)} : {data}\")\n        return data\n\n    def serialize(\n        self,\n        data,\n        *,\n        include=None,\n        exclude=None,\n        matcher=None,\n    ):\n        data = self._serialize(data)\n        data = self._filter(\n            data=data,\n            depth=0,\n            path=(),\n            exclude=exclude,\n            include=include,\n            matcher=matcher,\n        )\n        data = json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + \"\\n\"\n        return data\n\n    def matches(\n        self,\n        *,\n        serialized_data,\n        snapshot_data,\n    ) -> bool:\n        def convert_data(data):\n            data = json.loads(data)\n            return _convert_data(data)\n\n        def _convert_data(data):\n            if isinstance(data, Dict):\n                if \"choices\" in data:\n                    data[\"choices\"] = list(\n                        sorted(data[\"choices\"], key=lambda x: int(x[\"index\"]))\n                    )\n                    choices = data[\"choices\"]\n                    if isinstance(choices, List) and len(choices) >= 1:\n                        if \"delta\" in choices[0]:\n                            return ChatCompletionChunk(**data)\n                        if \"text\" in choices[0]:\n                            return Completion(**data)\n                    return ChatComplete(**data)\n                else:\n                    return Response(**data)\n            if isinstance(data, List):\n                return [_convert_data(d) for d in data]\n            raise NotImplementedError(f\"Data: {data}\")\n\n        def eq_token(token: Token, other: Token) -> bool:\n            return (\n                token.id == other.id\n                and token.text == other.text\n                and (\n                    self.ignore_logprob\n                    or (token.logprob == other.logprob and token.logprob is None)\n                    or math.isclose(token.logprob, other.logprob, rel_tol=self.rtol)\n                )\n                and token.special == other.special\n            )\n\n        def eq_prefill_token(prefill_token: InputToken, other: InputToken) -> bool:\n            try:\n                return (\n                    prefill_token.id == other.id\n                    and prefill_token.text == other.text\n                    and (\n                        self.ignore_logprob\n                        or math.isclose(\n                            prefill_token.logprob,\n                            other.logprob,\n                            rel_tol=self.rtol,\n                        )\n                        if prefill_token.logprob is not None\n                        else prefill_token.logprob == other.logprob\n                    )\n                )\n            except TypeError:\n                return False\n\n        def eq_best_of(details: BestOfSequence, other: BestOfSequence) -> bool:\n            return (\n                details.finish_reason == other.finish_reason\n                and details.generated_tokens == other.generated_tokens\n                and details.seed == other.seed\n                and len(details.prefill) == len(other.prefill)\n                and all(\n                    [\n                        eq_prefill_token(d, o)\n                        for d, o in zip(details.prefill, other.prefill)\n                    ]\n                )\n                and len(details.tokens) == len(other.tokens)\n                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])\n            )\n\n        def eq_details(details: Details, other: Details) -> bool:\n            return (\n                details.finish_reason == other.finish_reason\n                and details.generated_tokens == other.generated_tokens\n                and details.seed == other.seed\n                and len(details.prefill) == len(other.prefill)\n                and all(\n                    [\n                        eq_prefill_token(d, o)\n                        for d, o in zip(details.prefill, other.prefill)\n                    ]\n                )\n                and len(details.tokens) == len(other.tokens)\n                and all([eq_token(d, o) for d, o in zip(details.tokens, other.tokens)])\n                and (\n                    len(details.best_of_sequences)\n                    if details.best_of_sequences is not None\n                    else 0\n                )\n                == (\n                    len(other.best_of_sequences)\n                    if other.best_of_sequences is not None\n                    else 0\n                )\n                and (\n                    all(\n                        [\n                            eq_best_of(d, o)\n                            for d, o in zip(\n                                details.best_of_sequences, other.best_of_sequences\n                            )\n                        ]\n                    )\n                    if details.best_of_sequences is not None\n                    else details.best_of_sequences == other.best_of_sequences\n                )\n            )\n\n        def eq_completion(response: Completion, other: Completion) -> bool:\n            return response.choices[0].text == other.choices[0].text\n\n        def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:\n            return (\n                response.choices[0].message.content == other.choices[0].message.content\n            )\n\n        def eq_chat_complete_chunk(\n            response: ChatCompletionChunk, other: ChatCompletionChunk\n        ) -> bool:\n            if response.choices:\n                if response.choices[0].delta.content is not None:\n                    return (\n                        response.choices[0].delta.content\n                        == other.choices[0].delta.content\n                    )\n                elif response.choices[0].delta.tool_calls is not None:\n                    return (\n                        response.choices[0].delta.tool_calls\n                        == other.choices[0].delta.tool_calls\n                    )\n                else:\n                    raise RuntimeError(\n                        f\"Invalid empty chat chunk {response} vs {other}\"\n                    )\n            elif response.usage is not None:\n                return response.usage == other.usage\n            else:\n                raise RuntimeError(f\"Invalid empty chat {response} vs {other}\")\n\n        def eq_response(response: Response, other: Response) -> bool:\n            return response.generated_text == other.generated_text and eq_details(\n                response.details, other.details\n            )\n\n        serialized_data = convert_data(serialized_data)\n        snapshot_data = convert_data(snapshot_data)\n\n        if not isinstance(serialized_data, List):\n            serialized_data = [serialized_data]\n        if not isinstance(snapshot_data, List):\n            snapshot_data = [snapshot_data]\n\n        if len(serialized_data) == 0:\n            return len(snapshot_data) == len(serialized_data)\n\n        if isinstance(serialized_data[0], Completion):\n            return len(snapshot_data) == len(serialized_data) and all(\n                [eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]\n            )\n\n        if isinstance(serialized_data[0], ChatComplete):\n            return len(snapshot_data) == len(serialized_data) and all(\n                [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]\n            )\n\n        if isinstance(serialized_data[0], ChatCompletionChunk):\n            return len(snapshot_data) == len(serialized_data) and all(\n                [\n                    eq_chat_complete_chunk(r, o)\n                    for r, o in zip(serialized_data, snapshot_data)\n                ]\n            )\n\n        return len(snapshot_data) == len(serialized_data) and all(\n            [eq_response(r, o) for r, o in zip(serialized_data, snapshot_data)]\n        )\n\n\nclass GenerousResponseComparator(ResponseComparator):\n    # Needed for GPTQ with exllama which has serious numerical fluctuations.\n    rtol = 0.75\n\n\nclass IgnoreLogProbResponseComparator(ResponseComparator):\n    ignore_logprob = True\n\n\nclass LauncherHandle:\n    def __init__(self, port: int, error_log):\n        with warnings.catch_warnings(action=\"ignore\"):\n            self.client = AsyncClient(f\"http://localhost:{port}\", timeout=30)\n        self.error_log = error_log\n\n    def _inner_health(self):\n        raise NotImplementedError\n\n    async def health(self, timeout: int = 60):\n        assert timeout > 0\n        for _ in range(timeout):\n            if not self._inner_health():\n                self.error_log.seek(0)\n                print(self.error_log.read(), file=sys.stderr)\n                raise RuntimeError(\"Launcher crashed\")\n\n            try:\n                await self.client.generate(\"test\")\n                return\n            except (ClientConnectorError, ClientOSError, ServerDisconnectedError):\n                time.sleep(1)\n        self.error_log.seek(0)\n        print(self.error_log.read(), file=sys.stderr)\n        raise RuntimeError(\"Health check failed\")\n\n\nclass ContainerLauncherHandle(LauncherHandle):\n    def __init__(self, docker_client, container_name, port: int, error_log):\n        super().__init__(port, error_log)\n        self.docker_client = docker_client\n        self.container_name = container_name\n\n    def _inner_health(self) -> bool:\n        container = self.docker_client.containers.get(self.container_name)\n        return container.status in [\"running\", \"created\"]\n\n\nclass ProcessLauncherHandle(LauncherHandle):\n    def __init__(self, process, port: int, error_log):\n        super().__init__(port, error_log)\n        self.process = process\n\n    def _inner_health(self) -> bool:\n        return self.process.poll() is None\n\n\n@pytest.fixture\ndef response_snapshot(snapshot):\n    return snapshot.use_extension(ResponseComparator)\n\n\n@pytest.fixture\ndef generous_response_snapshot(snapshot):\n    return snapshot.use_extension(GenerousResponseComparator)\n\n\n@pytest.fixture\ndef ignore_logprob_response_snapshot(snapshot):\n    return snapshot.use_extension(IgnoreLogProbResponseComparator)\n\n\n@pytest.fixture(scope=\"session\")\ndef error_log():\n    with tempfile.TemporaryFile(\"w+\") as tmp:\n        yield tmp\n\n\n@pytest.fixture(scope=\"session\")\nasync def launcher(error_log):\n    @contextlib.contextmanager\n    def local_launcher(\n        model_id: str,\n        num_shard: Optional[int] = None,\n        quantize: Optional[str] = None,\n        trust_remote_code: bool = False,\n        use_flash_attention: bool = True,\n        disable_grammar_support: bool = False,\n        dtype: Optional[str] = None,\n        kv_cache_dtype: Optional[str] = None,\n        revision: Optional[str] = None,\n        max_input_length: Optional[int] = None,\n        max_input_tokens: Optional[int] = None,\n        max_batch_prefill_tokens: Optional[int] = None,\n        max_total_tokens: Optional[int] = None,\n        lora_adapters: Optional[List[str]] = None,\n        cuda_graphs: Optional[List[int]] = None,\n        attention: Optional[str] = None,\n    ):\n        port = random.randint(8000, 10_000)\n        master_port = random.randint(10_000, 20_000)\n\n        shard_uds_path = (\n            f\"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server\"\n        )\n\n        args = [\n            \"text-generation-launcher\",\n            \"--model-id\",\n            model_id,\n            \"--port\",\n            str(port),\n            \"--master-port\",\n            str(master_port),\n            \"--shard-uds-path\",\n            shard_uds_path,\n        ]\n\n        env = os.environ\n\n        if disable_grammar_support:\n            args.append(\"--disable-grammar-support\")\n        if num_shard is not None:\n            args.extend([\"--num-shard\", str(num_shard)])\n        if quantize is not None:\n            args.append(\"--quantize\")\n            args.append(quantize)\n        if dtype is not None:\n            args.append(\"--dtype\")\n            args.append(dtype)\n        if kv_cache_dtype is not None:\n            args.append(\"--kv-cache-dtype\")\n            args.append(kv_cache_dtype)\n        if revision is not None:\n            args.append(\"--revision\")\n            args.append(revision)\n        if trust_remote_code:\n            args.append(\"--trust-remote-code\")\n        if max_input_length:\n            args.append(\"--max-input-length\")\n            args.append(str(max_input_length))\n        if max_input_tokens:\n            args.append(\"--max-input-tokens\")\n            args.append(str(max_input_tokens))\n        if max_batch_prefill_tokens:\n            args.append(\"--max-batch-prefill-tokens\")\n            args.append(str(max_batch_prefill_tokens))\n        if max_total_tokens:\n            args.append(\"--max-total-tokens\")\n            args.append(str(max_total_tokens))\n        if lora_adapters:\n            args.append(\"--lora-adapters\")\n            args.append(\",\".join(lora_adapters))\n        if cuda_graphs:\n            args.append(\"--cuda-graphs\")\n            args.append(\",\".join(map(str, cuda_graphs)))\n\n        print(\" \".join(args), file=sys.stderr)\n\n        env[\"LOG_LEVEL\"] = \"info,text_generation_router=debug\"\n        env[\"PREFILL_CHUNKING\"] = \"1\"\n\n        if not use_flash_attention:\n            env[\"USE_FLASH_ATTENTION\"] = \"false\"\n        if attention is not None:\n            env[\"ATTENTION\"] = attention\n\n            # with tempfile.TemporaryFile(\"w+\") as tmp:\n            # We'll output stdout/stderr to a temporary file. Using a pipe\n            # cause the process to block until stdout is read.\n        with subprocess.Popen(\n            args,\n            stdout=error_log,\n            stderr=subprocess.STDOUT,\n            env=env,\n        ) as process:\n            yield ProcessLauncherHandle(process, port, error_log=error_log)\n\n            process.terminate()\n            process.wait(60)\n\n        if not use_flash_attention:\n            del env[\"USE_FLASH_ATTENTION\"]\n\n    @contextlib.contextmanager\n    def docker_launcher(\n        model_id: str,\n        num_shard: Optional[int] = None,\n        quantize: Optional[str] = None,\n        trust_remote_code: bool = False,\n        use_flash_attention: bool = True,\n        disable_grammar_support: bool = False,\n        dtype: Optional[str] = None,\n        kv_cache_dtype: Optional[str] = None,\n        revision: Optional[str] = None,\n        max_input_length: Optional[int] = None,\n        max_batch_prefill_tokens: Optional[int] = None,\n        max_total_tokens: Optional[int] = None,\n        lora_adapters: Optional[List[str]] = None,\n        cuda_graphs: Optional[List[int]] = None,\n        attention: Optional[str] = None,\n    ):\n        port = random.randint(8000, 10_000)\n\n        args = [\"--model-id\", model_id, \"--env\"]\n\n        if disable_grammar_support:\n            args.append(\"--disable-grammar-support\")\n        if num_shard is not None:\n            args.extend([\"--num-shard\", str(num_shard)])\n        if quantize is not None:\n            args.append(\"--quantize\")\n            args.append(quantize)\n        if dtype is not None:\n            args.append(\"--dtype\")\n            args.append(dtype)\n        if kv_cache_dtype is not None:\n            args.append(\"--kv-cache-dtype\")\n            args.append(kv_cache_dtype)\n        if revision is not None:\n            args.append(\"--revision\")\n            args.append(revision)\n        if trust_remote_code:\n            args.append(\"--trust-remote-code\")\n        if max_input_length:\n            args.append(\"--max-input-length\")\n            args.append(str(max_input_length))\n        if max_batch_prefill_tokens:\n            args.append(\"--max-batch-prefill-tokens\")\n            args.append(str(max_batch_prefill_tokens))\n        if max_total_tokens:\n            args.append(\"--max-total-tokens\")\n            args.append(str(max_total_tokens))\n        if lora_adapters:\n            args.append(\"--lora-adapters\")\n            args.append(\",\".join(lora_adapters))\n        if cuda_graphs:\n            args.append(\"--cuda-graphs\")\n            args.append(\",\".join(map(str, cuda_graphs)))\n\n        client = docker.from_env()\n\n        container_name = f\"tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}\"\n\n        try:\n            container = client.containers.get(container_name)\n            container.stop()\n            container.remove()\n            container.wait()\n        except NotFound:\n            pass\n\n        gpu_count = num_shard if num_shard is not None else 1\n\n        env = {\n            \"LOG_LEVEL\": \"info,text_generation_router=debug\",\n            \"PREFILL_CHUNKING\": \"1\",\n        }\n        if not use_flash_attention:\n            env[\"USE_FLASH_ATTENTION\"] = \"false\"\n        if attention is not None:\n            env[\"ATTENTION\"] = attention\n\n        if HF_TOKEN is not None:\n            env[\"HF_TOKEN\"] = HF_TOKEN\n\n        volumes = []\n        if DOCKER_VOLUME:\n            volumes = [f\"{DOCKER_VOLUME}:/data\"]\n\n        if DOCKER_DEVICES:\n            if DOCKER_DEVICES.lower() == \"none\":\n                devices = []\n            else:\n                devices = DOCKER_DEVICES.strip().split(\",\")\n            visible = os.getenv(\"ROCR_VISIBLE_DEVICES\")\n            if visible:\n                env[\"ROCR_VISIBLE_DEVICES\"] = visible\n            device_requests = []\n            if not devices:\n                devices = None\n            elif devices == [\"nvidia.com/gpu=all\"]:\n                devices = None\n                device_requests = [\n                    docker.types.DeviceRequest(\n                        driver=\"cdi\",\n                        # count=gpu_count,\n                        device_ids=[f\"nvidia.com/gpu={i}\"],\n                    )\n                    for i in range(gpu_count)\n                ]\n        else:\n            devices = None\n            device_requests = [\n                docker.types.DeviceRequest(count=gpu_count, capabilities=[[\"gpu\"]])\n            ]\n\n        client.api.timeout = 1000\n        container = client.containers.run(\n            DOCKER_IMAGE,\n            command=args,\n            name=container_name,\n            environment=env,\n            auto_remove=False,\n            detach=True,\n            device_requests=device_requests,\n            devices=devices,\n            volumes=volumes,\n            ports={\"80/tcp\": port},\n            healthcheck={\"timeout\": int(180 * 1e9), \"retries\": 2},  # 60s\n            shm_size=\"1G\",\n        )\n\n        def pipe():\n            for log in container.logs(stream=True):\n                log = log.decode(\"utf-8\")\n                error_log.write(log)\n\n        # Start looping to pipe the logs\n        import threading\n\n        t = threading.Thread(target=pipe, args=())\n        t.start()\n\n        try:\n            yield ContainerLauncherHandle(\n                client, container.name, port, error_log=error_log\n            )\n\n            if not use_flash_attention:\n                del env[\"USE_FLASH_ATTENTION\"]\n\n            try:\n                container.stop()\n                container.wait()\n            except NotFound:\n                pass\n\n        finally:\n            try:\n                container.remove()\n            except Exception:\n                pass\n\n    if DOCKER_IMAGE is not None:\n        return docker_launcher\n    return local_launcher\n\n\n@pytest.fixture(scope=\"module\")\ndef generate_load():\n    async def generate_load_inner(\n        client: AsyncClient,\n        prompt: str,\n        max_new_tokens: int,\n        n: int,\n        seed: Optional[int] = None,\n        grammar: Optional[Grammar] = None,\n        stop_sequences: Optional[List[str]] = None,\n    ) -> List[Response]:\n        futures = [\n            client.generate(\n                prompt,\n                max_new_tokens=max_new_tokens,\n                decoder_input_details=True,\n                seed=seed,\n                grammar=grammar,\n                stop_sequences=stop_sequences,\n            )\n            for _ in range(n)\n        ]\n\n        return await asyncio.gather(*futures)\n\n    return generate_load_inner\n\n\n@pytest.fixture(scope=\"module\")\ndef generate_multi():\n    async def generate_load_inner(\n        client: AsyncClient,\n        prompts: List[str],\n        max_new_tokens: int,\n        seed: Optional[int] = None,\n    ) -> List[Response]:\n        import numpy as np\n\n        arange = np.arange(len(prompts))\n        perm = np.random.permutation(arange)\n        rperm = [-1] * len(perm)\n        for i, p in enumerate(perm):\n            rperm[p] = i\n\n        shuffled_prompts = [prompts[p] for p in perm]\n        futures = [\n            client.chat(\n                messages=[Message(role=\"user\", content=prompt)],\n                max_tokens=max_new_tokens,\n                temperature=0,\n                seed=seed,\n            )\n            for prompt in shuffled_prompts\n        ]\n\n        shuffled_responses = await asyncio.gather(*futures)\n        responses = [shuffled_responses[p] for p in rperm]\n        return responses\n\n    return generate_load_inner\n\n\n# TODO fix the server parsser to count inline image tokens correctly\n@pytest.fixture\ndef chicken():\n    path = Path(__file__).parent / \"images\" / \"chicken_on_money.png\"\n\n    with open(path, \"rb\") as image_file:\n        encoded_string = base64.b64encode(image_file.read())\n    return f\"data:image/png;base64,{encoded_string.decode('utf-8')}\"\n\n\n@pytest.fixture\ndef cow_beach():\n    path = Path(__file__).parent / \"images\" / \"cow_beach.png\"\n\n    with open(path, \"rb\") as image_file:\n        encoded_string = base64.b64encode(image_file.read())\n    return f\"data:image/png;base64,{encoded_string.decode('utf-8')}\"\n"
  },
  {
    "path": "integration-tests/fixtures/gaudi/service.py",
    "content": "import asyncio\nimport contextlib\nimport os\nimport shlex\nimport subprocess\nimport sys\nimport threading\nimport time\nfrom tempfile import TemporaryDirectory\nfrom typing import List\nimport socket\n\nimport docker\nimport pytest\nfrom aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError\nfrom docker.errors import NotFound\nimport logging\nfrom huggingface_hub import AsyncInferenceClient, TextGenerationOutput\nimport huggingface_hub\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>\",\n    stream=sys.stdout,\n)\nlogger = logging.getLogger(__file__)\n\n# Use the latest image from the local docker build\nDOCKER_IMAGE = os.getenv(\"DOCKER_IMAGE\", \"tgi-gaudi\")\nDOCKER_VOLUME = os.getenv(\"DOCKER_VOLUME\", None)\nHF_TOKEN = huggingface_hub.get_token()\n\nassert (\n    HF_TOKEN is not None\n), \"HF_TOKEN is not set, please set it as some models are gated and thus the test will fail without it\"\n\nif DOCKER_VOLUME is None:\n    logger.warning(\n        \"DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing\"\n    )\n\nLOG_LEVEL = os.getenv(\"LOG_LEVEL\", \"info\")\n\nBASE_ENV = {\n    \"HF_HUB_ENABLE_HF_TRANSFER\": \"1\",\n    \"LOG_LEVEL\": LOG_LEVEL,\n    \"HF_TOKEN\": os.getenv(\"HF_TOKEN\", None),\n}\n\n\nHABANA_RUN_ARGS = {\n    \"runtime\": \"habana\",\n    \"ipc_mode\": \"host\",\n    \"cap_add\": [\"sys_nice\"],\n}\n\n\ndef stream_container_logs(container, test_name):\n    \"\"\"Stream container logs in a separate thread.\"\"\"\n    try:\n        for log in container.logs(stream=True, follow=True):\n            print(\n                f\"[TGI Server Logs - {test_name}] {log.decode('utf-8')}\",\n                end=\"\",\n                file=sys.stderr,\n                flush=True,\n            )\n    except Exception as e:\n        logger.error(f\"Error streaming container logs: {str(e)}\")\n\n\nclass TestClient(AsyncInferenceClient):\n    def __init__(self, service_name: str, base_url: str):\n        super().__init__(model=base_url)\n        self.service_name = service_name\n\n\nclass LauncherHandle:\n    def __init__(self, service_name: str, port: int):\n        self.client = TestClient(service_name, f\"http://localhost:{port}\")\n\n    def _inner_health(self):\n        raise NotImplementedError\n\n    async def health(self, timeout: int = 60):\n        assert timeout > 0\n        start_time = time.time()\n        logger.info(f\"Starting health check with timeout of {timeout}s\")\n\n        for attempt in range(timeout):\n            if not self._inner_health():\n                logger.error(\"Launcher crashed during health check\")\n                raise RuntimeError(\"Launcher crashed\")\n\n            try:\n                await self.client.text_generation(\"test\", max_new_tokens=1)\n                elapsed = time.time() - start_time\n                logger.info(f\"Health check passed after {elapsed:.1f}s\")\n                return\n            except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:\n                if attempt == timeout - 1:\n                    logger.error(f\"Health check failed after {timeout}s: {str(e)}\")\n                    raise RuntimeError(f\"Health check failed: {str(e)}\")\n                if attempt % 10 == 0 and attempt != 0:  # Only log every 10th attempt\n                    logger.debug(\n                        f\"Connection attempt {attempt}/{timeout} failed: {str(e)}\"\n                    )\n                time.sleep(1)\n            except Exception as e:\n                logger.error(f\"Unexpected error during health check: {str(e)}\")\n                # Get full traceback for debugging\n                import traceback\n\n                logger.error(f\"Full traceback:\\n{traceback.format_exc()}\")\n                raise\n\n\nclass ContainerLauncherHandle(LauncherHandle):\n    def __init__(self, docker_client, container_name, port: int):\n        service_name = container_name  # Use container name as service name\n        super(ContainerLauncherHandle, self).__init__(service_name, port)\n        self.docker_client = docker_client\n        self.container_name = container_name\n\n    def _inner_health(self) -> bool:\n        try:\n            container = self.docker_client.containers.get(self.container_name)\n            status = container.status\n            if status not in [\"running\", \"created\"]:\n                logger.warning(f\"Container status is {status}\")\n                # Get container logs for debugging\n                logs = container.logs().decode(\"utf-8\")\n                logger.debug(f\"Container logs:\\n{logs}\")\n            return status in [\"running\", \"created\"]\n        except Exception as e:\n            logger.error(f\"Error checking container health: {str(e)}\")\n            return False\n\n\nclass ProcessLauncherHandle(LauncherHandle):\n    def __init__(self, process, port: int):\n        service_name = \"process\"  # Use generic name for process launcher\n        super(ProcessLauncherHandle, self).__init__(service_name, port)\n        self.process = process\n\n    def _inner_health(self) -> bool:\n        return self.process.poll() is None\n\n\n@pytest.fixture(scope=\"module\")\ndef data_volume():\n    tmpdir = TemporaryDirectory()\n    yield tmpdir.name\n    try:\n        # Cleanup the temporary directory using sudo as it contains root files created by the container\n        subprocess.run(shlex.split(f\"sudo rm -rf {tmpdir.name}\"), check=True)\n    except subprocess.CalledProcessError as e:\n        logger.error(f\"Error cleaning up temporary directory: {str(e)}\")\n\n\n@pytest.fixture(scope=\"module\")\ndef gaudi_launcher():\n    @contextlib.contextmanager\n    def docker_launcher(\n        model_id: str,\n        test_name: str,\n        tgi_args: List[str] = None,\n        env_config: dict = None,\n    ):\n        logger.info(\n            f\"Starting docker launcher for model {model_id} and test {test_name}\"\n        )\n\n        # Get a random available port\n        def get_free_port():\n            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n                s.bind((\"\", 0))\n                s.listen(1)\n                port = s.getsockname()[1]\n            return port\n\n        port = get_free_port()\n        logger.debug(f\"Using port {port}\")\n\n        client = docker.from_env()\n\n        container_name = f\"tgi-gaudi-test-{test_name.replace('/', '-')}\"\n\n        try:\n            container = client.containers.get(container_name)\n            logger.info(\n                f\"Stopping existing container {container_name} for test {test_name}\"\n            )\n            container.stop()\n            container.wait()\n            container.remove()\n            logger.info(f\"Removed existing container {container_name}\")\n        except NotFound:\n            pass\n        except Exception as e:\n            logger.error(f\"Error handling existing container: {str(e)}\")\n\n        if tgi_args is None:\n            tgi_args = []\n        else:\n            tgi_args = tgi_args.copy()\n\n        env = BASE_ENV.copy()\n\n        # Add model_id to env\n        env[\"MODEL_ID\"] = model_id\n\n        # Add env config that is defined in the fixture parameter\n        if env_config is not None:\n            env.update(env_config.copy())\n\n        volumes = []\n        if DOCKER_VOLUME:\n            volumes = [f\"{DOCKER_VOLUME}:/data\"]\n        logger.debug(f\"Using volume {volumes}\")\n\n        try:\n            logger.debug(f\"Using command {tgi_args}\")\n            logger.info(f\"Creating container with name {container_name}\")\n\n            logger.debug(f\"Using environment {env}\")\n            logger.debug(f\"Using volumes {volumes}\")\n            logger.debug(f\"HABANA_RUN_ARGS {HABANA_RUN_ARGS}\")\n\n            # Log equivalent docker run command for debugging, this is not actually executed\n            container = client.containers.run(\n                DOCKER_IMAGE,\n                command=tgi_args,\n                name=container_name,\n                environment=env,\n                detach=True,\n                volumes=volumes,\n                ports={\"80/tcp\": port},\n                **HABANA_RUN_ARGS,\n            )\n\n            logger.info(f\"Container {container_name} started successfully\")\n\n            # Start log streaming in a background thread\n            log_thread = threading.Thread(\n                target=stream_container_logs,\n                args=(container, test_name),\n                daemon=True,  # This ensures the thread will be killed when the main program exits\n            )\n            log_thread.start()\n\n            # Add a small delay to allow container to initialize\n            time.sleep(2)\n\n            # Check container status after creation\n            status = container.status\n            logger.debug(f\"Initial container status: {status}\")\n            if status not in [\"running\", \"created\"]:\n                logs = container.logs().decode(\"utf-8\")\n                logger.error(f\"Container failed to start properly. Logs:\\n{logs}\")\n\n            yield ContainerLauncherHandle(client, container.name, port)\n\n        except Exception as e:\n            logger.error(f\"Error starting container: {str(e)}\")\n            # Get full traceback for debugging\n            import traceback\n\n            logger.error(f\"Full traceback:\\n{traceback.format_exc()}\")\n            raise\n        finally:\n            try:\n                container = client.containers.get(container_name)\n                logger.info(f\"Stopping container {container_name}\")\n                container.stop()\n                container.wait()\n\n                container_output = container.logs().decode(\"utf-8\")\n                print(container_output, file=sys.stderr)\n\n                container.remove()\n                logger.info(f\"Container {container_name} removed successfully\")\n            except NotFound:\n                pass\n            except Exception as e:\n                logger.warning(f\"Error cleaning up container: {str(e)}\")\n\n    return docker_launcher\n\n\n@pytest.fixture(scope=\"module\")\ndef gaudi_generate_load():\n    async def generate_load_inner(\n        client: AsyncInferenceClient, prompt: str, max_new_tokens: int, n: int\n    ) -> List[TextGenerationOutput]:\n        try:\n            futures = [\n                client.text_generation(\n                    prompt,\n                    max_new_tokens=max_new_tokens,\n                    details=True,\n                    decoder_input_details=True,\n                )\n                for _ in range(n)\n            ]\n            return await asyncio.gather(*futures)\n        except Exception as e:\n            logger.error(f\"Error generating load: {str(e)}\")\n            raise\n\n    return generate_load_inner\n"
  },
  {
    "path": "integration-tests/fixtures/neuron/export_models.py",
    "content": "import copy\nimport logging\nimport sys\nfrom tempfile import TemporaryDirectory\n\nimport huggingface_hub\nimport pytest\nimport docker\nimport hashlib\nimport os\nimport tempfile\n\nfrom docker.errors import NotFound\n\n\nTEST_ORGANIZATION = \"optimum-internal-testing\"\nTEST_CACHE_REPO_ID = f\"{TEST_ORGANIZATION}/neuron-testing-cache\"\nHF_TOKEN = huggingface_hub.get_token()\n\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s\",\n    stream=sys.stdout,\n)\nlogger = logging.getLogger(__file__)\n\n\n# All model configurations below will be added to the neuron_model_config fixture\nMODEL_CONFIGURATIONS = {\n    \"llama\": {\n        \"model_id\": \"unsloth/Llama-3.2-1B-Instruct\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 2048,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"fp16\",\n        },\n    },\n    \"qwen2\": {\n        \"model_id\": \"Qwen/Qwen2.5-0.5B\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"fp16\",\n        },\n    },\n    \"qwen3\": {\n        \"model_id\": \"Qwen/Qwen3-1.7B\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n    \"granite\": {\n        \"model_id\": \"ibm-granite/granite-3.1-2b-instruct\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n    \"phi3\": {\n        \"model_id\": \"microsoft/Phi-3-mini-4k-instruct\",\n        \"export_kwargs\": {\n            \"batch_size\": 4,\n            \"sequence_length\": 4096,\n            \"num_cores\": 2,\n            \"auto_cast_type\": \"bf16\",\n        },\n    },\n}\n\n\ndef get_neuron_backend_hash():\n    import subprocess\n\n    res = subprocess.run(\n        [\"git\", \"rev-parse\", \"--show-toplevel\"], capture_output=True, text=True\n    )\n    root_dir = res.stdout.split(\"\\n\")[0]\n\n    def get_sha(path):\n        res = subprocess.run(\n            [\"git\", \"ls-tree\", \"HEAD\", f\"{root_dir}/{path}\"],\n            capture_output=True,\n            text=True,\n        )\n        # Output of the command is in the form '040000 tree|blob <SHA>\\t<path>\\n'\n        sha = res.stdout.split(\"\\t\")[0].split(\" \")[-1]\n        return sha.encode()\n\n    # We hash both the neuron backends directory and Dockerfile and create a smaller hash out of that\n    m = hashlib.sha256()\n    m.update(get_sha(\"backends/neuron\"))\n    m.update(get_sha(\"Dockerfile.neuron\"))\n    return m.hexdigest()[:10]\n\n\ndef get_neuron_model_name(config_name: str):\n    return f\"neuron-tgi-testing-{config_name}-{get_neuron_backend_hash()}\"\n\n\ndef get_tgi_docker_image():\n    docker_image = os.getenv(\"DOCKER_IMAGE\", None)\n    if docker_image is None:\n        client = docker.from_env()\n        images = client.images.list(filters={\"reference\": \"text-generation-inference\"})\n        if not images:\n            raise ValueError(\n                \"No text-generation-inference image found on this host to run tests.\"\n            )\n        docker_image = images[0].tags[0]\n    return docker_image\n\n\ndef maybe_export_model(config_name, model_config):\n    \"\"\"Export a neuron model for the specified test configuration.\n\n    If the neuron model has not already been compiled and pushed to the hub, it is\n    exported by a custom image built on the fly from the base TGI image.\n    This makes sure the exported model and image are aligned and avoids introducing\n    neuron specific imports in the test suite.\n\n    Args:\n        config_name (`str`):\n            Used to identify test configurations\n        model_config (`str`):\n            The model configuration for export (includes the original model id)\n    \"\"\"\n    neuron_model_name = get_neuron_model_name(config_name)\n    neuron_model_id = f\"{TEST_ORGANIZATION}/{neuron_model_name}\"\n    hub = huggingface_hub.HfApi()\n    if hub.repo_exists(neuron_model_id):\n        logger.info(\n            f\"Skipping model export for config {config_name} as {neuron_model_id} already exists\"\n        )\n        return neuron_model_id\n\n    client = docker.from_env()\n\n    env = {\"LOG_LEVEL\": \"info\", \"CUSTOM_CACHE_REPO\": TEST_CACHE_REPO_ID}\n    if HF_TOKEN is not None:\n        env[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n        env[\"HF_TOKEN\"] = HF_TOKEN\n\n    # Create a sub-image to export the model to workaround docker dind issues preventing\n    # to share a volume from the container running tests\n    model_id = model_config[\"model_id\"]\n    export_kwargs = model_config[\"export_kwargs\"]\n    base_image = get_tgi_docker_image()\n    export_image = f\"neuron-tgi-tests-{config_name}-export-img\"\n    logger.info(f\"Building temporary image {export_image} from {base_image}\")\n    with tempfile.TemporaryDirectory() as context_dir:\n        # Create entrypoint\n        model_path = \"/data/neuron_model\"\n        export_command = (\n            f\"optimum-cli export neuron -m {model_id} --task text-generation\"\n        )\n        for kwarg, value in export_kwargs.items():\n            export_command += f\" --{kwarg} {str(value)}\"\n        export_command += f\" {model_path}\"\n        entrypoint_content = f\"\"\"#!/bin/sh\n        {export_command}\n        huggingface-cli repo create --organization {TEST_ORGANIZATION} {neuron_model_name}\n        huggingface-cli upload {TEST_ORGANIZATION}/{neuron_model_name} {model_path} --exclude *.bin *.safetensors\n        optimum-cli neuron cache synchronize --repo_id {TEST_CACHE_REPO_ID}\n        \"\"\"\n        with open(os.path.join(context_dir, \"entrypoint.sh\"), \"wb\") as f:\n            f.write(entrypoint_content.encode(\"utf-8\"))\n            f.flush()\n        # Create Dockerfile\n        docker_content = f\"\"\"\n        FROM {base_image}\n        COPY entrypoint.sh /export-entrypoint.sh\n        RUN chmod +x /export-entrypoint.sh\n        ENTRYPOINT [\"/export-entrypoint.sh\"]\n        \"\"\"\n        with open(os.path.join(context_dir, \"Dockerfile\"), \"wb\") as f:\n            f.write(docker_content.encode(\"utf-8\"))\n            f.flush()\n        image, logs = client.images.build(\n            path=context_dir, dockerfile=f.name, tag=export_image\n        )\n        logger.info(\"Successfully built image %s\", image.id)\n        logger.debug(\"Build logs %s\", logs)\n\n    try:\n        client.containers.run(\n            export_image,\n            environment=env,\n            auto_remove=True,\n            detach=False,\n            devices=[\"/dev/neuron0\"],\n            shm_size=\"1G\",\n        )\n        logger.info(f\"Successfully exported model for config {config_name}\")\n    except Exception as e:\n        logger.exception(f\"An exception occurred while running container: {e}.\")\n        pass\n    finally:\n        # Cleanup the export image\n        logger.info(\"Cleaning image %s\", image.id)\n        try:\n            image.remove(force=True)\n        except NotFound:\n            pass\n        except Exception as e:\n            logger.error(\"Error while removing image %s, skipping\", image.id)\n            logger.exception(e)\n    return neuron_model_id\n\n\ndef maybe_export_models():\n    for config_name, model_config in MODEL_CONFIGURATIONS.items():\n        maybe_export_model(config_name, model_config)\n\n\n@pytest.fixture(scope=\"session\", params=MODEL_CONFIGURATIONS.keys())\ndef neuron_model_config(request):\n    \"\"\"Expose a pre-trained neuron model\n\n    The fixture first makes sure the following model artifacts are present on the hub:\n    - exported neuron model under optimum-internal-testing/neuron-testing-<name>-<version>,\n    - cached artifacts under optimum-internal-testing/neuron-testing-cache.\n    If not, it will export the model and push it to the hub.\n\n    It then fetches the model locally and return a dictionary containing:\n    - a configuration name,\n    - the original model id,\n    - the export parameters,\n    - the neuron model id,\n    - the neuron model local path.\n\n    For each exposed model, the local directory is maintained for the duration of the\n    test session and cleaned up afterwards.\n    The hub model artifacts are never cleaned up and persist accross sessions.\n    They must be cleaned up manually when the optimum-neuron version changes.\n\n    \"\"\"\n    config_name = request.param\n    model_config = copy.deepcopy(MODEL_CONFIGURATIONS[request.param])\n    # Export the model first (only if needed)\n    neuron_model_id = maybe_export_model(config_name, model_config)\n    with TemporaryDirectory() as neuron_model_path:\n        logger.info(f\"Fetching {neuron_model_id} from the HuggingFace hub\")\n        hub = huggingface_hub.HfApi()\n        hub.snapshot_download(\n            neuron_model_id, etag_timeout=30, local_dir=neuron_model_path\n        )\n        # Add dynamic parameters to the model configuration\n        model_config[\"neuron_model_path\"] = neuron_model_path\n        model_config[\"neuron_model_id\"] = neuron_model_id\n        # Also add model configuration name to allow tests to adapt their expectations\n        model_config[\"name\"] = config_name\n        # Yield instead of returning to keep a reference to the temporary directory.\n        # It will go out of scope and be released only once all tests needing the fixture\n        # have been completed.\n        logger.info(f\"{config_name} ready for testing ...\")\n        yield model_config\n        logger.info(f\"Done with {config_name}\")\n\n\n@pytest.fixture(scope=\"module\")\ndef neuron_model_path(neuron_model_config):\n    yield neuron_model_config[\"neuron_model_path\"]\n\n\nif __name__ == \"__main__\":\n    maybe_export_models()\n"
  },
  {
    "path": "integration-tests/fixtures/neuron/service.py",
    "content": "import asyncio\nimport contextlib\nimport logging\nimport os\nimport random\nimport shutil\nimport sys\nimport tempfile\nimport time\nfrom typing import List\n\nimport docker\nimport huggingface_hub\nimport pytest\nfrom aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError\nfrom docker.errors import NotFound\nfrom huggingface_hub import AsyncInferenceClient, TextGenerationOutput\n\n\nOPTIMUM_CACHE_REPO_ID = \"optimum-internal-testing/neuron-testing-cache\"\nHF_TOKEN = huggingface_hub.get_token()\n\n\ndef get_tgi_docker_image():\n    docker_image = os.getenv(\"DOCKER_IMAGE\", None)\n    if docker_image is None:\n        client = docker.from_env()\n        images = client.images.list(filters={\"reference\": \"text-generation-inference\"})\n        if not images:\n            raise ValueError(\n                \"No text-generation-inference image found on this host to run tests.\"\n            )\n        docker_image = images[0].tags[0]\n    return docker_image\n\n\nlogging.basicConfig(\n    level=logging.INFO,\n    format=\"[%(asctime)s] %(levelname)s [%(filename)s.%(funcName)s:%(lineno)d] %(message)s\",\n    stream=sys.stdout,\n)\nlogger = logging.getLogger(__file__)\n\n\nclass TestClient(AsyncInferenceClient):\n    def __init__(self, service_name: str, base_url: str):\n        super().__init__(model=base_url)\n        self.service_name = service_name\n\n\nclass LauncherHandle:\n    def __init__(self, service_name: str, port: int):\n        self.client = TestClient(service_name, f\"http://localhost:{port}\")\n\n    def _inner_health(self):\n        raise NotImplementedError\n\n    async def health(self, timeout: int = 60):\n        assert timeout > 0\n        for i in range(timeout):\n            if not self._inner_health():\n                raise RuntimeError(f\"Service crashed after {i} seconds.\")\n\n            try:\n                await self.client.text_generation(\"test\", max_new_tokens=1)\n                logger.info(f\"Service started after {i} seconds\")\n                return\n            except (ClientConnectorError, ClientOSError, ServerDisconnectedError):\n                time.sleep(1)\n            except Exception:\n                raise RuntimeError(\"Basic generation failed with: {e}\")\n        raise RuntimeError(f\"Service failed to start after {i} seconds.\")\n\n\nclass ContainerLauncherHandle(LauncherHandle):\n    def __init__(self, service_name, docker_client, container_name, port: int):\n        super(ContainerLauncherHandle, self).__init__(service_name, port)\n        self.docker_client = docker_client\n        self.container_name = container_name\n        self._log_since = time.time()\n\n    def _inner_health(self) -> bool:\n        container = self.docker_client.containers.get(self.container_name)\n        container_output = container.logs(since=self._log_since).decode(\"utf-8\")\n        self._log_since = time.time()\n        if container_output != \"\":\n            print(container_output, end=\"\")\n        return container.status in [\"running\", \"created\"]\n\n\n@pytest.fixture(scope=\"module\")\ndef event_loop():\n    loop = asyncio.get_event_loop()\n    yield loop\n    loop.close()\n\n\n@pytest.fixture(scope=\"module\")\ndef neuron_launcher(event_loop):\n    \"\"\"Utility fixture to expose a TGI service.\n\n    The fixture uses a single event loop for each module, but it can create multiple\n    docker services with different parameters using the parametrized inner context.\n\n    Args:\n        service_name (`str`):\n            Used to identify test configurations and adjust test expectations,\n        model_name_or_path (`str`):\n            The model to use (can be a hub model or a path)\n        trust_remote_code (`bool`):\n            Must be set to True for gated models.\n\n    Returns:\n        A `ContainerLauncherHandle` containing both a TGI server and client.\n    \"\"\"\n\n    @contextlib.contextmanager\n    def docker_launcher(\n        service_name: str,\n        model_name_or_path: str,\n        trust_remote_code: bool = False,\n    ):\n        port = random.randint(8000, 10_000)\n\n        client = docker.from_env()\n\n        container_name = f\"tgi-tests-{service_name}-{port}\"\n\n        try:\n            container = client.containers.get(container_name)\n            container.stop()\n            container.wait()\n        except NotFound:\n            pass\n\n        env = {\n            \"LOG_LEVEL\": \"info,text_generation_router=debug\",\n            \"CUSTOM_CACHE_REPO\": OPTIMUM_CACHE_REPO_ID,\n        }\n\n        if HF_TOKEN is not None:\n            env[\"HUGGING_FACE_HUB_TOKEN\"] = HF_TOKEN\n            env[\"HF_TOKEN\"] = HF_TOKEN\n\n        for var in [\n            \"MAX_BATCH_SIZE\",\n            \"MAX_TOTAL_TOKENS\",\n            \"HF_AUTO_CAST_TYPE\",\n            \"HF_NUM_CORES\",\n        ]:\n            if var in os.environ:\n                env[var] = os.environ[var]\n\n        base_image = get_tgi_docker_image()\n        if os.path.isdir(model_name_or_path):\n            # Create a sub-image containing the model to workaround docker dind issues preventing\n            # to share a volume from the container running tests\n\n            test_image = f\"{container_name}-img\"\n            logger.info(\n                \"Building image on the flight derivated from %s, tagged with %s\",\n                base_image,\n                test_image,\n            )\n            with tempfile.TemporaryDirectory() as context_dir:\n                # Copy model directory to build context\n                model_path = os.path.join(context_dir, \"model\")\n                shutil.copytree(model_name_or_path, model_path)\n                # Create Dockerfile\n                container_model_id = f\"/data/{model_name_or_path}\"\n                docker_content = f\"\"\"\n                FROM {base_image}\n                COPY model {container_model_id}\n                \"\"\"\n                with open(os.path.join(context_dir, \"Dockerfile\"), \"wb\") as f:\n                    f.write(docker_content.encode(\"utf-8\"))\n                    f.flush()\n                image, logs = client.images.build(\n                    path=context_dir, dockerfile=f.name, tag=test_image\n                )\n            logger.info(\"Successfully built image %s\", image.id)\n            logger.debug(\"Build logs %s\", logs)\n        else:\n            test_image = base_image\n            image = None\n            container_model_id = model_name_or_path\n\n        args = [\"--model-id\", container_model_id, \"--env\"]\n\n        if trust_remote_code:\n            args.append(\"--trust-remote-code\")\n\n        container = client.containers.run(\n            test_image,\n            command=args,\n            name=container_name,\n            environment=env,\n            auto_remove=False,\n            detach=True,\n            devices=[\"/dev/neuron0\"],\n            ports={\"80/tcp\": port},\n            shm_size=\"1G\",\n        )\n\n        logger.info(f\"Starting {container_name} container\")\n        yield ContainerLauncherHandle(service_name, client, container.name, port)\n\n        try:\n            container.stop(timeout=60)\n            container.wait(timeout=60)\n        except Exception as e:\n            logger.exception(f\"Ignoring exception while stopping container: {e}.\")\n            pass\n        finally:\n            logger.info(\"Removing container %s\", container_name)\n            try:\n                container.remove(force=True)\n            except Exception as e:\n                logger.error(\n                    \"Error while removing container %s, skipping\", container_name\n                )\n                logger.exception(e)\n\n            # Cleanup the build image\n            if image:\n                logger.info(\"Cleaning image %s\", image.id)\n                try:\n                    image.remove(force=True)\n                except NotFound:\n                    pass\n                except Exception as e:\n                    logger.error(\"Error while removing image %s, skipping\", image.id)\n                    logger.exception(e)\n\n    return docker_launcher\n\n\n@pytest.fixture(scope=\"module\")\ndef neuron_generate_load():\n    \"\"\"A utility fixture to launch multiple asynchronous TGI requests in parallel\n\n    Args:\n        client (`AsyncClient`):\n            An async client\n        prompt (`str`):\n            The prompt to use (identical for all requests)\n        max_new_tokens (`int`):\n            The number of tokens to generate for each request.\n        n (`int`):\n            The number of requests\n\n    Returns:\n        A list of `huggingface_hub.TextGenerationOutput`.\n    \"\"\"\n\n    async def generate_load_inner(\n        client: AsyncInferenceClient, prompt: str, max_new_tokens: int, n: int\n    ) -> List[TextGenerationOutput]:\n        futures = [\n            client.text_generation(\n                prompt,\n                max_new_tokens=max_new_tokens,\n                details=True,\n                decoder_input_details=True,\n            )\n            for _ in range(n)\n        ]\n\n        return await asyncio.gather(*futures)\n\n    return generate_load_inner\n"
  },
  {
    "path": "integration-tests/gaudi/capture_expected_outputs.py",
    "content": "import json\nimport os\nfrom typing import Dict, Any, Generator\n\nimport pytest\nfrom test_gaudi_generate import TEST_CONFIGS\n\nUNKNOWN_CONFIGS = {\n    name: config\n    for name, config in TEST_CONFIGS.items()\n    if config[\"expected_greedy_output\"] == \"unknown\"\n    or config[\"expected_batch_output\"] == \"unknown\"\n}\n\n\n@pytest.fixture(scope=\"module\", params=UNKNOWN_CONFIGS.keys())\ndef test_config(request) -> Dict[str, Any]:\n    \"\"\"Fixture that provides model configurations for testing.\"\"\"\n    test_config = UNKNOWN_CONFIGS[request.param]\n    test_config[\"test_name\"] = request.param\n    return test_config\n\n\n@pytest.fixture(scope=\"module\")\ndef test_name(test_config):\n    yield test_config[\"test_name\"]\n\n\n@pytest.fixture(scope=\"module\")\ndef tgi_service(launcher, test_config, test_name) -> Generator:\n    \"\"\"Fixture that provides a TGI service for testing.\"\"\"\n    with launcher(test_config[\"model_id\"], test_name) as service:\n        yield service\n\n\n@pytest.mark.asyncio\nasync def test_capture_expected_outputs(tgi_service, test_config, test_name):\n    \"\"\"Test that captures expected outputs for models with unknown outputs.\"\"\"\n    print(f\"Testing {test_name} with {test_config['model_id']}\")\n\n    # Wait for service to be ready\n    await tgi_service.health(1000)\n    client = tgi_service.client\n\n    # Test single request (greedy)\n    print(\"Testing single request...\")\n    response = await client.generate(\n        test_config[\"input\"],\n        max_new_tokens=32,\n    )\n    greedy_output = response.generated_text\n\n    # Test multiple requests (batch)\n    print(\"Testing batch requests...\")\n    responses = []\n    for _ in range(4):\n        response = await client.generate(\n            test_config[\"input\"],\n            max_new_tokens=32,\n        )\n        responses.append(response.generated_text)\n\n    # Store results in a JSON file\n    output_file = \"server/integration-tests/expected_outputs.json\"\n    results = {}\n\n    # Try to load existing results if file exists\n    if os.path.exists(output_file):\n        with open(output_file, \"r\") as f:\n            results = json.load(f)\n\n    # Update results for this model\n    results[test_name] = {\n        \"model_id\": test_config[\"model_id\"],\n        \"input\": test_config[\"input\"],\n        \"greedy_output\": greedy_output,\n        \"batch_outputs\": responses,\n        \"args\": test_config[\"args\"],\n    }\n\n    # Save updated results\n    with open(output_file, \"w\") as f:\n        json.dump(results, f, indent=2)\n\n    print(f\"\\nResults for {test_name} saved to {output_file}\")\n"
  },
  {
    "path": "integration-tests/gaudi/test_gaudi_generate.py",
    "content": "from typing import Any, Dict, Generator\nfrom _pytest.fixtures import SubRequest\nfrom huggingface_hub import AsyncInferenceClient\nimport pytest\n\n\ndef pytest_configure(config):\n    config.addinivalue_line(\n        \"markers\", \"gaudi_all_models: mark test to run with all models\"\n    )\n\n\n# The \"args\" values in TEST_CONFIGS are not optimized for speed but only check that the inference is working for the different models architectures.\nTEST_CONFIGS = {\n    \"meta-llama/Llama-3.1-8B-Instruct-sharded\": {\n        \"model_id\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \" A Beginner’s Guide\\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of\",\n        \"expected_batch_output\": \" A Beginner’s Guide\\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of\",\n        \"args\": [\n            \"--sharded\",\n            \"true\",\n            \"--num-shard\",\n            \"2\",\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n        \"run_by_default\": True,\n    },\n    \"meta-llama/Llama-3.1-8B-Instruct\": {\n        \"model_id\": \"meta-llama/Llama-3.1-8B-Instruct\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \" A Beginner’s Guide\\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of\",\n        \"expected_batch_output\": \" A Beginner’s Guide\\nDeep learning is a subset of machine learning that involves the use of artificial neural networks to analyze and interpret data. It is a type of\",\n        \"env_config\": {},\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n        \"run_by_default\": True,\n    },\n    \"meta-llama/Llama-2-7b-chat-hf\": {\n        \"model_id\": \"meta-llama/Llama-2-7b-chat-hf\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\\u2014specific\",\n        \"expected_batch_output\": \"\\n\\nDeep learning (also known as deep structured learning) is part of a broader family of machine learning techniques based on artificial neural networks\\u2014specific\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n    },\n    \"mistralai/Mistral-7B-Instruct-v0.3\": {\n        \"model_id\": \"mistralai/Mistral-7B-Instruct-v0.3\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured\",\n        \"expected_batch_output\": \"\\n\\nDeep learning is a subset of machine learning in artificial intelligence (AI) that has networks capable of learning unsupervised from data that is unstructured\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n    },\n    \"bigcode/starcoder2-3b\": {\n        \"model_id\": \"bigcode/starcoder2-3b\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\\n\\nNeural networks are a type of machine learning algorithm that\",\n        \"expected_batch_output\": \"\\n\\nDeep learning is a subset of machine learning that uses artificial neural networks to perform tasks.\\n\\nNeural networks are a type of machine learning algorithm that\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n    },\n    \"google/gemma-7b-it\": {\n        \"model_id\": \"google/gemma-7b-it\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Deep learning is a powerful tool for many tasks,\",\n        \"expected_batch_output\": \"\\n\\nDeep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. Deep learning is a powerful tool for many tasks,\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n    },\n    \"Qwen/Qwen2-0.5B-Instruct\": {\n        \"model_id\": \"Qwen/Qwen2-0.5B-Instruct\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \" Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models\",\n        \"expected_batch_output\": \" Deep Learning is a type of machine learning that is based on the principles of artificial neural networks. It is a type of machine learning that is used to train models\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n            \"--max-batch-prefill-tokens\",\n            \"2048\",\n        ],\n    },\n    \"tiiuae/falcon-7b-instruct\": {\n        \"model_id\": \"tiiuae/falcon-7b-instruct\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a\",\n        \"expected_batch_output\": \"\\nDeep learning is a branch of machine learning that uses artificial neural networks to learn and make decisions. It is based on the concept of hierarchical learning, where a\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n        ],\n    },\n    \"microsoft/phi-1_5\": {\n        \"model_id\": \"microsoft/phi-1_5\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large\",\n        \"expected_batch_output\": \"\\n\\nDeep Learning is a subfield of Machine Learning that focuses on building neural networks with multiple layers of interconnected nodes. These networks are designed to learn from large\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n        ],\n    },\n    \"openai-community/gpt2\": {\n        \"model_id\": \"openai-community/gpt2\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning is a subset of machine learning that is based on artificial neural networks. It is a type of machine learning that is based on the idea of\",\n        \"expected_batch_output\": \"\\n\\nDeep learning is a subset of machine learning that is based on artificial neural networks. It is a type of machine learning that is based on the idea of\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n        ],\n    },\n    \"EleutherAI/gpt-j-6b\": {\n        \"model_id\": \"EleutherAI/gpt-j-6b\",\n        \"input\": \"What is Deep Learning?\",\n        \"expected_greedy_output\": \"\\n\\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by\",\n        \"expected_batch_output\": \"\\n\\nDeep learning is a subset of machine learning that is based on the idea of neural networks. Neural networks are a type of artificial intelligence that is inspired by\",\n        \"args\": [\n            \"--max-input-tokens\",\n            \"512\",\n            \"--max-total-tokens\",\n            \"1024\",\n            \"--max-batch-size\",\n            \"4\",\n        ],\n    },\n}\n\n\ndef pytest_generate_tests(metafunc):\n    if \"test_config\" in metafunc.fixturenames:\n        if metafunc.config.getoption(\"--gaudi-all-models\"):\n            models = list(TEST_CONFIGS.keys())\n        else:\n            models = [\n                name\n                for name, config in TEST_CONFIGS.items()\n                if config.get(\"run_by_default\", False)\n            ]\n        print(f\"Testing {len(models)} models\")\n        metafunc.parametrize(\"test_config\", models, indirect=True)\n\n\n@pytest.fixture(scope=\"module\")\ndef test_config(request: SubRequest) -> Dict[str, Any]:\n    \"\"\"Fixture that provides model configurations for testing.\"\"\"\n    model_name = request.param\n    test_config = TEST_CONFIGS[model_name]\n    test_config[\"test_name\"] = model_name\n    return test_config\n\n\n@pytest.fixture(scope=\"module\")\ndef model_id(test_config: Dict[str, Any]) -> Generator[str, None, None]:\n    yield test_config[\"model_id\"]\n\n\n@pytest.fixture(scope=\"module\")\ndef test_name(test_config: Dict[str, Any]) -> Generator[str, None, None]:\n    yield test_config[\"test_name\"]\n\n\n@pytest.fixture(scope=\"module\")\ndef expected_outputs(test_config: Dict[str, Any]) -> Dict[str, str]:\n    return {\n        \"greedy\": test_config[\"expected_greedy_output\"],\n        \"batch\": test_config[\"expected_batch_output\"],\n    }\n\n\n@pytest.fixture(scope=\"module\")\ndef input(test_config: Dict[str, Any]) -> str:\n    return test_config[\"input\"]\n\n\n@pytest.fixture(scope=\"module\")\ndef tgi_service(\n    gaudi_launcher, model_id: str, test_name: str, test_config: Dict[str, Any]\n):\n    with gaudi_launcher(\n        model_id,\n        test_name,\n        tgi_args=test_config.get(\"args\", []),\n        env_config=test_config.get(\"env_config\", {}),\n    ) as tgi_service:\n        yield tgi_service\n\n\n@pytest.fixture(scope=\"module\")\nasync def tgi_client(tgi_service) -> AsyncInferenceClient:\n    await tgi_service.health(1000)\n    return tgi_service.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.all_models\nasync def test_model_single_request(\n    tgi_client: AsyncInferenceClient, expected_outputs: Dict[str, str], input: str\n):\n    # Bounded greedy decoding without input\n    response = await tgi_client.text_generation(\n        input,\n        max_new_tokens=32,\n        details=True,\n        decoder_input_details=True,\n    )\n    assert response.details.generated_tokens == 32\n    assert response.generated_text == expected_outputs[\"greedy\"]\n\n\n@pytest.mark.asyncio\n@pytest.mark.all_models\nasync def test_model_multiple_requests(\n    tgi_client: AsyncInferenceClient,\n    gaudi_generate_load,\n    expected_outputs: Dict[str, str],\n    input: str,\n):\n    num_requests = 4\n    responses = await gaudi_generate_load(\n        tgi_client,\n        input,\n        max_new_tokens=32,\n        n=num_requests,\n    )\n\n    assert len(responses) == 4\n    expected = expected_outputs[\"batch\"]\n    for r in responses:\n        assert r.details.generated_tokens == 32\n        assert r.generated_text == expected\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test.py",
    "content": "import os\nimport json\n\n\nfor root, dirs, files in os.walk(\".\"):\n    for filename in files:\n        if filename.endswith(\".json\"):\n            with open(os.path.join(root, filename), \"r\") as f:\n                data = json.load(f)\n\n            print(os.path.join(root, filename))\n            try:\n                if filename.endswith(\"_load.json\"):\n                    for i in range(len(data)):\n                        data[i][\"details\"][\"prefill\"] = []\n                else:\n                    data[\"details\"][\"prefill\"] = []\n            except Exception:\n                pass\n\n            with open(os.path.join(root, filename), \"w\") as f:\n                json.dump(data, f, indent=2, ensure_ascii=False)\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 17934,\n        \"logprob\": null,\n        \"text\": \"Pour\"\n      },\n      {\n        \"id\": 49833,\n        \"logprob\": -10.5703125,\n        \"text\": \" dég\"\n      },\n      {\n        \"id\": 21543,\n        \"logprob\": -0.14746094,\n        \"text\": \"uster\"\n      },\n      {\n        \"id\": 447,\n        \"logprob\": -1.9277344,\n        \"text\": \" un\"\n      },\n      {\n        \"id\": 46341,\n        \"logprob\": -15.421875,\n        \"text\": \" ort\"\n      },\n      {\n        \"id\": 35567,\n        \"logprob\": -7.5820312,\n        \"text\": \"olan\"\n      },\n      {\n        \"id\": 15,\n        \"logprob\": -1.4013672,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1669,\n        \"logprob\": -1.5595703,\n        \"text\": \" il\"\n      },\n      {\n        \"id\": 11580,\n        \"logprob\": -0.9428711,\n        \"text\": \" faut\"\n      },\n      {\n        \"id\": 3913,\n        \"logprob\": -3.703125,\n        \"text\": \" tout\"\n      },\n      {\n        \"id\": 39261,\n        \"logprob\": -1.7763672,\n        \"text\": \" d'abord\"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 578,\n        \"logprob\": -1.7822266,\n        \"special\": false,\n        \"text\": \" le\"\n      },\n      {\n        \"id\": 5608,\n        \"logprob\": -2.4882812,\n        \"special\": false,\n        \"text\": \" faire\"\n      },\n      {\n        \"id\": 7735,\n        \"logprob\": -2.4199219,\n        \"special\": false,\n        \"text\": \" fond\"\n      },\n      {\n        \"id\": 289,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"re\"\n      },\n      {\n        \"id\": 693,\n        \"logprob\": -2.4628906,\n        \"special\": false,\n        \"text\": \" à\"\n      },\n      {\n        \"id\": 366,\n        \"logprob\": -1.1308594,\n        \"special\": false,\n        \"text\": \" la\"\n      },\n      {\n        \"id\": 48844,\n        \"logprob\": -1.7900391,\n        \"special\": false,\n        \"text\": \" cass\"\n      },\n      {\n        \"id\": 1744,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"ero\"\n      },\n      {\n        \"id\": 327,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"le\"\n      },\n      {\n        \"id\": 2940,\n        \"logprob\": -1.9306641,\n        \"special\": false,\n        \"text\": \" avec\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" le faire fondre à la casserole avec\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 15,\n        \"logprob\": null,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1669,\n        \"logprob\": -5.4453125,\n        \"text\": \" il\"\n      },\n      {\n        \"id\": 11580,\n        \"logprob\": -2.3378906,\n        \"text\": \" faut\"\n      },\n      {\n        \"id\": 3913,\n        \"logprob\": -4.3320312,\n        \"text\": \" tout\"\n      },\n      {\n        \"id\": 39261,\n        \"logprob\": -2.9160156,\n        \"text\": \" d'abord\"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 408,\n        \"logprob\": -0.16687012,\n        \"special\": false,\n        \"text\": \" que\"\n      },\n      {\n        \"id\": 366,\n        \"logprob\": -1.5517578,\n        \"special\": false,\n        \"text\": \" la\"\n      },\n      {\n        \"id\": 8769,\n        \"logprob\": -0.16687012,\n        \"special\": false,\n        \"text\": \" personne\"\n      },\n      {\n        \"id\": 1479,\n        \"logprob\": -2.1035156,\n        \"special\": false,\n        \"text\": \" qui\"\n      },\n      {\n        \"id\": 143926,\n        \"logprob\": -2.8671875,\n        \"special\": false,\n        \"text\": \" réalise\"\n      },\n      {\n        \"id\": 578,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" le\"\n      },\n      {\n        \"id\": 8138,\n        \"logprob\": -0.66748047,\n        \"special\": false,\n        \"text\": \" projet\"\n      },\n      {\n        \"id\": 795,\n        \"logprob\": -1.6279297,\n        \"special\": false,\n        \"text\": \" ne\"\n      },\n      {\n        \"id\": 9802,\n        \"logprob\": -0.47875977,\n        \"special\": false,\n        \"text\": \" soit\"\n      },\n      {\n        \"id\": 1230,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" pas\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Pour déguster un ortolan, il faut tout d'abord que la personne qui réalise le projet ne soit pas\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_bloom_560m/test_bloom_560m_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.5625,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.14770508,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.4609375,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.5585938,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.4003906,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5673828,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94628906,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.703125,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.5732422,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7646484,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.6113281,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5263672,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -0.00010049343,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.4707031,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.2119141,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11883545,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.40844727,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.0037841797,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0195312,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.53125,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.14770508,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.4140625,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.5234375,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.3613281,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5458984,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94189453,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.7011719,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.5732422,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7548828,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.578125,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5117188,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -0.00010049343,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.4707031,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11004639,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4506836,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.53125,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.14770508,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.4140625,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.5234375,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.3613281,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5458984,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94189453,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.7011719,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.5732422,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7548828,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.578125,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5117188,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -0.00010049343,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.4707031,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11004639,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4506836,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.53125,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.14770508,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.4140625,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.5234375,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.3613281,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5458984,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94189453,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.7011719,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.5732422,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7548828,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.578125,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5117188,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -0.00010049343,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.4707031,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11004639,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4506836,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 17934,\n        \"logprob\": null,\n        \"text\": \"Pour\"\n      },\n      {\n        \"id\": 49833,\n        \"logprob\": -10.546875,\n        \"text\": \" dég\"\n      },\n      {\n        \"id\": 21543,\n        \"logprob\": -0.14819336,\n        \"text\": \"uster\"\n      },\n      {\n        \"id\": 447,\n        \"logprob\": -1.9257812,\n        \"text\": \" un\"\n      },\n      {\n        \"id\": 46341,\n        \"logprob\": -15.4296875,\n        \"text\": \" ort\"\n      },\n      {\n        \"id\": 35567,\n        \"logprob\": -7.5625,\n        \"text\": \"olan\"\n      },\n      {\n        \"id\": 15,\n        \"logprob\": -1.4199219,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1669,\n        \"logprob\": -1.5634766,\n        \"text\": \" il\"\n      },\n      {\n        \"id\": 11580,\n        \"logprob\": -0.9458008,\n        \"text\": \" faut\"\n      },\n      {\n        \"id\": 3913,\n        \"logprob\": -3.6816406,\n        \"text\": \" tout\"\n      },\n      {\n        \"id\": 39261,\n        \"logprob\": -1.7753906,\n        \"text\": \" d'abord\"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 578,\n        \"logprob\": -1.828125,\n        \"special\": false,\n        \"text\": \" le\"\n      },\n      {\n        \"id\": 5608,\n        \"logprob\": -2.5546875,\n        \"special\": false,\n        \"text\": \" faire\"\n      },\n      {\n        \"id\": 7735,\n        \"logprob\": -2.4277344,\n        \"special\": false,\n        \"text\": \" fond\"\n      },\n      {\n        \"id\": 289,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"re\"\n      },\n      {\n        \"id\": 693,\n        \"logprob\": -2.4472656,\n        \"special\": false,\n        \"text\": \" à\"\n      },\n      {\n        \"id\": 366,\n        \"logprob\": -1.1494141,\n        \"special\": false,\n        \"text\": \" la\"\n      },\n      {\n        \"id\": 48844,\n        \"logprob\": -1.7939453,\n        \"special\": false,\n        \"text\": \" cass\"\n      },\n      {\n        \"id\": 1744,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"ero\"\n      },\n      {\n        \"id\": 327,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"le\"\n      },\n      {\n        \"id\": 2940,\n        \"logprob\": -1.9013672,\n        \"special\": false,\n        \"text\": \" avec\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" le faire fondre à la casserole avec\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_bloom_560m_sharded/test_bloom_560m_sharded_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.5390625,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.14758301,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9296875,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.4453125,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.59375,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.3994141,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.578125,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.9453125,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.7011719,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.5732422,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7529297,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.6054688,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5283203,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -0.00010049343,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.4716797,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11853027,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.41210938,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.0037765503,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0166016,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.515625,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.1484375,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.34375,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.515625,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.4199219,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5664062,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94091797,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.6660156,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.7753906,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7626953,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.5820312,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5097656,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -9.393692e-05,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.5175781,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11883545,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4909668,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.515625,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.1484375,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.34375,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.515625,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.4199219,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5664062,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94091797,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.6660156,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.7753906,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7626953,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.5820312,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5097656,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -9.393692e-05,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.5175781,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11883545,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4909668,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 17934,\n          \"logprob\": null,\n          \"text\": \"Pour\"\n        },\n        {\n          \"id\": 49833,\n          \"logprob\": -10.515625,\n          \"text\": \" dég\"\n        },\n        {\n          \"id\": 21543,\n          \"logprob\": -0.1484375,\n          \"text\": \"uster\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.9287109,\n          \"text\": \" un\"\n        },\n        {\n          \"id\": 46341,\n          \"logprob\": -15.34375,\n          \"text\": \" ort\"\n        },\n        {\n          \"id\": 35567,\n          \"logprob\": -7.515625,\n          \"text\": \"olan\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -1.4199219,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1669,\n          \"logprob\": -1.5664062,\n          \"text\": \" il\"\n        },\n        {\n          \"id\": 11580,\n          \"logprob\": -0.94091797,\n          \"text\": \" faut\"\n        },\n        {\n          \"id\": 3913,\n          \"logprob\": -3.6660156,\n          \"text\": \" tout\"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -1.7753906,\n          \"text\": \" d'abord\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 578,\n          \"logprob\": -1.7626953,\n          \"special\": false,\n          \"text\": \" le\"\n        },\n        {\n          \"id\": 5608,\n          \"logprob\": -2.5820312,\n          \"special\": false,\n          \"text\": \" faire\"\n        },\n        {\n          \"id\": 1767,\n          \"logprob\": -1.5097656,\n          \"special\": false,\n          \"text\": \" cu\"\n        },\n        {\n          \"id\": 1273,\n          \"logprob\": -9.393692e-05,\n          \"special\": false,\n          \"text\": \"ire\"\n        },\n        {\n          \"id\": 1486,\n          \"logprob\": -1.5175781,\n          \"special\": false,\n          \"text\": \" dans\"\n        },\n        {\n          \"id\": 283,\n          \"logprob\": -1.1982422,\n          \"special\": false,\n          \"text\": \" de\"\n        },\n        {\n          \"id\": 40410,\n          \"logprob\": -0.11883545,\n          \"special\": false,\n          \"text\": \" l'eau\"\n        },\n        {\n          \"id\": 20226,\n          \"logprob\": -0.4909668,\n          \"special\": false,\n          \"text\": \" bou\"\n        },\n        {\n          \"id\": 172483,\n          \"logprob\": -0.003047943,\n          \"special\": false,\n          \"text\": \"illante\"\n        },\n        {\n          \"id\": 2805,\n          \"logprob\": -1.0185547,\n          \"special\": false,\n          \"text\": \" sal\"\n        }\n      ]\n    },\n    \"generated_text\": \" le faire cuire dans de l'eau bouillante sal\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\\n\\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1724792495,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"2.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 100,\n    \"prompt_tokens\": 61,\n    \"total_tokens\": 161\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_chat_hfhub_nousage.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"OK\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265520,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"!\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265520,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265520,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_chat_hfhub_usage.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"OK\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741266005,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"!\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741266005,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741266005,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [],\n    \"created\": 1741266005,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 3,\n      \"prompt_tokens\": 39,\n      \"total_tokens\": 42\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_chat_openai_nousage.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"OK\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265134,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"!\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265134,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265134,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_chat_openai_usage.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"OK\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265133,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"!\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265133,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\",\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741265133,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [],\n    \"created\": 1741265133,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 3,\n      \"completion_tokens_details\": null,\n      \"prompt_tokens\": 39,\n      \"prompt_tokens_details\": null,\n      \"total_tokens\": 42\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 1,\n      \"logprobs\": null,\n      \"text\": \" This is a question that has puzzled many people for\"\n    },\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"text\": \" A Beginner’s Guide\\nDeep learning is a subset\"\n    },\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 3,\n      \"logprobs\": null,\n      \"text\": \"usculas_minusculas(s):\\n    \\\"\\\"\\\"\\n\"\n    },\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 2,\n      \"logprobs\": null,\n      \"text\": \" Paris\\nWhat is the capital of France?\\nThe\"\n    }\n  ],\n  \"created\": 1741264813,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"text_completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 40,\n    \"prompt_tokens\": 22,\n    \"total_tokens\": 62\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" A\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" This\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" Paris\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"us\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" Beginner\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" is\"\n      }\n    ],\n    \"created\": 1741340006,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \"\\n\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"cul\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \"’s\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" a\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \"What\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"as\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" Guide\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" question\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" is\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"_minus\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \"\\n\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" that\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" the\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"cul\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \"Deep\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" has\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" capital\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"as\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" learning\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" puzzled\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" of\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"(s\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" is\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" many\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \" France\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"):\\n\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" a\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" people\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \"?\\n\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \"   \"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"text\": \" subset\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"completion_tokens_details\": null,\n      \"prompt_tokens\": 6,\n      \"prompt_tokens_details\": null,\n      \"total_tokens\": 16\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 1,\n        \"logprobs\": null,\n        \"text\": \" for\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"completion_tokens_details\": null,\n      \"prompt_tokens\": 5,\n      \"prompt_tokens_details\": null,\n      \"total_tokens\": 15\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 2,\n        \"logprobs\": null,\n        \"text\": \"The\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"completion_tokens_details\": null,\n      \"prompt_tokens\": 8,\n      \"prompt_tokens_details\": null,\n      \"total_tokens\": 18\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 3,\n        \"logprobs\": null,\n        \"text\": \" \\\"\\\"\\\"\\n\"\n      }\n    ],\n    \"created\": 1741340007,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"text_completion\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"completion_tokens_details\": null,\n      \"prompt_tokens\": 3,\n      \"prompt_tokens_details\": null,\n      \"total_tokens\": 13\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_single_prompt.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"text\": \" A Beginner’s Guide\\nDeep learning is a subset\"\n    }\n  ],\n  \"created\": 1741264812,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"text_completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 10,\n    \"prompt_tokens\": 6,\n    \"total_tokens\": 16\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_stream_usage.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"**\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373593,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Deep\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373593,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Learning\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373593,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \":\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" An\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Overview\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"**\\n\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"================================\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"=====\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\n\\n\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [],\n    \"created\": 1741373594,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 40,\n      \"total_tokens\": 50\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 323,\n        \"logprob\": -1.1171875,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 1268,\n        \"logprob\": -0.9477539,\n        \"special\": false,\n        \"text\": \" how\"\n      },\n      {\n        \"id\": 1587,\n        \"logprob\": -0.51464844,\n        \"special\": false,\n        \"text\": \" does\"\n      },\n      {\n        \"id\": 433,\n        \"logprob\": -0.043182373,\n        \"special\": false,\n        \"text\": \" it\"\n      },\n      {\n        \"id\": 1782,\n        \"logprob\": -1.0810547,\n        \"special\": false,\n        \"text\": \" differ\"\n      },\n      {\n        \"id\": 505,\n        \"logprob\": -0.005054474,\n        \"special\": false,\n        \"text\": \" from\"\n      },\n      {\n        \"id\": 8776,\n        \"logprob\": -0.47485352,\n        \"special\": false,\n        \"text\": \" traditional\"\n      },\n      {\n        \"id\": 5780,\n        \"logprob\": -0.15112305,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.0011291504,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 5380,\n        \"logprob\": -0.31323242,\n        \"special\": false,\n        \"text\": \"?\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" and how does it differ from traditional machine learning?\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5380,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"?\\n\"\n      },\n      {\n        \"id\": 34564,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 11,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1101,\n        \"logprob\": -1.0136719,\n        \"special\": false,\n        \"text\": \" also\"\n      },\n      {\n        \"id\": 3967,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" known\"\n      },\n      {\n        \"id\": 439,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" as\"\n      },\n      {\n        \"id\": 30828,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" neural\"\n      },\n      {\n        \"id\": 4009,\n        \"logprob\": -0.21923828,\n        \"special\": false,\n        \"text\": \" network\"\n      },\n      {\n        \"id\": 477,\n        \"logprob\": -1.4824219,\n        \"special\": false,\n        \"text\": \" or\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep learning, also known as neural network or\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int/test_compressed_tensors_w8a8_int_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 323,\n          \"logprob\": -1.1171875,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 1268,\n          \"logprob\": -0.9477539,\n          \"special\": false,\n          \"text\": \" how\"\n        },\n        {\n          \"id\": 1587,\n          \"logprob\": -0.51464844,\n          \"special\": false,\n          \"text\": \" does\"\n        },\n        {\n          \"id\": 433,\n          \"logprob\": -0.043182373,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 1782,\n          \"logprob\": -1.0810547,\n          \"special\": false,\n          \"text\": \" differ\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -0.005054474,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 8776,\n          \"logprob\": -0.47485352,\n          \"special\": false,\n          \"text\": \" traditional\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.15112305,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0011291504,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 5380,\n          \"logprob\": -0.3173828,\n          \"special\": false,\n          \"text\": \"?\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" and how does it differ from traditional machine learning?\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 323,\n          \"logprob\": -1.1171875,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 1268,\n          \"logprob\": -0.9477539,\n          \"special\": false,\n          \"text\": \" how\"\n        },\n        {\n          \"id\": 1587,\n          \"logprob\": -0.51464844,\n          \"special\": false,\n          \"text\": \" does\"\n        },\n        {\n          \"id\": 433,\n          \"logprob\": -0.043182373,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 1782,\n          \"logprob\": -1.0810547,\n          \"special\": false,\n          \"text\": \" differ\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -0.005054474,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 8776,\n          \"logprob\": -0.47485352,\n          \"special\": false,\n          \"text\": \" traditional\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.15112305,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0011291504,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 5380,\n          \"logprob\": -0.3173828,\n          \"special\": false,\n          \"text\": \"?\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" and how does it differ from traditional machine learning?\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 323,\n          \"logprob\": -1.1171875,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 1268,\n          \"logprob\": -0.9477539,\n          \"special\": false,\n          \"text\": \" how\"\n        },\n        {\n          \"id\": 1587,\n          \"logprob\": -0.51464844,\n          \"special\": false,\n          \"text\": \" does\"\n        },\n        {\n          \"id\": 433,\n          \"logprob\": -0.043182373,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 1782,\n          \"logprob\": -1.0810547,\n          \"special\": false,\n          \"text\": \" differ\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -0.005054474,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 8776,\n          \"logprob\": -0.47485352,\n          \"special\": false,\n          \"text\": \" traditional\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.15112305,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0011291504,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 5380,\n          \"logprob\": -0.3173828,\n          \"special\": false,\n          \"text\": \"?\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" and how does it differ from traditional machine learning?\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 323,\n          \"logprob\": -1.1171875,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 1268,\n          \"logprob\": -0.9477539,\n          \"special\": false,\n          \"text\": \" how\"\n        },\n        {\n          \"id\": 1587,\n          \"logprob\": -0.51464844,\n          \"special\": false,\n          \"text\": \" does\"\n        },\n        {\n          \"id\": 433,\n          \"logprob\": -0.043182373,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 1782,\n          \"logprob\": -1.0810547,\n          \"special\": false,\n          \"text\": \" differ\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -0.005054474,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 8776,\n          \"logprob\": -0.47485352,\n          \"special\": false,\n          \"text\": \" traditional\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.15112305,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0011291504,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 5380,\n          \"logprob\": -0.3173828,\n          \"special\": false,\n          \"text\": \"?\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" and how does it differ from traditional machine learning?\\n\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 76,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 18183,\n        \"logprob\": -1.5195312,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 6832,\n        \"logprob\": -0.06817627,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.13122559,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.13415527,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 25993,\n        \"logprob\": -0.8769531,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.0011396408,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5662,\n        \"logprob\": -0.16442871,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6832,\n        \"logprob\": -0.0026416779,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 429,\n        \"logprob\": -0.48754883,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 5711,\n        \"logprob\": -1.2294922,\n        \"special\": false,\n        \"text\": \" uses\"\n      },\n      {\n        \"id\": 29728,\n        \"logprob\": -0.66503906,\n        \"special\": false,\n        \"text\": \" neural\"\n      },\n      {\n        \"id\": 14155,\n        \"logprob\": -0.02960205,\n        \"special\": false,\n        \"text\": \" networks\"\n      },\n      {\n        \"id\": 311,\n        \"logprob\": -0.7236328,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 3960,\n        \"logprob\": -1.1914062,\n        \"special\": false,\n        \"text\": \" learn\"\n      },\n      {\n        \"id\": 504,\n        \"logprob\": -0.7089844,\n        \"special\": false,\n        \"text\": \" from\"\n      },\n      {\n        \"id\": 821,\n        \"logprob\": -0.7729492,\n        \"special\": false,\n        \"text\": \" data\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.7836914,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 1084,\n        \"logprob\": -0.9941406,\n        \"special\": false,\n        \"text\": \" It\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.52441406,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.9511719,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 943,\n        \"logprob\": -0.8642578,\n        \"special\": false,\n        \"text\": \" type\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.00030231476,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 20443,\n        \"logprob\": -0.14416504,\n        \"special\": false,\n        \"text\": \" artificial\"\n      },\n      {\n        \"id\": 11229,\n        \"logprob\": -0.013824463,\n        \"special\": false,\n        \"text\": \" intelligence\"\n      },\n      {\n        \"id\": 429,\n        \"logprob\": -0.18762207,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 646,\n        \"logprob\": -1.0087891,\n        \"special\": false,\n        \"text\": \" can\"\n      },\n      {\n        \"id\": 3960,\n        \"logprob\": -0.90234375,\n        \"special\": false,\n        \"text\": \" learn\"\n      },\n      {\n        \"id\": 504,\n        \"logprob\": -0.54345703,\n        \"special\": false,\n        \"text\": \" from\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -1.0400391,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 1281,\n        \"logprob\": -0.072509766,\n        \"special\": false,\n        \"text\": \" make\"\n      },\n      {\n        \"id\": 19898,\n        \"logprob\": -0.16516113,\n        \"special\": false,\n        \"text\": \" predictions\"\n      },\n      {\n        \"id\": 389,\n        \"logprob\": -0.4416504,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 3460,\n        \"logprob\": -0.5385742,\n        \"special\": false,\n        \"text\": \" large\"\n      },\n      {\n        \"id\": 14713,\n        \"logprob\": -0.4387207,\n        \"special\": false,\n        \"text\": \" amounts\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.00015091896,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 821,\n        \"logprob\": -0.061431885,\n        \"special\": false,\n        \"text\": \" data\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.71875,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 18183,\n        \"logprob\": -0.23632812,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 6832,\n        \"logprob\": -0.0017204285,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -1.1738281,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 1483,\n        \"logprob\": -0.61083984,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": -0.035003662,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.118652344,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 8045,\n        \"logprob\": -0.42016602,\n        \"special\": false,\n        \"text\": \" variety\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -1.6212463e-05,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 8357,\n        \"logprob\": -0.1315918,\n        \"special\": false,\n        \"text\": \" applications\"\n      },\n      {\n        \"id\": 11,\n        \"logprob\": -0.12915039,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 2670,\n        \"logprob\": -0.12463379,\n        \"special\": false,\n        \"text\": \" including\"\n      },\n      {\n        \"id\": 2168,\n        \"logprob\": -0.37402344,\n        \"special\": false,\n        \"text\": \" image\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -0.1451416,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 8806,\n        \"logprob\": -0.028869629,\n        \"special\": false,\n        \"text\": \" speech\"\n      },\n      {\n        \"id\": 17843,\n        \"logprob\": -0.00024068356,\n        \"special\": false,\n        \"text\": \" recognition\"\n      },\n      {\n        \"id\": 11,\n        \"logprob\": -0.00031018257,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 5810,\n        \"logprob\": -0.019821167,\n        \"special\": false,\n        \"text\": \" natural\"\n      },\n      {\n        \"id\": 4128,\n        \"logprob\": -0.00012528896,\n        \"special\": false,\n        \"text\": \" language\"\n      },\n      {\n        \"id\": 8692,\n        \"logprob\": -0.00089263916,\n        \"special\": false,\n        \"text\": \" processing\"\n      },\n      {\n        \"id\": 11,\n        \"logprob\": -0.00073862076,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -0.040161133,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 38193,\n        \"logprob\": -0.4519043,\n        \"special\": false,\n        \"text\": \" autonomous\"\n      },\n      {\n        \"id\": 11474,\n        \"logprob\": -0.39941406,\n        \"special\": false,\n        \"text\": \" vehicles\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.21166992,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 1084,\n        \"logprob\": -0.9082031,\n        \"special\": false,\n        \"text\": \" It\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.44213867,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -1.2177734,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 18512,\n        \"logprob\": -0.5205078,\n        \"special\": false,\n        \"text\": \" rapidly\"\n      },\n      {\n        \"id\": 7826,\n        \"logprob\": -0.15332031,\n        \"special\": false,\n        \"text\": \" growing\"\n      },\n      {\n        \"id\": 2070,\n        \"logprob\": -0.0039978027,\n        \"special\": false,\n        \"text\": \" field\"\n      },\n      {\n        \"id\": 448,\n        \"logprob\": -0.9091797,\n        \"special\": false,\n        \"text\": \" with\"\n      },\n      {\n        \"id\": 1657,\n        \"logprob\": -0.17114258,\n        \"special\": false,\n        \"text\": \" many\"\n      },\n      {\n        \"id\": 4650,\n        \"logprob\": -0.70703125,\n        \"special\": false,\n        \"text\": \" potential\"\n      },\n      {\n        \"id\": 8357,\n        \"logprob\": -0.025131226,\n        \"special\": false,\n        \"text\": \" applications\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": -0.6699219,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 279,\n        \"logprob\": -0.35205078,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 3853,\n        \"logprob\": -0.049194336,\n        \"special\": false,\n        \"text\": \" future\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.21972656,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 151643,\n        \"logprob\": -2.0019531,\n        \"special\": true,\n        \"text\": \"<|endoftext|>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5267,\n        \"logprob\": -1.1464844,\n        \"special\": false,\n        \"text\": \"?\\n\"\n      },\n      {\n        \"id\": 33464,\n        \"logprob\": -0.83203125,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 20909,\n        \"logprob\": -0.5625,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 320,\n        \"logprob\": -2.1464844,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 16524,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"DL\"\n      },\n      {\n        \"id\": 701,\n        \"logprob\": -2.2089844,\n        \"special\": false,\n        \"text\": \"),\"\n      },\n      {\n        \"id\": 476,\n        \"logprob\": -0.27368164,\n        \"special\": false,\n        \"text\": \" or\"\n      },\n      {\n        \"id\": 20443,\n        \"logprob\": -0.09442139,\n        \"special\": false,\n        \"text\": \" artificial\"\n      },\n      {\n        \"id\": 29728,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" neural\"\n      },\n      {\n        \"id\": 14155,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" networks\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep Learning (DL), or artificial neural networks\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8a8_int_dynamic_weight/test_compressed_tensors_w8a8_int_dynamic_weight_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18183,\n          \"logprob\": -1.5195312,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.06817627,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.13122559,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.13415527,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 25993,\n          \"logprob\": -0.87353516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0011396408,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5662,\n          \"logprob\": -0.16442871,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.0026416779,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 429,\n          \"logprob\": -0.48754883,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5711,\n          \"logprob\": -1.2294922,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18183,\n          \"logprob\": -1.5195312,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.06817627,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.13122559,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.13415527,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 25993,\n          \"logprob\": -0.87353516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0011396408,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5662,\n          \"logprob\": -0.16442871,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.0026416779,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 429,\n          \"logprob\": -0.48754883,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5711,\n          \"logprob\": -1.2294922,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18183,\n          \"logprob\": -1.5195312,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.06817627,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.13122559,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.13415527,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 25993,\n          \"logprob\": -0.87353516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0011396408,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5662,\n          \"logprob\": -0.16442871,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.0026416779,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 429,\n          \"logprob\": -0.48754883,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5711,\n          \"logprob\": -1.2294922,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18183,\n          \"logprob\": -1.5195312,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.06817627,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.13122559,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.13415527,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 25993,\n          \"logprob\": -0.87353516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0011396408,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5662,\n          \"logprob\": -0.16442871,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6832,\n          \"logprob\": -0.0026416779,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 429,\n          \"logprob\": -0.48754883,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5711,\n          \"logprob\": -1.2294922,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that uses\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8an_fp/test_compressed_tensors_w8an.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 18682,\n        \"logprob\": -0.8769531,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.0076942444,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.25073242,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.097595215,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 955,\n        \"logprob\": -0.921875,\n        \"special\": false,\n        \"text\": \" type\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.00027918816,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 21075,\n        \"logprob\": -0.5527344,\n        \"special\": false,\n        \"text\": \" artificial\"\n      },\n      {\n        \"id\": 11478,\n        \"logprob\": -0.042541504,\n        \"special\": false,\n        \"text\": \" intelligence\"\n      },\n      {\n        \"id\": 320,\n        \"logprob\": -0.38891602,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 15836,\n        \"logprob\": -0.0011043549,\n        \"special\": false,\n        \"text\": \"AI\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" Deep learning is a type of artificial intelligence (AI\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8an_fp/test_compressed_tensors_w8an_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5380,\n        \"logprob\": -0.23840332,\n        \"special\": false,\n        \"text\": \"?\\n\"\n      },\n      {\n        \"id\": 34564,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 11,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1101,\n        \"logprob\": -1.2011719,\n        \"special\": false,\n        \"text\": \" also\"\n      },\n      {\n        \"id\": 3967,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" known\"\n      },\n      {\n        \"id\": 439,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" as\"\n      },\n      {\n        \"id\": 30828,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" neural\"\n      },\n      {\n        \"id\": 4009,\n        \"logprob\": -0.6777344,\n        \"special\": false,\n        \"text\": \" network\"\n      },\n      {\n        \"id\": 477,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" or\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep learning, also known as neural network or\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_w8an_fp/test_compressed_tensors_w8an_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -0.8769531,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0076942444,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.25146484,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.097595215,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 955,\n          \"logprob\": -0.9248047,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.00027513504,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 21075,\n          \"logprob\": -0.5527344,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11478,\n          \"logprob\": -0.043151855,\n          \"special\": false,\n          \"text\": \" intelligence\"\n        },\n        {\n          \"id\": 320,\n          \"logprob\": -0.3840332,\n          \"special\": false,\n          \"text\": \" (\"\n        },\n        {\n          \"id\": 15836,\n          \"logprob\": -0.0011043549,\n          \"special\": false,\n          \"text\": \"AI\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a type of artificial intelligence (AI\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -0.875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.007698059,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.25268555,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.09753418,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 955,\n          \"logprob\": -0.92529297,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.00027942657,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 21075,\n          \"logprob\": -0.5527344,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11478,\n          \"logprob\": -0.042541504,\n          \"special\": false,\n          \"text\": \" intelligence\"\n        },\n        {\n          \"id\": 320,\n          \"logprob\": -0.3840332,\n          \"special\": false,\n          \"text\": \" (\"\n        },\n        {\n          \"id\": 15836,\n          \"logprob\": -0.0011053085,\n          \"special\": false,\n          \"text\": \"AI\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a type of artificial intelligence (AI\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -0.875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.007698059,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.25268555,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.09753418,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 955,\n          \"logprob\": -0.92529297,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.00027942657,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 21075,\n          \"logprob\": -0.5527344,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11478,\n          \"logprob\": -0.042541504,\n          \"special\": false,\n          \"text\": \" intelligence\"\n        },\n        {\n          \"id\": 320,\n          \"logprob\": -0.3840332,\n          \"special\": false,\n          \"text\": \" (\"\n        },\n        {\n          \"id\": 15836,\n          \"logprob\": -0.0011053085,\n          \"special\": false,\n          \"text\": \"AI\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a type of artificial intelligence (AI\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -0.875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.007698059,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.25268555,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.09753418,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 955,\n          \"logprob\": -0.92529297,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.00027942657,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 21075,\n          \"logprob\": -0.5527344,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11478,\n          \"logprob\": -0.042541504,\n          \"special\": false,\n          \"text\": \" intelligence\"\n        },\n        {\n          \"id\": 320,\n          \"logprob\": -0.3840332,\n          \"special\": false,\n          \"text\": \" (\"\n        },\n        {\n          \"id\": 15836,\n          \"logprob\": -0.0011053085,\n          \"special\": false,\n          \"text\": \"AI\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a type of artificial intelligence (AI\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 109,\n        \"logprob\": -0.24707031,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 26843,\n        \"logprob\": -0.14550781,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6044,\n        \"logprob\": -0.038330078,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 603,\n        \"logprob\": -0.029907227,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 476,\n        \"logprob\": -0.020996094,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 38397,\n        \"logprob\": -0.828125,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 576,\n        \"logprob\": -0.00049209595,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 6479,\n        \"logprob\": -0.057373047,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6044,\n        \"logprob\": -0.000207901,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 674,\n        \"logprob\": -0.15429688,\n        \"special\": false,\n        \"text\": \" that\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning that\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 235336,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"?\"\n      },\n      {\n        \"id\": 109,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 26843,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 14715,\n        \"logprob\": -0.38671875,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 603,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 476,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 38397,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 576,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 6479,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6044,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\n\\nDeep Learning is a subset of machine learning\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int/test_compressed_tensors_wna16_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 109,\n          \"logprob\": -0.24707031,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 26843,\n          \"logprob\": -0.14550781,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.03857422,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.030883789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 476,\n          \"logprob\": -0.020996094,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 38397,\n          \"logprob\": -0.828125,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 576,\n          \"logprob\": -0.00051498413,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 6479,\n          \"logprob\": -0.05883789,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.00020694733,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 674,\n          \"logprob\": -0.15820312,\n          \"special\": false,\n          \"text\": \" that\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning that\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 109,\n          \"logprob\": -0.23828125,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 26843,\n          \"logprob\": -0.14550781,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.038330078,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.030883789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 476,\n          \"logprob\": -0.020996094,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 38397,\n          \"logprob\": -0.80859375,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 576,\n          \"logprob\": -0.0005455017,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 6479,\n          \"logprob\": -0.05908203,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.00020599365,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 674,\n          \"logprob\": -0.17285156,\n          \"special\": false,\n          \"text\": \" that\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning that\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 109,\n          \"logprob\": -0.23828125,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 26843,\n          \"logprob\": -0.14550781,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.038330078,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.030883789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 476,\n          \"logprob\": -0.020996094,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 38397,\n          \"logprob\": -0.80859375,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 576,\n          \"logprob\": -0.0005455017,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 6479,\n          \"logprob\": -0.05908203,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.00020599365,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 674,\n          \"logprob\": -0.17285156,\n          \"special\": false,\n          \"text\": \" that\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning that\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 109,\n          \"logprob\": -0.23828125,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 26843,\n          \"logprob\": -0.14550781,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.038330078,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.030883789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 476,\n          \"logprob\": -0.020996094,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 38397,\n          \"logprob\": -0.80859375,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 576,\n          \"logprob\": -0.0005455017,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 6479,\n          \"logprob\": -0.05908203,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6044,\n          \"logprob\": -0.00020599365,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 674,\n          \"logprob\": -0.17285156,\n          \"special\": false,\n          \"text\": \" that\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning that\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 34564,\n        \"logprob\": -1.765625,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.023864746,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.1060791,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.1940918,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 27084,\n        \"logprob\": -0.79785156,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.008262634,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5780,\n        \"logprob\": -0.046569824,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.0023479462,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 430,\n        \"logprob\": -0.7626953,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 5829,\n        \"logprob\": -1.0107422,\n        \"special\": false,\n        \"text\": \" uses\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Deep learning is a subset of machine learning that uses\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5380,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"?\\n\"\n      },\n      {\n        \"id\": 34564,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 320,\n        \"logprob\": -0.19580078,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 16931,\n        \"logprob\": -1.7783203,\n        \"special\": false,\n        \"text\": \"DL\"\n      },\n      {\n        \"id\": 8,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \")\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -1.4287109,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 27084,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep learning (DL) is a subset of\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_compressed_tensors_wna16_int_24/test_compressed_tensors_wna16_int_24_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 34564,\n          \"logprob\": -1.765625,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.024002075,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.10760498,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.19580078,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.7993164,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.008300781,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.046295166,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.002374649,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.7651367,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5829,\n          \"logprob\": -1.0107422,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 34564,\n          \"logprob\": -1.7597656,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.024032593,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.10748291,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.19592285,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.7988281,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.008354187,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.046569824,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0023517609,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.7661133,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5829,\n          \"logprob\": -1.0107422,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 34564,\n          \"logprob\": -1.7597656,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.024032593,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.10748291,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.19592285,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.7988281,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.008354187,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.046569824,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0023517609,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.7661133,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5829,\n          \"logprob\": -1.0107422,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Deep learning is a subset of machine learning that uses\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 34564,\n          \"logprob\": -1.7597656,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.024032593,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.10748291,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.19592285,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.7988281,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.008354187,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.046569824,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0023517609,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.7661133,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 5829,\n          \"logprob\": -1.0107422,\n          \"special\": false,\n          \"text\": \" uses\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Deep learning is a subset of machine learning that uses\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\\n\\n1\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1732541189,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"2.4.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 30,\n    \"prompt_tokens\": 49,\n    \"total_tokens\": 79\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \" the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1732541190,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"2.4.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 30,\n    \"prompt_tokens\": 73,\n    \"total_tokens\": 103\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.9306641,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 5618,\n        \"logprob\": -2.4550781,\n        \"special\": false,\n        \"text\": \"What\"\n      },\n      {\n        \"id\": 338,\n        \"logprob\": -0.5732422,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 278,\n        \"logprob\": -1.5761719,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 4328,\n        \"logprob\": -1.5888672,\n        \"special\": false,\n        \"text\": \" difference\"\n      },\n      {\n        \"id\": 1546,\n        \"logprob\": -0.026504517,\n        \"special\": false,\n        \"text\": \" between\"\n      },\n      {\n        \"id\": 21784,\n        \"logprob\": -1.4287109,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 29257,\n        \"logprob\": -0.15856934,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 322,\n        \"logprob\": -0.17456055,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 6189,\n        \"logprob\": -0.62646484,\n        \"special\": false,\n        \"text\": \" Machine\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.19958496,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 4013,\n        \"logprob\": -2.203125,\n        \"special\": false,\n        \"text\": \"This\"\n      },\n      {\n        \"id\": 1139,\n        \"logprob\": -0.23693848,\n        \"special\": false,\n        \"text\": \" question\"\n      },\n      {\n        \"id\": 756,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" has\"\n      },\n      {\n        \"id\": 1063,\n        \"logprob\": -0.076538086,\n        \"special\": false,\n        \"text\": \" been\"\n      },\n      {\n        \"id\": 4433,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" asked\"\n      },\n      {\n        \"id\": 1784,\n        \"logprob\": -1.1367188,\n        \"special\": false,\n        \"text\": \" many\"\n      },\n      {\n        \"id\": 3064,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" times\"\n      },\n      {\n        \"id\": 322,\n        \"logprob\": -1.7460938,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 306,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" I\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is Deep Learning?\\nThis question has been asked many times and I\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_awq/test_flash_llama_awq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9306641,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4550781,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.5732422,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5761719,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5888672,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.026504517,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4287109,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15856934,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62646484,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9306641,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4550781,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.5732422,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5761719,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5888672,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.026504517,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4287109,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15856934,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62646484,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9306641,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4550781,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.5732422,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5761719,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5888672,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.026504517,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4287109,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15856934,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62646484,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9306641,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4550781,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.5732422,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5761719,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5888672,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.026504517,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4287109,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15856934,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62646484,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_load_sharded.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9228516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4609375,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.57177734,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5722656,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5859375,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.02633667,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4335938,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15991211,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62060547,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9228516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4609375,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.57177734,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5722656,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5859375,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.02633667,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4335938,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15991211,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62060547,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9228516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4609375,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.57177734,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5722656,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5859375,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.02633667,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4335938,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15991211,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62060547,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.9228516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 5618,\n          \"logprob\": -2.4609375,\n          \"special\": false,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.57177734,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -1.5722656,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 4328,\n          \"logprob\": -1.5859375,\n          \"special\": false,\n          \"text\": \" difference\"\n        },\n        {\n          \"id\": 1546,\n          \"logprob\": -0.02633667,\n          \"special\": false,\n          \"text\": \" between\"\n        },\n        {\n          \"id\": 21784,\n          \"logprob\": -1.4335938,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 29257,\n          \"logprob\": -0.15991211,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 322,\n          \"logprob\": -0.17456055,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 6189,\n          \"logprob\": -0.62060547,\n          \"special\": false,\n          \"text\": \" Machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_awq_sharded/test_flash_llama_awq_sharded.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.9228516,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 5618,\n        \"logprob\": -2.4609375,\n        \"special\": false,\n        \"text\": \"What\"\n      },\n      {\n        \"id\": 338,\n        \"logprob\": -0.57177734,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 278,\n        \"logprob\": -1.5722656,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 4328,\n        \"logprob\": -1.5927734,\n        \"special\": false,\n        \"text\": \" difference\"\n      },\n      {\n        \"id\": 1546,\n        \"logprob\": -0.026428223,\n        \"special\": false,\n        \"text\": \" between\"\n      },\n      {\n        \"id\": 21784,\n        \"logprob\": -1.4267578,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 29257,\n        \"logprob\": -0.16015625,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 322,\n        \"logprob\": -0.17382812,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 6189,\n        \"logprob\": -0.62060547,\n        \"special\": false,\n        \"text\": \" Machine\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\nWhat is the difference between Deep Learning and Machine\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 185,\n        \"logprob\": -1.546875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 549,\n        \"logprob\": -2.859375,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 1727,\n        \"logprob\": -2.484375,\n        \"special\": false,\n        \"text\": \" test\"\n      },\n      {\n        \"id\": 3102,\n        \"logprob\": -0.83203125,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 317,\n        \"logprob\": -1.1484375,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 245,\n        \"logprob\": -1.578125,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 3412,\n        \"logprob\": -2.578125,\n        \"special\": false,\n        \"text\": \" document\"\n      },\n      {\n        \"id\": 344,\n        \"logprob\": -1.125,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 317,\n        \"logprob\": -1.6953125,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 1222,\n        \"logprob\": -1.71875,\n        \"special\": false,\n        \"text\": \" used\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\nThe test request is a document that is used\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 4,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 2143,\n        \"logprob\": -1.828125,\n        \"special\": false,\n        \"text\": \" sent\"\n      },\n      {\n        \"id\": 10081,\n        \"logprob\": -0.41210938,\n        \"special\": false,\n        \"text\": \" successfully\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 100001,\n        \"logprob\": -0.16015625,\n        \"special\": true,\n        \"text\": \"<｜end▁of▁sentence｜>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request sent successfully.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 185,\n          \"logprob\": -1.546875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 549,\n          \"logprob\": -2.859375,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 1727,\n          \"logprob\": -2.4375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3102,\n          \"logprob\": -0.83984375,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 317,\n          \"logprob\": -1.1328125,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -1.515625,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -1.15625,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 3458,\n          \"logprob\": -0.3671875,\n          \"special\": false,\n          \"text\": \" step\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -0.88671875,\n          \"special\": false,\n          \"text\": \" in\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.69140625,\n          \"special\": false,\n          \"text\": \" the\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nThe test request is the first step in the\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 185,\n          \"logprob\": -1.546875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 549,\n          \"logprob\": -2.859375,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 1727,\n          \"logprob\": -2.4375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3102,\n          \"logprob\": -0.83984375,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 317,\n          \"logprob\": -1.1328125,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -1.515625,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -1.15625,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 3458,\n          \"logprob\": -0.3671875,\n          \"special\": false,\n          \"text\": \" step\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -0.88671875,\n          \"special\": false,\n          \"text\": \" in\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.69140625,\n          \"special\": false,\n          \"text\": \" the\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nThe test request is the first step in the\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 185,\n          \"logprob\": -1.546875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 549,\n          \"logprob\": -2.859375,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 1727,\n          \"logprob\": -2.4375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3102,\n          \"logprob\": -0.83984375,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 317,\n          \"logprob\": -1.1328125,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -1.515625,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -1.15625,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 3458,\n          \"logprob\": -0.3671875,\n          \"special\": false,\n          \"text\": \" step\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -0.88671875,\n          \"special\": false,\n          \"text\": \" in\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.69140625,\n          \"special\": false,\n          \"text\": \" the\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nThe test request is the first step in the\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 185,\n          \"logprob\": -1.546875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 549,\n          \"logprob\": -2.859375,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 1727,\n          \"logprob\": -2.4375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3102,\n          \"logprob\": -0.83984375,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 317,\n          \"logprob\": -1.1328125,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -1.515625,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -1.15625,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 3458,\n          \"logprob\": -0.3671875,\n          \"special\": false,\n          \"text\": \" step\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -0.88671875,\n          \"special\": false,\n          \"text\": \" in\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.69140625,\n          \"special\": false,\n          \"text\": \" the\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nThe test request is the first step in the\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 23090,\n        \"logprob\": -1.8251953,\n        \"special\": false,\n        \"text\": \" Hello\"\n      },\n      {\n        \"id\": 23,\n        \"logprob\": -0.3173828,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 8156,\n        \"logprob\": -0.23803711,\n        \"special\": false,\n        \"text\": \" Daniel\"\n      },\n      {\n        \"id\": 12,\n        \"logprob\": -0.56933594,\n        \"special\": false,\n        \"text\": \"!\"\n      },\n      {\n        \"id\": 193,\n        \"logprob\": -0.61279297,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 23626,\n        \"logprob\": -0.41967773,\n        \"special\": false,\n        \"text\": \"Daniel\"\n      },\n      {\n        \"id\": 37,\n        \"logprob\": -0.0023403168,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 1634,\n        \"logprob\": -2.0605469,\n        \"special\": false,\n        \"text\": \" What\"\n      },\n      {\n        \"id\": 18,\n        \"logprob\": -1.5292969,\n        \"special\": false,\n        \"text\": \"'\"\n      },\n      {\n        \"id\": 94,\n        \"logprob\": -0.007904053,\n        \"special\": false,\n        \"text\": \"s\"\n      }\n    ]\n  },\n  \"generated_text\": \" Hello, Daniel!\\nDaniel: What's\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 836,\n        \"logprob\": -1.265625,\n        \"special\": false,\n        \"text\": \" i\"\n      },\n      {\n        \"id\": 18,\n        \"logprob\": -0.119628906,\n        \"special\": false,\n        \"text\": \"'\"\n      },\n      {\n        \"id\": 298,\n        \"logprob\": -2.265625,\n        \"special\": false,\n        \"text\": \"ve\"\n      },\n      {\n        \"id\": 650,\n        \"logprob\": -0.49804688,\n        \"special\": false,\n        \"text\": \" been\"\n      },\n      {\n        \"id\": 1241,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" using\"\n      },\n      {\n        \"id\": 334,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" it\"\n      },\n      {\n        \"id\": 312,\n        \"logprob\": -1.2421875,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 909,\n        \"logprob\": -0.99609375,\n        \"special\": false,\n        \"text\": \" years\"\n      },\n      {\n        \"id\": 193,\n        \"logprob\": -0.30273438,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 807,\n        \"logprob\": -1.078125,\n        \"special\": false,\n        \"text\": \"ik\"\n      }\n    ]\n  },\n  \"generated_text\": \"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\\nDaniel: Hello, Girafatron!\\nGirafatron: i've been using it for years\\nik\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_falcon/test_flash_falcon_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 23090,\n          \"logprob\": -1.828125,\n          \"special\": false,\n          \"text\": \" Hello\"\n        },\n        {\n          \"id\": 23,\n          \"logprob\": -0.3178711,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 8156,\n          \"logprob\": -0.23925781,\n          \"special\": false,\n          \"text\": \" Daniel\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -0.5698242,\n          \"special\": false,\n          \"text\": \"!\"\n        },\n        {\n          \"id\": 193,\n          \"logprob\": -0.61279297,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23626,\n          \"logprob\": -0.4177246,\n          \"special\": false,\n          \"text\": \"Daniel\"\n        },\n        {\n          \"id\": 37,\n          \"logprob\": -0.0023345947,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1634,\n          \"logprob\": -2.0605469,\n          \"special\": false,\n          \"text\": \" What\"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.5283203,\n          \"special\": false,\n          \"text\": \"'\"\n        },\n        {\n          \"id\": 94,\n          \"logprob\": -0.007965088,\n          \"special\": false,\n          \"text\": \"s\"\n        }\n      ]\n    },\n    \"generated_text\": \" Hello, Daniel!\\nDaniel: What's\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 23090,\n          \"logprob\": -1.8251953,\n          \"special\": false,\n          \"text\": \" Hello\"\n        },\n        {\n          \"id\": 23,\n          \"logprob\": -0.31762695,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 8156,\n          \"logprob\": -0.2388916,\n          \"special\": false,\n          \"text\": \" Daniel\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -0.5698242,\n          \"special\": false,\n          \"text\": \"!\"\n        },\n        {\n          \"id\": 193,\n          \"logprob\": -0.6152344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23626,\n          \"logprob\": -0.42211914,\n          \"special\": false,\n          \"text\": \"Daniel\"\n        },\n        {\n          \"id\": 37,\n          \"logprob\": -0.002336502,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1634,\n          \"logprob\": -2.0605469,\n          \"special\": false,\n          \"text\": \" What\"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.5292969,\n          \"special\": false,\n          \"text\": \"'\"\n        },\n        {\n          \"id\": 94,\n          \"logprob\": -0.007926941,\n          \"special\": false,\n          \"text\": \"s\"\n        }\n      ]\n    },\n    \"generated_text\": \" Hello, Daniel!\\nDaniel: What's\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 23090,\n          \"logprob\": -1.8251953,\n          \"special\": false,\n          \"text\": \" Hello\"\n        },\n        {\n          \"id\": 23,\n          \"logprob\": -0.31762695,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 8156,\n          \"logprob\": -0.2388916,\n          \"special\": false,\n          \"text\": \" Daniel\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -0.5698242,\n          \"special\": false,\n          \"text\": \"!\"\n        },\n        {\n          \"id\": 193,\n          \"logprob\": -0.6152344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23626,\n          \"logprob\": -0.42211914,\n          \"special\": false,\n          \"text\": \"Daniel\"\n        },\n        {\n          \"id\": 37,\n          \"logprob\": -0.002336502,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1634,\n          \"logprob\": -2.0605469,\n          \"special\": false,\n          \"text\": \" What\"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.5292969,\n          \"special\": false,\n          \"text\": \"'\"\n        },\n        {\n          \"id\": 94,\n          \"logprob\": -0.007926941,\n          \"special\": false,\n          \"text\": \"s\"\n        }\n      ]\n    },\n    \"generated_text\": \" Hello, Daniel!\\nDaniel: What's\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 23090,\n          \"logprob\": -1.8251953,\n          \"special\": false,\n          \"text\": \" Hello\"\n        },\n        {\n          \"id\": 23,\n          \"logprob\": -0.31762695,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 8156,\n          \"logprob\": -0.2388916,\n          \"special\": false,\n          \"text\": \" Daniel\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -0.5698242,\n          \"special\": false,\n          \"text\": \"!\"\n        },\n        {\n          \"id\": 193,\n          \"logprob\": -0.6152344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23626,\n          \"logprob\": -0.42211914,\n          \"special\": false,\n          \"text\": \"Daniel\"\n        },\n        {\n          \"id\": 37,\n          \"logprob\": -0.002336502,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1634,\n          \"logprob\": -2.0605469,\n          \"special\": false,\n          \"text\": \" What\"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.5292969,\n          \"special\": false,\n          \"text\": \"'\"\n        },\n        {\n          \"id\": 94,\n          \"logprob\": -0.007926941,\n          \"special\": false,\n          \"text\": \"s\"\n        }\n      ]\n    },\n    \"generated_text\": \" Hello, Daniel!\\nDaniel: What's\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 7539,\n        \"logprob\": -0.609375,\n        \"special\": false,\n        \"text\": \" forms\"\n      },\n      {\n        \"id\": 708,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" are\"\n      },\n      {\n        \"id\": 671,\n        \"logprob\": -1.5546875,\n        \"special\": false,\n        \"text\": \" an\"\n      },\n      {\n        \"id\": 8727,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" essential\"\n      },\n      {\n        \"id\": 1702,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" part\"\n      },\n      {\n        \"id\": 576,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 11859,\n        \"logprob\": -1.953125,\n        \"special\": false,\n        \"text\": \" lab\"\n      },\n      {\n        \"id\": 2185,\n        \"logprob\": -1.7734375,\n        \"special\": false,\n        \"text\": \" process\"\n      },\n      {\n        \"id\": 235265,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request forms are an essential part of the lab process.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1736,\n          \"logprob\": -2.09375,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -1.9140625,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 651,\n          \"logprob\": -2.453125,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 2121,\n          \"logprob\": -1.8984375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -0.23535156,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 1736,\n          \"logprob\": -0.091308594,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.96875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 1671,\n          \"logprob\": -1.6484375,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 577,\n          \"logprob\": -0.43164062,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -1.2421875,\n          \"special\": false,\n          \"text\": \" request\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" form\\n\\nThe test request form is used to request\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1736,\n          \"logprob\": -2.09375,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -1.9140625,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 651,\n          \"logprob\": -2.453125,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 2121,\n          \"logprob\": -1.8984375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -0.23535156,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 1736,\n          \"logprob\": -0.091308594,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.96875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 1671,\n          \"logprob\": -1.6484375,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 577,\n          \"logprob\": -0.43164062,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -1.2421875,\n          \"special\": false,\n          \"text\": \" request\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" form\\n\\nThe test request form is used to request\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1736,\n          \"logprob\": -2.09375,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -1.9140625,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 651,\n          \"logprob\": -2.453125,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 2121,\n          \"logprob\": -1.8984375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -0.23535156,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 1736,\n          \"logprob\": -0.091308594,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.96875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 1671,\n          \"logprob\": -1.6484375,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 577,\n          \"logprob\": -0.43164062,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -1.2421875,\n          \"special\": false,\n          \"text\": \" request\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" form\\n\\nThe test request form is used to request\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1736,\n          \"logprob\": -2.09375,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -1.9140625,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 651,\n          \"logprob\": -2.453125,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 2121,\n          \"logprob\": -1.8984375,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -0.23535156,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 1736,\n          \"logprob\": -0.091308594,\n          \"special\": false,\n          \"text\": \" form\"\n        },\n        {\n          \"id\": 603,\n          \"logprob\": -0.96875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 1671,\n          \"logprob\": -1.6484375,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 577,\n          \"logprob\": -0.43164062,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 3853,\n          \"logprob\": -1.2421875,\n          \"special\": false,\n          \"text\": \" request\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" form\\n\\nThe test request form is used to request\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 1736,\n        \"logprob\": -2.109375,\n        \"special\": false,\n        \"text\": \" form\"\n      },\n      {\n        \"id\": 109,\n        \"logprob\": -1.90625,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 651,\n        \"logprob\": -2.4375,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 2121,\n        \"logprob\": -1.796875,\n        \"special\": false,\n        \"text\": \" test\"\n      },\n      {\n        \"id\": 3853,\n        \"logprob\": -0.24511719,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 1736,\n        \"logprob\": -0.09326172,\n        \"special\": false,\n        \"text\": \" form\"\n      },\n      {\n        \"id\": 603,\n        \"logprob\": -0.95703125,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 1671,\n        \"logprob\": -1.5859375,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 577,\n        \"logprob\": -0.39257812,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 3853,\n        \"logprob\": -1.25,\n        \"special\": false,\n        \"text\": \" request\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" form\\n\\nThe test request form is used to request\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 688,\n        \"logprob\": -0.546875,\n        \"special\": false,\n        \"text\": \"**\"\n      },\n      {\n        \"id\": 103889,\n        \"logprob\": -0.49023438,\n        \"special\": false,\n        \"text\": \"Hydrogen\"\n      },\n      {\n        \"id\": 190213,\n        \"logprob\": -0.48632812,\n        \"special\": false,\n        \"text\": \"**,\"\n      },\n      {\n        \"id\": 2611,\n        \"logprob\": -0.58203125,\n        \"special\": false,\n        \"text\": \" light\"\n      },\n      {\n        \"id\": 578,\n        \"logprob\": -0.099121094,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 2223,\n        \"logprob\": -1.078125,\n        \"special\": false,\n        \"text\": \" free\"\n      },\n      {\n        \"id\": 235269,\n        \"logprob\": -0.025756836,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.29101562,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 688,\n        \"logprob\": -0.0035858154,\n        \"special\": false,\n        \"text\": \"**\"\n      },\n      {\n        \"id\": 1949,\n        \"logprob\": -4.1007996e-05,\n        \"special\": false,\n        \"text\": \"He\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"**Hydrogen**, light and free,\\n**He\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma2/test_flash_gemma2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 688,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 103889,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"Hydrogen\"\n        },\n        {\n          \"id\": 190213,\n          \"logprob\": -0.48632812,\n          \"special\": false,\n          \"text\": \"**,\"\n        },\n        {\n          \"id\": 2611,\n          \"logprob\": -0.58203125,\n          \"special\": false,\n          \"text\": \" light\"\n        },\n        {\n          \"id\": 578,\n          \"logprob\": -0.08886719,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 2223,\n          \"logprob\": -1.09375,\n          \"special\": false,\n          \"text\": \" free\"\n        },\n        {\n          \"id\": 235269,\n          \"logprob\": -0.024291992,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 108,\n          \"logprob\": -0.30664062,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 688,\n          \"logprob\": -0.0035552979,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 1949,\n          \"logprob\": -4.220009e-05,\n          \"special\": false,\n          \"text\": \"He\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"**Hydrogen**, light and free,\\n**He\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 688,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 103889,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"Hydrogen\"\n        },\n        {\n          \"id\": 190213,\n          \"logprob\": -0.48632812,\n          \"special\": false,\n          \"text\": \"**,\"\n        },\n        {\n          \"id\": 2611,\n          \"logprob\": -0.58203125,\n          \"special\": false,\n          \"text\": \" light\"\n        },\n        {\n          \"id\": 578,\n          \"logprob\": -0.08886719,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 2223,\n          \"logprob\": -1.09375,\n          \"special\": false,\n          \"text\": \" free\"\n        },\n        {\n          \"id\": 235269,\n          \"logprob\": -0.024291992,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 108,\n          \"logprob\": -0.30664062,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 688,\n          \"logprob\": -0.0035552979,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 1949,\n          \"logprob\": -4.220009e-05,\n          \"special\": false,\n          \"text\": \"He\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"**Hydrogen**, light and free,\\n**He\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 688,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 103889,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"Hydrogen\"\n        },\n        {\n          \"id\": 190213,\n          \"logprob\": -0.48632812,\n          \"special\": false,\n          \"text\": \"**,\"\n        },\n        {\n          \"id\": 2611,\n          \"logprob\": -0.58203125,\n          \"special\": false,\n          \"text\": \" light\"\n        },\n        {\n          \"id\": 578,\n          \"logprob\": -0.08984375,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 2223,\n          \"logprob\": -1.1015625,\n          \"special\": false,\n          \"text\": \" free\"\n        },\n        {\n          \"id\": 235269,\n          \"logprob\": -0.024291992,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 108,\n          \"logprob\": -0.30664062,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 688,\n          \"logprob\": -0.0038452148,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 1949,\n          \"logprob\": -4.1484833e-05,\n          \"special\": false,\n          \"text\": \"He\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"**Hydrogen**, light and free,\\n**He\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 688,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 103889,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"Hydrogen\"\n        },\n        {\n          \"id\": 190213,\n          \"logprob\": -0.48632812,\n          \"special\": false,\n          \"text\": \"**,\"\n        },\n        {\n          \"id\": 2611,\n          \"logprob\": -0.58203125,\n          \"special\": false,\n          \"text\": \" light\"\n        },\n        {\n          \"id\": 578,\n          \"logprob\": -0.08886719,\n          \"special\": false,\n          \"text\": \" and\"\n        },\n        {\n          \"id\": 2223,\n          \"logprob\": -1.09375,\n          \"special\": false,\n          \"text\": \" free\"\n        },\n        {\n          \"id\": 235269,\n          \"logprob\": -0.024291992,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 108,\n          \"logprob\": -0.30664062,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 688,\n          \"logprob\": -0.0035552979,\n          \"special\": false,\n          \"text\": \"**\"\n        },\n        {\n          \"id\": 1949,\n          \"logprob\": -4.220009e-05,\n          \"special\": false,\n          \"text\": \"He\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"**Hydrogen**, light and free,\\n**He\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_exceed_window.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 16,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 506,\n        \"logprob\": -1.3984375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 1331,\n        \"logprob\": -1.6953125,\n        \"special\": false,\n        \"text\": \" people\"\n      },\n      {\n        \"id\": 236764,\n        \"logprob\": -0.23535156,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 532,\n        \"logprob\": -0.24316406,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.12109375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2780,\n        \"logprob\": -1.1640625,\n        \"special\": false,\n        \"text\": \" food\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -0.21386719,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.64453125,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 2094,\n        \"logprob\": -0.77734375,\n        \"special\": false,\n        \"text\": \"This\"\n      },\n      {\n        \"id\": 563,\n        \"logprob\": -0.040283203,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 496,\n        \"logprob\": -0.03125,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 6290,\n        \"logprob\": -0.03515625,\n        \"special\": false,\n        \"text\": \" nice\"\n      },\n      {\n        \"id\": 1977,\n        \"logprob\": -0.0020751953,\n        \"special\": false,\n        \"text\": \" place\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -0.0079956055,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 107,\n        \"logprob\": -0.9921875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 106,\n        \"logprob\": -0.45507812,\n        \"special\": true,\n        \"text\": \"<end_of_turn>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" the people, and the food.\\n\\nThis is a nice place.\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 100,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 1331,\n        \"logprob\": -0.31835938,\n        \"special\": false,\n        \"text\": \" people\"\n      },\n      {\n        \"id\": 8390,\n        \"logprob\": -0.1484375,\n        \"special\": false,\n        \"text\": \" died\"\n      },\n      {\n        \"id\": 528,\n        \"logprob\": -1.1171875,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.45898438,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 3640,\n        \"logprob\": -0.55859375,\n        \"special\": false,\n        \"text\": \" United\"\n      },\n      {\n        \"id\": 4184,\n        \"logprob\": -0.0026397705,\n        \"special\": false,\n        \"text\": \" States\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -0.38085938,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.07421875,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 818,\n        \"logprob\": -1.0859375,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 6816,\n        \"logprob\": -1.75,\n        \"special\": false,\n        \"text\": \" generally\"\n      },\n      {\n        \"id\": 10951,\n        \"logprob\": -0.14648438,\n        \"special\": false,\n        \"text\": \" accepted\"\n      },\n      {\n        \"id\": 10967,\n        \"logprob\": -0.9609375,\n        \"special\": false,\n        \"text\": \" estimate\"\n      },\n      {\n        \"id\": 563,\n        \"logprob\": -0.49414062,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 600,\n        \"logprob\": -0.703125,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 236743,\n        \"logprob\": -1.171875,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 236825,\n        \"logprob\": -0.0009918213,\n        \"special\": false,\n        \"text\": \"6\"\n      },\n      {\n        \"id\": 236832,\n        \"logprob\": -6.389618e-05,\n        \"special\": false,\n        \"text\": \"7\"\n      },\n      {\n        \"id\": 236810,\n        \"logprob\": -4.7445297e-05,\n        \"special\": false,\n        \"text\": \"5\"\n      },\n      {\n        \"id\": 236764,\n        \"logprob\": -0.00017929077,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 236771,\n        \"logprob\": -1.4901161e-05,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 236771,\n        \"logprob\": -1.7881393e-06,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 236771,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 1331,\n        \"logprob\": -0.45898438,\n        \"special\": false,\n        \"text\": \" people\"\n      },\n      {\n        \"id\": 8390,\n        \"logprob\": -0.011474609,\n        \"special\": false,\n        \"text\": \" died\"\n      },\n      {\n        \"id\": 528,\n        \"logprob\": -0.084472656,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.00032615662,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 3640,\n        \"logprob\": -0.029785156,\n        \"special\": false,\n        \"text\": \" United\"\n      },\n      {\n        \"id\": 4184,\n        \"logprob\": -0.00012302399,\n        \"special\": false,\n        \"text\": \" States\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -1.1796875,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 3153,\n        \"logprob\": -0.09667969,\n        \"special\": false,\n        \"text\": \" However\"\n      },\n      {\n        \"id\": 236764,\n        \"logprob\": -0.009094238,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1070,\n        \"logprob\": -0.91015625,\n        \"special\": false,\n        \"text\": \" some\"\n      },\n      {\n        \"id\": 61806,\n        \"logprob\": -0.859375,\n        \"special\": false,\n        \"text\": \" historians\"\n      },\n      {\n        \"id\": 4646,\n        \"logprob\": -1.3828125,\n        \"special\": false,\n        \"text\": \" believe\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.65234375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 5396,\n        \"logprob\": -0.765625,\n        \"special\": false,\n        \"text\": \" actual\"\n      },\n      {\n        \"id\": 1548,\n        \"logprob\": -0.048339844,\n        \"special\": false,\n        \"text\": \" number\"\n      },\n      {\n        \"id\": 1451,\n        \"logprob\": -0.65625,\n        \"special\": false,\n        \"text\": \" could\"\n      },\n      {\n        \"id\": 577,\n        \"logprob\": -0.09082031,\n        \"special\": false,\n        \"text\": \" be\"\n      },\n      {\n        \"id\": 618,\n        \"logprob\": -0.625,\n        \"special\": false,\n        \"text\": \" as\"\n      },\n      {\n        \"id\": 1494,\n        \"logprob\": -0.00037193298,\n        \"special\": false,\n        \"text\": \" high\"\n      },\n      {\n        \"id\": 618,\n        \"logprob\": -0.0001296997,\n        \"special\": false,\n        \"text\": \" as\"\n      },\n      {\n        \"id\": 236743,\n        \"logprob\": -0.00093460083,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 236770,\n        \"logprob\": -0.21289062,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 236771,\n        \"logprob\": -0.16796875,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 3625,\n        \"logprob\": -0.0126953125,\n        \"special\": false,\n        \"text\": \" million\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -0.22460938,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.3984375,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 236777,\n        \"logprob\": -1.078125,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 1006,\n        \"logprob\": -1.359375,\n        \"special\": false,\n        \"text\": \" am\"\n      },\n      {\n        \"id\": 3182,\n        \"logprob\": -1.0859375,\n        \"special\": false,\n        \"text\": \" looking\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -0.035888672,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 919,\n        \"logprob\": -1.2578125,\n        \"special\": false,\n        \"text\": \" more\"\n      },\n      {\n        \"id\": 1938,\n        \"logprob\": -1.3046875,\n        \"special\": false,\n        \"text\": \" information\"\n      },\n      {\n        \"id\": 580,\n        \"logprob\": -0.7421875,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 672,\n        \"logprob\": -0.78125,\n        \"special\": false,\n        \"text\": \" this\"\n      },\n      {\n        \"id\": 59725,\n        \"logprob\": -0.7109375,\n        \"special\": false,\n        \"text\": \" discrepancy\"\n      },\n      {\n        \"id\": 532,\n        \"logprob\": -0.8046875,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.71484375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 5872,\n        \"logprob\": -1.1640625,\n        \"special\": false,\n        \"text\": \" factors\"\n      },\n      {\n        \"id\": 600,\n        \"logprob\": -0.20410156,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 19263,\n        \"logprob\": -1.1484375,\n        \"special\": false,\n        \"text\": \" contributed\"\n      },\n      {\n        \"id\": 531,\n        \"logprob\": -0.000957489,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.19921875,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 5777,\n        \"logprob\": -1.171875,\n        \"special\": false,\n        \"text\": \" wide\"\n      },\n      {\n        \"id\": 2644,\n        \"logprob\": -0.020141602,\n        \"special\": false,\n        \"text\": \" range\"\n      },\n      {\n        \"id\": 529,\n        \"logprob\": -0.14550781,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 14287,\n        \"logprob\": -0.03564453,\n        \"special\": false,\n        \"text\": \" estimates\"\n      },\n      {\n        \"id\": 236761,\n        \"logprob\": -0.010620117,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.060302734,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 8291,\n        \"logprob\": -0.7421875,\n        \"special\": false,\n        \"text\": \"Here\"\n      },\n      {\n        \"id\": 236789,\n        \"logprob\": -0.24023438,\n        \"special\": false,\n        \"text\": \"'\"\n      },\n      {\n        \"id\": 236751,\n        \"logprob\": -1.0728836e-06,\n        \"special\": false,\n        \"text\": \"s\"\n      },\n      {\n        \"id\": 496,\n        \"logprob\": -0.16992188,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 25890,\n        \"logprob\": -0.06933594,\n        \"special\": false,\n        \"text\": \" breakdown\"\n      },\n      {\n        \"id\": 529,\n        \"logprob\": -0.002243042,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.18554688,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 5872,\n        \"logprob\": -0.9921875,\n        \"special\": false,\n        \"text\": \" factors\"\n      },\n      {\n        \"id\": 20894,\n        \"logprob\": -0.25976562,\n        \"special\": false,\n        \"text\": \" contributing\"\n      },\n      {\n        \"id\": 531,\n        \"logprob\": -8.440018e-05,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.009765625,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 5777,\n        \"logprob\": -0.67578125,\n        \"special\": false,\n        \"text\": \" wide\"\n      },\n      {\n        \"id\": 2644,\n        \"logprob\": -0.0023956299,\n        \"special\": false,\n        \"text\": \" range\"\n      },\n      {\n        \"id\": 529,\n        \"logprob\": -0.014831543,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 14287,\n        \"logprob\": -0.012329102,\n        \"special\": false,\n        \"text\": \" estimates\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -0.3125,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -0.21484375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 236743,\n        \"logprob\": -0.43359375,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 236770,\n        \"logprob\": -3.5762787e-07,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 236819,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"9\"\n      },\n      {\n        \"id\": 236770,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 236828,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"8\"\n      },\n      {\n        \"id\": 7745,\n        \"logprob\": -0.703125,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 10248,\n        \"logprob\": -0.013427734,\n        \"special\": false,\n        \"text\": \" pandemic\"\n      },\n      {\n        \"id\": 4355,\n        \"logprob\": -0.6953125,\n        \"special\": false,\n        \"text\": \" death\"\n      },\n      {\n        \"id\": 25363,\n        \"logprob\": -6.771088e-05,\n        \"special\": false,\n        \"text\": \" toll\"\n      },\n      {\n        \"id\": 528,\n        \"logprob\": -0.076171875,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 506,\n        \"logprob\": -7.2717667e-06,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 3640,\n        \"logprob\": -0.0052490234,\n        \"special\": false,\n        \"text\": \" United\"\n      },\n      {\n        \"id\": 4184,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" States\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" people died in the United States.\\n\\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\\n\\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\\n\\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_jpg.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"Okay, let's analyze the image.\\n\\nThe image is a solid, bright white color. There is nothing else visible within it. \\n\\nIt's essentially a blank white square or rectangle.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747062956,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 42,\n    \"prompt_tokens\": 277,\n    \"total_tokens\": 319\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgb_png.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"Okay, let's analyze the image. \\n\\nThe image is a very plain, solid white square. That's it! \\n\\nIt's essentially a blank canvas. \\n\\nDo you want me to describe it in more detail, or are you interested in something else regarding this image?\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747062955,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 62,\n    \"prompt_tokens\": 277,\n    \"total_tokens\": 339\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"Okay, let's analyze the image. \\n\\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \\n\\nDo you want me to describe any specific element of the image in more detail?\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747062952,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 67,\n    \"prompt_tokens\": 277,\n    \"total_tokens\": 344\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"Here's a description of what's shown in the image:\\n\\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \\n\\nIt's a quite a humorous and unusual scene – a cow enjoying a beach day!\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747216083,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 72,\n    \"prompt_tokens\": 275,\n    \"total_tokens\": 347\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \\n\\nBrown Swiss cows are known for their beautiful reddish-brown coats and distinctive white markings. \\n\\nIf you'd like, you can send me another image, and I'll do my best to identify the animal in it!\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747216080,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 80,\n    \"prompt_tokens\": 279,\n    \"total_tokens\": 359\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 604,\n        \"logprob\": -2.4296875,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -2.4453125,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2412,\n        \"logprob\": -2.8632812,\n        \"special\": false,\n        \"text\": \" following\"\n      },\n      {\n        \"id\": 235292,\n        \"logprob\": -2.1328125,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 109,\n        \"logprob\": -0.76660156,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 235287,\n        \"logprob\": -1.3837891,\n        \"special\": false,\n        \"text\": \"*\"\n      },\n      {\n        \"id\": 235248,\n        \"logprob\": -1.9746094,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 199,\n        \"logprob\": -1.4189453,\n        \"special\": false,\n        \"text\": \"<strong>\"\n      },\n      {\n        \"id\": 1232,\n        \"logprob\": -4.34375,\n        \"special\": false,\n        \"text\": \"Name\"\n      },\n      {\n        \"id\": 208,\n        \"logprob\": -0.8852539,\n        \"special\": false,\n        \"text\": \"</strong>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" for the following:\\n\\n* <strong>Name</strong>\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 604,\n        \"logprob\": -0.28271484,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -0.19030762,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 16819,\n        \"logprob\": -1.4863281,\n        \"special\": false,\n        \"text\": \" detection\"\n      },\n      {\n        \"id\": 576,\n        \"logprob\": -0.7089844,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -2.0410156,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 8566,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" presence\"\n      },\n      {\n        \"id\": 689,\n        \"logprob\": -0.16491699,\n        \"special\": false,\n        \"text\": \" or\"\n      },\n      {\n        \"id\": 14862,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" absence\"\n      },\n      {\n        \"id\": 576,\n        \"logprob\": -0.9970703,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 671,\n        \"logprob\": -0.5292969,\n        \"special\": false,\n        \"text\": \" an\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request for the detection of the presence or absence of an\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 604,\n          \"logprob\": -2.4277344,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 573,\n          \"logprob\": -2.4394531,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 2412,\n          \"logprob\": -2.8613281,\n          \"special\": false,\n          \"text\": \" following\"\n        },\n        {\n          \"id\": 235292,\n          \"logprob\": -2.1523438,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -0.76220703,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 235287,\n          \"logprob\": -1.3642578,\n          \"special\": false,\n          \"text\": \"*\"\n        },\n        {\n          \"id\": 235248,\n          \"logprob\": -2.0175781,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 199,\n          \"logprob\": -1.4238281,\n          \"special\": false,\n          \"text\": \"<strong>\"\n        },\n        {\n          \"id\": 1232,\n          \"logprob\": -4.328125,\n          \"special\": false,\n          \"text\": \"Name\"\n        },\n        {\n          \"id\": 208,\n          \"logprob\": -0.8881836,\n          \"special\": false,\n          \"text\": \"</strong>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the following:\\n\\n* <strong>Name</strong>\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 604,\n          \"logprob\": -2.4238281,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 573,\n          \"logprob\": -2.4453125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 2412,\n          \"logprob\": -2.859375,\n          \"special\": false,\n          \"text\": \" following\"\n        },\n        {\n          \"id\": 235292,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -0.7631836,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 235287,\n          \"logprob\": -1.3642578,\n          \"special\": false,\n          \"text\": \"*\"\n        },\n        {\n          \"id\": 235248,\n          \"logprob\": -1.9960938,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 199,\n          \"logprob\": -1.4179688,\n          \"special\": false,\n          \"text\": \"<strong>\"\n        },\n        {\n          \"id\": 1232,\n          \"logprob\": -4.3359375,\n          \"special\": false,\n          \"text\": \"Name\"\n        },\n        {\n          \"id\": 208,\n          \"logprob\": -0.8847656,\n          \"special\": false,\n          \"text\": \"</strong>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the following:\\n\\n* <strong>Name</strong>\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 604,\n          \"logprob\": -2.4257812,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 573,\n          \"logprob\": -2.4453125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 2412,\n          \"logprob\": -2.8789062,\n          \"special\": false,\n          \"text\": \" following\"\n        },\n        {\n          \"id\": 235292,\n          \"logprob\": -2.1367188,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -0.76171875,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 235287,\n          \"logprob\": -1.3515625,\n          \"special\": false,\n          \"text\": \"*\"\n        },\n        {\n          \"id\": 235248,\n          \"logprob\": -1.9873047,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 199,\n          \"logprob\": -1.4169922,\n          \"special\": false,\n          \"text\": \"<strong>\"\n        },\n        {\n          \"id\": 1232,\n          \"logprob\": -4.3320312,\n          \"special\": false,\n          \"text\": \"Name\"\n        },\n        {\n          \"id\": 208,\n          \"logprob\": -0.8930664,\n          \"special\": false,\n          \"text\": \"</strong>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the following:\\n\\n* <strong>Name</strong>\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 604,\n          \"logprob\": -2.4179688,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 573,\n          \"logprob\": -2.4492188,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 2412,\n          \"logprob\": -2.8574219,\n          \"special\": false,\n          \"text\": \" following\"\n        },\n        {\n          \"id\": 235292,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 109,\n          \"logprob\": -0.7519531,\n          \"special\": false,\n          \"text\": \"\\n\\n\"\n        },\n        {\n          \"id\": 235287,\n          \"logprob\": -1.3623047,\n          \"special\": false,\n          \"text\": \"*\"\n        },\n        {\n          \"id\": 235248,\n          \"logprob\": -1.9707031,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 199,\n          \"logprob\": -1.4267578,\n          \"special\": false,\n          \"text\": \"<strong>\"\n        },\n        {\n          \"id\": 1232,\n          \"logprob\": -4.3359375,\n          \"special\": false,\n          \"text\": \"Name\"\n        },\n        {\n          \"id\": 208,\n          \"logprob\": -0.88427734,\n          \"special\": false,\n          \"text\": \"</strong>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the following:\\n\\n* <strong>Name</strong>\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 198,\n        \"logprob\": -0.68603516,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 198,\n        \"logprob\": -0.005393982,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 29744,\n        \"logprob\": -0.31079102,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 4673,\n        \"logprob\": -0.08300781,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 318,\n        \"logprob\": -0.58984375,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 257,\n        \"logprob\": -0.953125,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 649,\n        \"logprob\": -2.0957031,\n        \"special\": false,\n        \"text\": \" new\"\n      },\n      {\n        \"id\": 2214,\n        \"logprob\": -1.8095703,\n        \"special\": false,\n        \"text\": \" field\"\n      },\n      {\n        \"id\": 286,\n        \"logprob\": -1.0673828,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 2267,\n        \"logprob\": -0.9375,\n        \"special\": false,\n        \"text\": \" research\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nDeep learning is a new field of research\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_gpt2/test_flash_gpt2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -0.68603516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.005672455,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29744,\n          \"logprob\": -0.3251953,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4673,\n          \"logprob\": -0.08294678,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 318,\n          \"logprob\": -0.5854492,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 257,\n          \"logprob\": -0.9423828,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 649,\n          \"logprob\": -2.0800781,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 2214,\n          \"logprob\": -1.8369141,\n          \"special\": false,\n          \"text\": \" field\"\n        },\n        {\n          \"id\": 286,\n          \"logprob\": -1.0683594,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2267,\n          \"logprob\": -0.9711914,\n          \"special\": false,\n          \"text\": \" research\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new field of research\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -0.7089844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.0054779053,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29744,\n          \"logprob\": -0.3190918,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4673,\n          \"logprob\": -0.08319092,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 318,\n          \"logprob\": -0.5839844,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 257,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 649,\n          \"logprob\": -2.0878906,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 2214,\n          \"logprob\": -1.8496094,\n          \"special\": false,\n          \"text\": \" field\"\n        },\n        {\n          \"id\": 286,\n          \"logprob\": -1.0673828,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2267,\n          \"logprob\": -0.9370117,\n          \"special\": false,\n          \"text\": \" research\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new field of research\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -0.7089844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.0054779053,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29744,\n          \"logprob\": -0.3190918,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4673,\n          \"logprob\": -0.08319092,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 318,\n          \"logprob\": -0.5839844,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 257,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 649,\n          \"logprob\": -2.0878906,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 2214,\n          \"logprob\": -1.8496094,\n          \"special\": false,\n          \"text\": \" field\"\n        },\n        {\n          \"id\": 286,\n          \"logprob\": -1.0673828,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2267,\n          \"logprob\": -0.9370117,\n          \"special\": false,\n          \"text\": \" research\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new field of research\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -0.7089844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.0054779053,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29744,\n          \"logprob\": -0.3190918,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4673,\n          \"logprob\": -0.08319092,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 318,\n          \"logprob\": -0.5839844,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 257,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 649,\n          \"logprob\": -2.0878906,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 2214,\n          \"logprob\": -1.8496094,\n          \"special\": false,\n          \"text\": \" field\"\n        },\n        {\n          \"id\": 286,\n          \"logprob\": -1.0673828,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2267,\n          \"logprob\": -0.9370117,\n          \"special\": false,\n          \"text\": \" research\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new field of research\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -2.0566406,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.5253906,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 29902,\n        \"logprob\": -2.7578125,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 4966,\n        \"logprob\": -1.9033203,\n        \"special\": false,\n        \"text\": \" hope\"\n      },\n      {\n        \"id\": 445,\n        \"logprob\": -0.5019531,\n        \"special\": false,\n        \"text\": \" this\"\n      },\n      {\n        \"id\": 6911,\n        \"logprob\": -0.21264648,\n        \"special\": false,\n        \"text\": \" helps\"\n      },\n      {\n        \"id\": 29991,\n        \"logprob\": -0.5991211,\n        \"special\": false,\n        \"text\": \"!\"\n      },\n      {\n        \"id\": 2803,\n        \"logprob\": -0.37475586,\n        \"special\": false,\n        \"text\": \" Let\"\n      },\n      {\n        \"id\": 592,\n        \"logprob\": -0.018463135,\n        \"special\": false,\n        \"text\": \" me\"\n      },\n      {\n        \"id\": 1073,\n        \"logprob\": -0.0008597374,\n        \"special\": false,\n        \"text\": \" know\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nI hope this helps! Let me know\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_json.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 30,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 6377,\n        \"logprob\": -0.14916992,\n        \"special\": false,\n        \"text\": \"{\\\"\"\n      },\n      {\n        \"id\": 29888,\n        \"logprob\": -0.13598633,\n        \"special\": false,\n        \"text\": \"f\"\n      },\n      {\n        \"id\": 12935,\n        \"logprob\": -0.017669678,\n        \"special\": false,\n        \"text\": \"irs\"\n      },\n      {\n        \"id\": 29873,\n        \"logprob\": -0.00085639954,\n        \"special\": false,\n        \"text\": \"t\"\n      },\n      {\n        \"id\": 1170,\n        \"logprob\": -0.0054016113,\n        \"special\": false,\n        \"text\": \"Name\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.13549805,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 19504,\n        \"logprob\": -0.8852539,\n        \"special\": false,\n        \"text\": \"David\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.16394043,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 29882,\n        \"logprob\": -0.08862305,\n        \"special\": false,\n        \"text\": \"h\"\n      },\n      {\n        \"id\": 711,\n        \"logprob\": -0.66259766,\n        \"special\": false,\n        \"text\": \"ob\"\n      },\n      {\n        \"id\": 1609,\n        \"logprob\": -5.51939e-05,\n        \"special\": false,\n        \"text\": \"by\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.23120117,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 29911,\n        \"logprob\": -2.3730469,\n        \"special\": false,\n        \"text\": \"T\"\n      },\n      {\n        \"id\": 11003,\n        \"logprob\": -0.032104492,\n        \"special\": false,\n        \"text\": \"rees\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.22021484,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 4230,\n        \"logprob\": -0.06726074,\n        \"special\": false,\n        \"text\": \"last\"\n      },\n      {\n        \"id\": 1170,\n        \"logprob\": -0.003501892,\n        \"special\": false,\n        \"text\": \"Name\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.0045661926,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 29950,\n        \"logprob\": -0.12512207,\n        \"special\": false,\n        \"text\": \"H\"\n      },\n      {\n        \"id\": 14339,\n        \"logprob\": -0.009552002,\n        \"special\": false,\n        \"text\": \"olt\"\n      },\n      {\n        \"id\": 29920,\n        \"logprob\": -0.00042438507,\n        \"special\": false,\n        \"text\": \"z\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.11651611,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 29876,\n        \"logprob\": -0.29736328,\n        \"special\": false,\n        \"text\": \"n\"\n      },\n      {\n        \"id\": 398,\n        \"logprob\": -0.003030777,\n        \"special\": false,\n        \"text\": \"um\"\n      },\n      {\n        \"id\": 29907,\n        \"logprob\": -0.3774414,\n        \"special\": false,\n        \"text\": \"C\"\n      },\n      {\n        \"id\": 1446,\n        \"logprob\": -0.0003130436,\n        \"special\": false,\n        \"text\": \"ats\"\n      },\n      {\n        \"id\": 1115,\n        \"logprob\": -0.0021514893,\n        \"special\": false,\n        \"text\": \"\\\":\"\n      },\n      {\n        \"id\": 29906,\n        \"logprob\": -0.071899414,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 29913,\n        \"logprob\": -0.018997192,\n        \"special\": false,\n        \"text\": \"}\"\n      },\n      {\n        \"id\": 2,\n        \"logprob\": 0.0,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"{\\\"firstName\\\":\\\"David\\\",\\\"hobby\\\":\\\"Trees\\\",\\\"lastName\\\":\\\"Holtz\\\",\\\"numCats\\\":2}\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 29896,\n          \"logprob\": -0.7709961,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29906,\n          \"logprob\": -0.33740234,\n          \"special\": false,\n          \"text\": \"2\"\n        },\n        {\n          \"id\": 29941,\n          \"logprob\": -0.00995636,\n          \"special\": false,\n          \"text\": \"3\"\n        },\n        {\n          \"id\": 29946,\n          \"logprob\": -0.64208984,\n          \"special\": false,\n          \"text\": \"4\"\n        },\n        {\n          \"id\": 29945,\n          \"logprob\": -0.4970703,\n          \"special\": false,\n          \"text\": \"5\"\n        },\n        {\n          \"id\": 29953,\n          \"logprob\": -0.46533203,\n          \"special\": false,\n          \"text\": \"6\"\n        },\n        {\n          \"id\": 29992,\n          \"logprob\": -0.5336914,\n          \"special\": false,\n          \"text\": \"@\"\n        },\n        {\n          \"id\": 21980,\n          \"logprob\": -0.5361328,\n          \"special\": false,\n          \"text\": \"gmail\"\n        },\n        {\n          \"id\": 29889,\n          \"logprob\": -0.00088739395,\n          \"special\": false,\n          \"text\": \".\"\n        },\n        {\n          \"id\": 510,\n          \"logprob\": -0.0022735596,\n          \"special\": false,\n          \"text\": \"com\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"123456@gmail.com\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 29896,\n          \"logprob\": -0.7685547,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29906,\n          \"logprob\": -0.33666992,\n          \"special\": false,\n          \"text\": \"2\"\n        },\n        {\n          \"id\": 29941,\n          \"logprob\": -0.01008606,\n          \"special\": false,\n          \"text\": \"3\"\n        },\n        {\n          \"id\": 29946,\n          \"logprob\": -0.64160156,\n          \"special\": false,\n          \"text\": \"4\"\n        },\n        {\n          \"id\": 29945,\n          \"logprob\": -0.5,\n          \"special\": false,\n          \"text\": \"5\"\n        },\n        {\n          \"id\": 29953,\n          \"logprob\": -0.46557617,\n          \"special\": false,\n          \"text\": \"6\"\n        },\n        {\n          \"id\": 29992,\n          \"logprob\": -0.5341797,\n          \"special\": false,\n          \"text\": \"@\"\n        },\n        {\n          \"id\": 21980,\n          \"logprob\": -0.5361328,\n          \"special\": false,\n          \"text\": \"gmail\"\n        },\n        {\n          \"id\": 29889,\n          \"logprob\": -0.00088739395,\n          \"special\": false,\n          \"text\": \".\"\n        },\n        {\n          \"id\": 510,\n          \"logprob\": -0.0022907257,\n          \"special\": false,\n          \"text\": \"com\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"123456@gmail.com\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 29896,\n          \"logprob\": -0.7709961,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29906,\n          \"logprob\": -0.33740234,\n          \"special\": false,\n          \"text\": \"2\"\n        },\n        {\n          \"id\": 29941,\n          \"logprob\": -0.00995636,\n          \"special\": false,\n          \"text\": \"3\"\n        },\n        {\n          \"id\": 29946,\n          \"logprob\": -0.64208984,\n          \"special\": false,\n          \"text\": \"4\"\n        },\n        {\n          \"id\": 29945,\n          \"logprob\": -0.4970703,\n          \"special\": false,\n          \"text\": \"5\"\n        },\n        {\n          \"id\": 29953,\n          \"logprob\": -0.46533203,\n          \"special\": false,\n          \"text\": \"6\"\n        },\n        {\n          \"id\": 29992,\n          \"logprob\": -0.5336914,\n          \"special\": false,\n          \"text\": \"@\"\n        },\n        {\n          \"id\": 21980,\n          \"logprob\": -0.5361328,\n          \"special\": false,\n          \"text\": \"gmail\"\n        },\n        {\n          \"id\": 29889,\n          \"logprob\": -0.00088739395,\n          \"special\": false,\n          \"text\": \".\"\n        },\n        {\n          \"id\": 510,\n          \"logprob\": -0.0022735596,\n          \"special\": false,\n          \"text\": \"com\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"123456@gmail.com\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 29896,\n          \"logprob\": -0.7709961,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29906,\n          \"logprob\": -0.33740234,\n          \"special\": false,\n          \"text\": \"2\"\n        },\n        {\n          \"id\": 29941,\n          \"logprob\": -0.00995636,\n          \"special\": false,\n          \"text\": \"3\"\n        },\n        {\n          \"id\": 29946,\n          \"logprob\": -0.64208984,\n          \"special\": false,\n          \"text\": \"4\"\n        },\n        {\n          \"id\": 29945,\n          \"logprob\": -0.4970703,\n          \"special\": false,\n          \"text\": \"5\"\n        },\n        {\n          \"id\": 29953,\n          \"logprob\": -0.46533203,\n          \"special\": false,\n          \"text\": \"6\"\n        },\n        {\n          \"id\": 29992,\n          \"logprob\": -0.5336914,\n          \"special\": false,\n          \"text\": \"@\"\n        },\n        {\n          \"id\": 21980,\n          \"logprob\": -0.5361328,\n          \"special\": false,\n          \"text\": \"gmail\"\n        },\n        {\n          \"id\": 29889,\n          \"logprob\": -0.00088739395,\n          \"special\": false,\n          \"text\": \".\"\n        },\n        {\n          \"id\": 510,\n          \"logprob\": -0.0022735596,\n          \"special\": false,\n          \"text\": \"com\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"123456@gmail.com\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_regex.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 29946,\n        \"logprob\": -1.4765625,\n        \"special\": false,\n        \"text\": \"4\"\n      },\n      {\n        \"id\": 29906,\n        \"logprob\": -0.9199219,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 29889,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -1.1367188,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29889,\n        \"logprob\": -1.4648438,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -0.40722656,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29889,\n        \"logprob\": -0.17419434,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -0.20251465,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29900,\n        \"logprob\": -1.5527344,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -1.3710938,\n        \"special\": false,\n        \"text\": \"1\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"42.1.1.101\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_grammar_llama/test_flash_llama_grammar_single_load_instance.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 29896,\n        \"logprob\": -0.7685547,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29906,\n        \"logprob\": -0.33666992,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 29941,\n        \"logprob\": -0.009979248,\n        \"special\": false,\n        \"text\": \"3\"\n      },\n      {\n        \"id\": 29946,\n        \"logprob\": -0.64208984,\n        \"special\": false,\n        \"text\": \"4\"\n      },\n      {\n        \"id\": 29945,\n        \"logprob\": -0.4970703,\n        \"special\": false,\n        \"text\": \"5\"\n      },\n      {\n        \"id\": 29953,\n        \"logprob\": -0.46533203,\n        \"special\": false,\n        \"text\": \"6\"\n      },\n      {\n        \"id\": 29992,\n        \"logprob\": -0.5336914,\n        \"special\": false,\n        \"text\": \"@\"\n      },\n      {\n        \"id\": 21980,\n        \"logprob\": -0.53759766,\n        \"special\": false,\n        \"text\": \"gmail\"\n      },\n      {\n        \"id\": 29889,\n        \"logprob\": -0.0008878708,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 510,\n        \"logprob\": -0.002275467,\n        \"special\": false,\n        \"text\": \"com\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"123456@gmail.com\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"stop_sequence\",\n    \"generated_tokens\": 5,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5229,\n        \"logprob\": -2.5839844,\n        \"special\": false,\n        \"text\": \" failed\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": -0.44970703,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 4829,\n        \"logprob\": -1.8339844,\n        \"special\": false,\n        \"text\": \" Error\"\n      },\n      {\n        \"id\": 297,\n        \"logprob\": -1.0556641,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 1243,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" test\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request failed: Error in test\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 363,\n          \"logprob\": -1.5351562,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 847,\n          \"logprob\": -2.5566406,\n          \"special\": false,\n          \"text\": \" /\"\n        },\n        {\n          \"id\": 2754,\n          \"logprob\": -2.2519531,\n          \"special\": false,\n          \"text\": \"api\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.03414917,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29894,\n          \"logprob\": -0.96240234,\n          \"special\": false,\n          \"text\": \"v\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -0.3647461,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.012901306,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 16418,\n          \"logprob\": -3.1542969,\n          \"special\": false,\n          \"text\": \"projects\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.4362793,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -1.9394531,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for /api/v1/projects/1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 363,\n          \"logprob\": -1.5332031,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 847,\n          \"logprob\": -2.5625,\n          \"special\": false,\n          \"text\": \" /\"\n        },\n        {\n          \"id\": 2754,\n          \"logprob\": -2.2617188,\n          \"special\": false,\n          \"text\": \"api\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.033996582,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29894,\n          \"logprob\": -0.9609375,\n          \"special\": false,\n          \"text\": \"v\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -0.36572266,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.0129776,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 16418,\n          \"logprob\": -3.15625,\n          \"special\": false,\n          \"text\": \"projects\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.4362793,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -1.9394531,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for /api/v1/projects/1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 363,\n          \"logprob\": -1.5332031,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 847,\n          \"logprob\": -2.5625,\n          \"special\": false,\n          \"text\": \" /\"\n        },\n        {\n          \"id\": 2754,\n          \"logprob\": -2.2617188,\n          \"special\": false,\n          \"text\": \"api\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.033996582,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29894,\n          \"logprob\": -0.9609375,\n          \"special\": false,\n          \"text\": \"v\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -0.36572266,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.0129776,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 16418,\n          \"logprob\": -3.15625,\n          \"special\": false,\n          \"text\": \"projects\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.4362793,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -1.9394531,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for /api/v1/projects/1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 363,\n          \"logprob\": -1.5332031,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 847,\n          \"logprob\": -2.5625,\n          \"special\": false,\n          \"text\": \" /\"\n        },\n        {\n          \"id\": 2754,\n          \"logprob\": -2.2617188,\n          \"special\": false,\n          \"text\": \"api\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.033996582,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29894,\n          \"logprob\": -0.9609375,\n          \"special\": false,\n          \"text\": \"v\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -0.36572266,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.0129776,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 16418,\n          \"logprob\": -3.15625,\n          \"special\": false,\n          \"text\": \"projects\"\n        },\n        {\n          \"id\": 29914,\n          \"logprob\": -0.4362793,\n          \"special\": false,\n          \"text\": \"/\"\n        },\n        {\n          \"id\": 29896,\n          \"logprob\": -1.9394531,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for /api/v1/projects/1\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama/test_flash_llama_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 363,\n        \"logprob\": -1.5351562,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 847,\n        \"logprob\": -2.5722656,\n        \"special\": false,\n        \"text\": \" /\"\n      },\n      {\n        \"id\": 2754,\n        \"logprob\": -2.2714844,\n        \"special\": false,\n        \"text\": \"api\"\n      },\n      {\n        \"id\": 29914,\n        \"logprob\": -0.03414917,\n        \"special\": false,\n        \"text\": \"/\"\n      },\n      {\n        \"id\": 29894,\n        \"logprob\": -0.95996094,\n        \"special\": false,\n        \"text\": \"v\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -0.3635254,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29914,\n        \"logprob\": -0.013031006,\n        \"special\": false,\n        \"text\": \"/\"\n      },\n      {\n        \"id\": 16418,\n        \"logprob\": -3.1523438,\n        \"special\": false,\n        \"text\": \"projects\"\n      },\n      {\n        \"id\": 29914,\n        \"logprob\": -0.43701172,\n        \"special\": false,\n        \"text\": \"/\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": -1.9394531,\n        \"special\": false,\n        \"text\": \"1\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" for /api/v1/projects/1\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 25,\n        \"logprob\": -2.9316406,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 330,\n        \"logprob\": -3.5136719,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 489,\n        \"logprob\": -0.7783203,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 1715,\n        \"logprob\": -1.2314453,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 489,\n        \"logprob\": -2.0019531,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 2990,\n        \"logprob\": -1.5009766,\n        \"special\": false,\n        \"text\": \" \\\"\\\\\"\n      },\n      {\n        \"id\": 77,\n        \"logprob\": -0.057434082,\n        \"special\": false,\n        \"text\": \"n\"\n      },\n      {\n        \"id\": 702,\n        \"logprob\": -1.4912109,\n        \"special\": false,\n        \"text\": \"\\\"\\n\"\n      },\n      {\n        \"id\": 262,\n        \"logprob\": -1.2636719,\n        \"special\": false,\n        \"text\": \"   \"\n      },\n      {\n        \"id\": 557,\n        \"logprob\": -2.4042969,\n        \"special\": false,\n        \"text\": \" }\\n\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \": \\\" + request + \\\"\\\\n\\\"\\n    }\\n\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.9980469,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 578,\n        \"logprob\": -0.15795898,\n        \"special\": false,\n        \"text\": \" The\"\n      },\n      {\n        \"id\": 3622,\n        \"logprob\": -1.0458984,\n        \"special\": false,\n        \"text\": \" server\"\n      },\n      {\n        \"id\": 31680,\n        \"logprob\": -1.3623047,\n        \"special\": false,\n        \"text\": \" responds\"\n      },\n      {\n        \"id\": 449,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" with\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 330,\n        \"logprob\": -0.5678711,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 1049,\n        \"logprob\": -0.12322998,\n        \"special\": false,\n        \"text\": \"200\"\n      },\n      {\n        \"id\": 10619,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" OK\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\\"\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request. The server responds with a \\\"200 OK\\\"\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_exl2/test_flash_llama_exl2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.9785156,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 330,\n          \"logprob\": -3.4941406,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -0.79345703,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -1.2324219,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -1.9794922,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 2990,\n          \"logprob\": -1.4892578,\n          \"special\": false,\n          \"text\": \" \\\"\\\\\"\n        },\n        {\n          \"id\": 77,\n          \"logprob\": -0.058258057,\n          \"special\": false,\n          \"text\": \"n\"\n        },\n        {\n          \"id\": 702,\n          \"logprob\": -1.4892578,\n          \"special\": false,\n          \"text\": \"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.2783203,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 557,\n          \"logprob\": -2.3945312,\n          \"special\": false,\n          \"text\": \" }\\n\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": \\\" + request + \\\"\\\\n\\\"\\n    }\\n\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.9433594,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 330,\n          \"logprob\": -3.4726562,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -0.8022461,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -1.2509766,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -1.984375,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 2990,\n          \"logprob\": -1.4677734,\n          \"special\": false,\n          \"text\": \" \\\"\\\\\"\n        },\n        {\n          \"id\": 77,\n          \"logprob\": -0.059173584,\n          \"special\": false,\n          \"text\": \"n\"\n        },\n        {\n          \"id\": 702,\n          \"logprob\": -1.4990234,\n          \"special\": false,\n          \"text\": \"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.2822266,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 557,\n          \"logprob\": -2.3867188,\n          \"special\": false,\n          \"text\": \" }\\n\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": \\\" + request + \\\"\\\\n\\\"\\n    }\\n\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.9511719,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 330,\n          \"logprob\": -3.46875,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -0.77490234,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -1.2558594,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -1.984375,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 2990,\n          \"logprob\": -1.4990234,\n          \"special\": false,\n          \"text\": \" \\\"\\\\\"\n        },\n        {\n          \"id\": 77,\n          \"logprob\": -0.059143066,\n          \"special\": false,\n          \"text\": \"n\"\n        },\n        {\n          \"id\": 702,\n          \"logprob\": -1.4941406,\n          \"special\": false,\n          \"text\": \"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 557,\n          \"logprob\": -2.3964844,\n          \"special\": false,\n          \"text\": \" }\\n\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": \\\" + request + \\\"\\\\n\\\"\\n    }\\n\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.9101562,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 330,\n          \"logprob\": -3.5039062,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -0.8076172,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -1.2236328,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 489,\n          \"logprob\": -1.9853516,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 2990,\n          \"logprob\": -1.4892578,\n          \"special\": false,\n          \"text\": \" \\\"\\\\\"\n        },\n        {\n          \"id\": 77,\n          \"logprob\": -0.056671143,\n          \"special\": false,\n          \"text\": \"n\"\n        },\n        {\n          \"id\": 702,\n          \"logprob\": -1.5107422,\n          \"special\": false,\n          \"text\": \"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.2597656,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 557,\n          \"logprob\": -2.4042969,\n          \"special\": false,\n          \"text\": \" }\\n\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": \\\" + request + \\\"\\\\n\\\"\\n    }\\n\\n\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 369,\n        \"logprob\": -2.1816406,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 279,\n        \"logprob\": -2.6992188,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -3.6308594,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 679,\n        \"logprob\": -1.7900391,\n        \"special\": false,\n        \"text\": \"201\"\n      },\n      {\n        \"id\": 24,\n        \"logprob\": -1.3554688,\n        \"special\": false,\n        \"text\": \"9\"\n      },\n      {\n        \"id\": 12,\n        \"logprob\": -2.0039062,\n        \"special\": false,\n        \"text\": \"-\"\n      },\n      {\n        \"id\": 2366,\n        \"logprob\": -0.4489746,\n        \"special\": false,\n        \"text\": \"202\"\n      },\n      {\n        \"id\": 15,\n        \"logprob\": -0.037109375,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 2978,\n        \"logprob\": -0.8100586,\n        \"special\": false,\n        \"text\": \" school\"\n      },\n      {\n        \"id\": 1060,\n        \"logprob\": -0.013015747,\n        \"special\": false,\n        \"text\": \" year\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" for the 2019-2020 school year\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 25,\n        \"logprob\": -0.88183594,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 2209,\n        \"logprob\": -2.6699219,\n        \"special\": false,\n        \"text\": \" Is\"\n      },\n      {\n        \"id\": 279,\n        \"logprob\": -0.61083984,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 734,\n        \"logprob\": -2.6660156,\n        \"special\": false,\n        \"text\": \" function\"\n      },\n      {\n        \"id\": 330,\n        \"logprob\": -0.35498047,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 4110,\n        \"logprob\": -2.4101562,\n        \"special\": false,\n        \"text\": \"Create\"\n      },\n      {\n        \"id\": 7575,\n        \"logprob\": -2.2304688,\n        \"special\": false,\n        \"text\": \"Process\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": -0.080078125,\n        \"special\": false,\n        \"text\": \"\\\"\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": -0.75439453,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 12468,\n        \"logprob\": -1.8769531,\n        \"special\": false,\n        \"text\": \" Win\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request: Is the function \\\"CreateProcess\\\" in Win\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 369,\n          \"logprob\": -2.15625,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -2.703125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 220,\n          \"logprob\": -3.640625,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 679,\n          \"logprob\": -1.703125,\n          \"special\": false,\n          \"text\": \"201\"\n        },\n        {\n          \"id\": 24,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \"9\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -2.03125,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 2366,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"202\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -0.041503906,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 2978,\n          \"logprob\": -0.87109375,\n          \"special\": false,\n          \"text\": \" school\"\n        },\n        {\n          \"id\": 1060,\n          \"logprob\": -0.012939453,\n          \"special\": false,\n          \"text\": \" year\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the 2019-2020 school year\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 369,\n          \"logprob\": -2.15625,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -2.703125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 220,\n          \"logprob\": -3.640625,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 679,\n          \"logprob\": -1.703125,\n          \"special\": false,\n          \"text\": \"201\"\n        },\n        {\n          \"id\": 24,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \"9\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -2.03125,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 2366,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"202\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -0.041503906,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 2978,\n          \"logprob\": -0.87109375,\n          \"special\": false,\n          \"text\": \" school\"\n        },\n        {\n          \"id\": 1060,\n          \"logprob\": -0.012939453,\n          \"special\": false,\n          \"text\": \" year\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the 2019-2020 school year\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 369,\n          \"logprob\": -2.15625,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -2.703125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 220,\n          \"logprob\": -3.640625,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 679,\n          \"logprob\": -1.703125,\n          \"special\": false,\n          \"text\": \"201\"\n        },\n        {\n          \"id\": 24,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \"9\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -2.03125,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 2366,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"202\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -0.041503906,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 2978,\n          \"logprob\": -0.87109375,\n          \"special\": false,\n          \"text\": \" school\"\n        },\n        {\n          \"id\": 1060,\n          \"logprob\": -0.012939453,\n          \"special\": false,\n          \"text\": \" year\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the 2019-2020 school year\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 369,\n          \"logprob\": -2.15625,\n          \"special\": false,\n          \"text\": \" for\"\n        },\n        {\n          \"id\": 279,\n          \"logprob\": -2.703125,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 220,\n          \"logprob\": -3.640625,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 679,\n          \"logprob\": -1.703125,\n          \"special\": false,\n          \"text\": \"201\"\n        },\n        {\n          \"id\": 24,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \"9\"\n        },\n        {\n          \"id\": 12,\n          \"logprob\": -2.03125,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 2366,\n          \"logprob\": -0.49023438,\n          \"special\": false,\n          \"text\": \"202\"\n        },\n        {\n          \"id\": 15,\n          \"logprob\": -0.041503906,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 2978,\n          \"logprob\": -0.87109375,\n          \"special\": false,\n          \"text\": \" school\"\n        },\n        {\n          \"id\": 1060,\n          \"logprob\": -0.012939453,\n          \"special\": false,\n          \"text\": \" year\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" for the 2019-2020 school year\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 18682,\n        \"logprob\": -1.109375,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.005432129,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.028808594,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.013671875,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 27084,\n        \"logprob\": -0.69921875,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.0005874634,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5780,\n        \"logprob\": -0.026855469,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": -0.00020885468,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 430,\n        \"logprob\": -0.17773438,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 18065,\n        \"logprob\": -0.703125,\n        \"special\": false,\n        \"text\": \" involves\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" Deep learning is a subset of machine learning that involves\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 720,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\n\"\n      },\n      {\n        \"id\": 34564,\n        \"logprob\": -0.12512207,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 6975,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 320,\n        \"logprob\": -0.23840332,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 16931,\n        \"logprob\": -2.0175781,\n        \"special\": false,\n        \"text\": \"DL\"\n      },\n      {\n        \"id\": 8,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \")\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.8613281,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1207,\n        \"logprob\": -1.2451172,\n        \"special\": false,\n        \"text\": \" sub\"\n      },\n      {\n        \"id\": 2630,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"field\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning? \\nDeep learning (DL) is a subfield\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_fp8_kv_cache/test_flash_llama_fp8_kv_cache_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -1.109375,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.0047912598,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.025512695,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.012145996,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.72265625,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0005760193,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.02722168,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.00023651123,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.17285156,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 18065,\n          \"logprob\": -0.703125,\n          \"special\": false,\n          \"text\": \" involves\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that involves\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -1.1796875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.005432129,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.02758789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.013366699,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0004863739,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.02709961,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.00022506714,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.19726562,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 18065,\n          \"logprob\": -0.77734375,\n          \"special\": false,\n          \"text\": \" involves\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that involves\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -1.1796875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.005432129,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.02758789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.013366699,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0004863739,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.02709961,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.00022506714,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.19726562,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 18065,\n          \"logprob\": -0.77734375,\n          \"special\": false,\n          \"text\": \" involves\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that involves\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 18682,\n          \"logprob\": -1.1796875,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.005432129,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 374,\n          \"logprob\": -0.02758789,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.013366699,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 27084,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 315,\n          \"logprob\": -0.0004863739,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5780,\n          \"logprob\": -0.02709961,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6975,\n          \"logprob\": -0.00022506714,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 430,\n          \"logprob\": -0.19726562,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 18065,\n          \"logprob\": -0.77734375,\n          \"special\": false,\n          \"text\": \" involves\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" Deep learning is a subset of machine learning that involves\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 198,\n        \"logprob\": -2.5742188,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 262,\n        \"logprob\": -1.6230469,\n        \"special\": false,\n        \"text\": \"   \"\n      },\n      {\n        \"id\": 3270,\n        \"logprob\": -2.046875,\n        \"special\": false,\n        \"text\": \" \\\"\\\"\\\"\\n\"\n      },\n      {\n        \"id\": 262,\n        \"logprob\": -0.015281677,\n        \"special\": false,\n        \"text\": \"   \"\n      },\n      {\n        \"id\": 422,\n        \"logprob\": -2.1425781,\n        \"special\": false,\n        \"text\": \" if\"\n      },\n      {\n        \"id\": 1715,\n        \"logprob\": -0.9238281,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 13204,\n        \"logprob\": -0.076660156,\n        \"special\": false,\n        \"text\": \".method\"\n      },\n      {\n        \"id\": 624,\n        \"logprob\": -0.021987915,\n        \"special\": false,\n        \"text\": \" ==\"\n      },\n      {\n        \"id\": 364,\n        \"logprob\": -0.39208984,\n        \"special\": false,\n        \"text\": \" '\"\n      },\n      {\n        \"id\": 3019,\n        \"logprob\": -0.10821533,\n        \"special\": false,\n        \"text\": \"POST\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n    \\\"\\\"\\\"\\n    if request.method == 'POST\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -2.2539062,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 578,\n        \"logprob\": -0.15563965,\n        \"special\": false,\n        \"text\": \" The\"\n      },\n      {\n        \"id\": 3622,\n        \"logprob\": -0.8203125,\n        \"special\": false,\n        \"text\": \" server\"\n      },\n      {\n        \"id\": 706,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" has\"\n      },\n      {\n        \"id\": 539,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 3686,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" yet\"\n      },\n      {\n        \"id\": 3288,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" sent\"\n      },\n      {\n        \"id\": 904,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" any\"\n      },\n      {\n        \"id\": 828,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" data\"\n      },\n      {\n        \"id\": 382,\n        \"logprob\": -1.5517578,\n        \"special\": false,\n        \"text\": \".\\n\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request. The server has not yet sent any data.\\n\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_gptq/test_flash_llama_gptq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.5742188,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.6220703,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 3270,\n          \"logprob\": -2.0410156,\n          \"special\": false,\n          \"text\": \" \\\"\\\"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -0.015281677,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 422,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \" if\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -0.92333984,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13204,\n          \"logprob\": -0.07672119,\n          \"special\": false,\n          \"text\": \".method\"\n        },\n        {\n          \"id\": 624,\n          \"logprob\": -0.021987915,\n          \"special\": false,\n          \"text\": \" ==\"\n        },\n        {\n          \"id\": 364,\n          \"logprob\": -0.39208984,\n          \"special\": false,\n          \"text\": \" '\"\n        },\n        {\n          \"id\": 3019,\n          \"logprob\": -0.10638428,\n          \"special\": false,\n          \"text\": \"POST\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n    \\\"\\\"\\\"\\n    if request.method == 'POST\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.5742188,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.6220703,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 3270,\n          \"logprob\": -2.0410156,\n          \"special\": false,\n          \"text\": \" \\\"\\\"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -0.015281677,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 422,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \" if\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -0.92333984,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13204,\n          \"logprob\": -0.07672119,\n          \"special\": false,\n          \"text\": \".method\"\n        },\n        {\n          \"id\": 624,\n          \"logprob\": -0.021987915,\n          \"special\": false,\n          \"text\": \" ==\"\n        },\n        {\n          \"id\": 364,\n          \"logprob\": -0.39208984,\n          \"special\": false,\n          \"text\": \" '\"\n        },\n        {\n          \"id\": 3019,\n          \"logprob\": -0.10638428,\n          \"special\": false,\n          \"text\": \"POST\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n    \\\"\\\"\\\"\\n    if request.method == 'POST\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.5742188,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.6220703,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 3270,\n          \"logprob\": -2.0410156,\n          \"special\": false,\n          \"text\": \" \\\"\\\"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -0.015281677,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 422,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \" if\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -0.92333984,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13204,\n          \"logprob\": -0.07672119,\n          \"special\": false,\n          \"text\": \".method\"\n        },\n        {\n          \"id\": 624,\n          \"logprob\": -0.021987915,\n          \"special\": false,\n          \"text\": \" ==\"\n        },\n        {\n          \"id\": 364,\n          \"logprob\": -0.39208984,\n          \"special\": false,\n          \"text\": \" '\"\n        },\n        {\n          \"id\": 3019,\n          \"logprob\": -0.10638428,\n          \"special\": false,\n          \"text\": \"POST\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n    \\\"\\\"\\\"\\n    if request.method == 'POST\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.5742188,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -1.6220703,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 3270,\n          \"logprob\": -2.0410156,\n          \"special\": false,\n          \"text\": \" \\\"\\\"\\\"\\n\"\n        },\n        {\n          \"id\": 262,\n          \"logprob\": -0.015281677,\n          \"special\": false,\n          \"text\": \"   \"\n        },\n        {\n          \"id\": 422,\n          \"logprob\": -2.1445312,\n          \"special\": false,\n          \"text\": \" if\"\n        },\n        {\n          \"id\": 1715,\n          \"logprob\": -0.92333984,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13204,\n          \"logprob\": -0.07672119,\n          \"special\": false,\n          \"text\": \".method\"\n        },\n        {\n          \"id\": 624,\n          \"logprob\": -0.021987915,\n          \"special\": false,\n          \"text\": \" ==\"\n        },\n        {\n          \"id\": 364,\n          \"logprob\": -0.39208984,\n          \"special\": false,\n          \"text\": \" '\"\n        },\n        {\n          \"id\": 3019,\n          \"logprob\": -0.10638428,\n          \"special\": false,\n          \"text\": \"POST\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n    \\\"\\\"\\\"\\n    if request.method == 'POST\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -2.0507812,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -2.3007812,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 29902,\n        \"logprob\": -2.0449219,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 505,\n        \"logprob\": -1.3242188,\n        \"special\": false,\n        \"text\": \" have\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -0.2076416,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1243,\n        \"logprob\": -2.0273438,\n        \"special\": false,\n        \"text\": \" test\"\n      },\n      {\n        \"id\": 2009,\n        \"logprob\": -0.6845703,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 515,\n        \"logprob\": -1.1748047,\n        \"special\": false,\n        \"text\": \" from\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -1.0644531,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1404,\n        \"logprob\": -1.5224609,\n        \"special\": false,\n        \"text\": \" user\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nI have a test request from a user\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5229,\n        \"logprob\": -1.2607422,\n        \"special\": false,\n        \"text\": \" failed\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 6527,\n        \"logprob\": -0.11450195,\n        \"special\": false,\n        \"text\": \" Could\"\n      },\n      {\n        \"id\": 451,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 4511,\n        \"logprob\": -0.2286377,\n        \"special\": false,\n        \"text\": \" connect\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 1923,\n        \"logprob\": -1.2568359,\n        \"special\": false,\n        \"text\": \" server\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.15905762,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 29902,\n        \"logprob\": -0.21618652,\n        \"special\": false,\n        \"text\": \"I\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request failed: Could not connect to server\\n\\nI\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin/test_flash_llama_marlin_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.0507812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -2.3007812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29902,\n          \"logprob\": -2.0449219,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -1.3242188,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.2076416,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1243,\n          \"logprob\": -2.0273438,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.6845703,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 515,\n          \"logprob\": -1.1748047,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -1.0595703,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1404,\n          \"logprob\": -1.5224609,\n          \"special\": false,\n          \"text\": \" user\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nI have a test request from a user\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.0507812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -2.3007812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29902,\n          \"logprob\": -2.0449219,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -1.3242188,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.2076416,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1243,\n          \"logprob\": -2.0273438,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.6845703,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 515,\n          \"logprob\": -1.1748047,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -1.0595703,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1404,\n          \"logprob\": -1.5224609,\n          \"special\": false,\n          \"text\": \" user\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nI have a test request from a user\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.0507812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -2.3007812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29902,\n          \"logprob\": -2.0449219,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -1.3242188,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.2076416,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1243,\n          \"logprob\": -2.0273438,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.6845703,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 515,\n          \"logprob\": -1.1748047,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -1.0595703,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1404,\n          \"logprob\": -1.5224609,\n          \"special\": false,\n          \"text\": \" user\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nI have a test request from a user\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.0507812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -2.3007812,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 29902,\n          \"logprob\": -2.0449219,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 505,\n          \"logprob\": -1.3242188,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.2076416,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1243,\n          \"logprob\": -2.0273438,\n          \"special\": false,\n          \"text\": \" test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.6845703,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 515,\n          \"logprob\": -1.1748047,\n          \"special\": false,\n          \"text\": \" from\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -1.0595703,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1404,\n          \"logprob\": -1.5224609,\n          \"special\": false,\n          \"text\": \" user\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nI have a test request from a user\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 5229,\n        \"logprob\": -2.7988281,\n        \"special\": false,\n        \"text\": \" failed\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": -0.91259766,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 853,\n        \"logprob\": -2.8496094,\n        \"special\": false,\n        \"text\": \" Un\"\n      },\n      {\n        \"id\": 23765,\n        \"logprob\": -1.1894531,\n        \"special\": false,\n        \"text\": \"supported\"\n      },\n      {\n        \"id\": 4714,\n        \"logprob\": -1.5917969,\n        \"special\": false,\n        \"text\": \" browser\"\n      },\n      {\n        \"id\": 29892,\n        \"logprob\": -0.34765625,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1873,\n        \"logprob\": -1.2695312,\n        \"special\": false,\n        \"text\": \" version\"\n      },\n      {\n        \"id\": 470,\n        \"logprob\": -0.25170898,\n        \"special\": false,\n        \"text\": \" or\"\n      },\n      {\n        \"id\": 7481,\n        \"logprob\": -0.21411133,\n        \"special\": false,\n        \"text\": \" platform\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.1162109,\n        \"special\": false,\n        \"text\": \"\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" failed: Unsupported browser, version or platform\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 5229,\n        \"logprob\": -0.6645508,\n        \"special\": false,\n        \"text\": \" failed\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 6527,\n        \"logprob\": -2.2324219,\n        \"special\": false,\n        \"text\": \" Could\"\n      },\n      {\n        \"id\": 451,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 6088,\n        \"logprob\": -1.6074219,\n        \"special\": false,\n        \"text\": \" parse\"\n      },\n      {\n        \"id\": 1243,\n        \"logprob\": -1.6298828,\n        \"special\": false,\n        \"text\": \" test\"\n      },\n      {\n        \"id\": 1206,\n        \"logprob\": -0.72558594,\n        \"special\": false,\n        \"text\": \" case\"\n      },\n      {\n        \"id\": 1024,\n        \"logprob\": -0.40429688,\n        \"special\": false,\n        \"text\": \" name\"\n      },\n      {\n        \"id\": 515,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" from\"\n      },\n      {\n        \"id\": 525,\n        \"logprob\": -1.2519531,\n        \"special\": false,\n        \"text\": \" '\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request failed: Could not parse test case name from '\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 5229,\n          \"logprob\": -2.7988281,\n          \"special\": false,\n          \"text\": \" failed\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.91259766,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 853,\n          \"logprob\": -2.8496094,\n          \"special\": false,\n          \"text\": \" Un\"\n        },\n        {\n          \"id\": 23765,\n          \"logprob\": -1.1894531,\n          \"special\": false,\n          \"text\": \"supported\"\n        },\n        {\n          \"id\": 4714,\n          \"logprob\": -1.5917969,\n          \"special\": false,\n          \"text\": \" browser\"\n        },\n        {\n          \"id\": 29892,\n          \"logprob\": -0.34765625,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1873,\n          \"logprob\": -1.2695312,\n          \"special\": false,\n          \"text\": \" version\"\n        },\n        {\n          \"id\": 470,\n          \"logprob\": -0.25170898,\n          \"special\": false,\n          \"text\": \" or\"\n        },\n        {\n          \"id\": 7481,\n          \"logprob\": -0.21411133,\n          \"special\": false,\n          \"text\": \" platform\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.1162109,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" failed: Unsupported browser, version or platform\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 5229,\n          \"logprob\": -2.7988281,\n          \"special\": false,\n          \"text\": \" failed\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.91259766,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 853,\n          \"logprob\": -2.8496094,\n          \"special\": false,\n          \"text\": \" Un\"\n        },\n        {\n          \"id\": 23765,\n          \"logprob\": -1.1894531,\n          \"special\": false,\n          \"text\": \"supported\"\n        },\n        {\n          \"id\": 4714,\n          \"logprob\": -1.5917969,\n          \"special\": false,\n          \"text\": \" browser\"\n        },\n        {\n          \"id\": 29892,\n          \"logprob\": -0.34765625,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1873,\n          \"logprob\": -1.2695312,\n          \"special\": false,\n          \"text\": \" version\"\n        },\n        {\n          \"id\": 470,\n          \"logprob\": -0.25170898,\n          \"special\": false,\n          \"text\": \" or\"\n        },\n        {\n          \"id\": 7481,\n          \"logprob\": -0.21411133,\n          \"special\": false,\n          \"text\": \" platform\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.1162109,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" failed: Unsupported browser, version or platform\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 5229,\n          \"logprob\": -2.7988281,\n          \"special\": false,\n          \"text\": \" failed\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.91259766,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 853,\n          \"logprob\": -2.8496094,\n          \"special\": false,\n          \"text\": \" Un\"\n        },\n        {\n          \"id\": 23765,\n          \"logprob\": -1.1894531,\n          \"special\": false,\n          \"text\": \"supported\"\n        },\n        {\n          \"id\": 4714,\n          \"logprob\": -1.5917969,\n          \"special\": false,\n          \"text\": \" browser\"\n        },\n        {\n          \"id\": 29892,\n          \"logprob\": -0.34765625,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1873,\n          \"logprob\": -1.2695312,\n          \"special\": false,\n          \"text\": \" version\"\n        },\n        {\n          \"id\": 470,\n          \"logprob\": -0.25170898,\n          \"special\": false,\n          \"text\": \" or\"\n        },\n        {\n          \"id\": 7481,\n          \"logprob\": -0.21411133,\n          \"special\": false,\n          \"text\": \" platform\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.1162109,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" failed: Unsupported browser, version or platform\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 5229,\n          \"logprob\": -2.7988281,\n          \"special\": false,\n          \"text\": \" failed\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.91259766,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 853,\n          \"logprob\": -2.8496094,\n          \"special\": false,\n          \"text\": \" Un\"\n        },\n        {\n          \"id\": 23765,\n          \"logprob\": -1.1894531,\n          \"special\": false,\n          \"text\": \"supported\"\n        },\n        {\n          \"id\": 4714,\n          \"logprob\": -1.5917969,\n          \"special\": false,\n          \"text\": \" browser\"\n        },\n        {\n          \"id\": 29892,\n          \"logprob\": -0.34765625,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1873,\n          \"logprob\": -1.2695312,\n          \"special\": false,\n          \"text\": \" version\"\n        },\n        {\n          \"id\": 470,\n          \"logprob\": -0.25170898,\n          \"special\": false,\n          \"text\": \" or\"\n        },\n        {\n          \"id\": 7481,\n          \"logprob\": -0.21411133,\n          \"special\": false,\n          \"text\": \" platform\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.1162109,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" failed: Unsupported browser, version or platform\\n\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_prefix/test_flash_llama_load.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Jeff Walker's Product Launch Formula is a comprehensive system\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 69,\n      \"total_tokens\": 79\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are three key indicators to determine if a customer\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 52,\n      \"total_tokens\": 62\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can use the `String.format()` method in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 97,\n      \"total_tokens\": 107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"In a realm of binary mysticism, we find\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 126,\n      \"total_tokens\": 136\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The `dummy` variable is being used to consume\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 305,\n      \"total_tokens\": 315\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can add multiple new columns in Power Query (\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"There are many exciting new technologies emerging across various fields\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 52,\n      \"total_tokens\": 62\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Poly Ether Ether Ketone (PEEK) is\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 40,\n      \"total_tokens\": 50\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a technical overview of a referral system similar\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 85,\n      \"total_tokens\": 95\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's an example of how you can add an\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help with Java. What\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can help you plan a road trip from Pune\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 82,\n      \"total_tokens\": 92\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to explain more about a topic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help you brainstorm and provide\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 47,\n      \"total_tokens\": 57\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Implementing a Minesweeper algorithm using algebraic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 54,\n      \"total_tokens\": 64\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"There are several issues with the provided code:\\n\\n1\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 375,\n      \"total_tokens\": 385\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \";)\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085330,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 2,\n      \"prompt_tokens\": 105,\n      \"total_tokens\": 107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"As I delved into the world of high-st\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2097,\n      \"total_tokens\": 2107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hi, I'm\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2614,\n      \"total_tokens\": 2624\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To simulate a conversation between Alice and /u/C\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1070,\n      \"total_tokens\": 1080\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Alice: Hey /u/CruxHub,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1847,\n      \"total_tokens\": 1857\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Alice: Hi /u/CruxHub,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1849,\n      \"total_tokens\": 1859\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1004,\n      \"total_tokens\": 1014\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1100,\n      \"total_tokens\": 1110\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1044,\n      \"total_tokens\": 1054\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Dogme approach and the Lexical Approach are\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 54,\n      \"total_tokens\": 64\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Implementing a netfilter in Linux with a Rust\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 48,\n      \"total_tokens\": 58\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Damage to the Ulnar nerve can cause numb\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Space Shuttle's Reaction Control System (RCS\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can provide you with a basic Python script that\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 65,\n      \"total_tokens\": 75\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Farming meat has several negative impacts on the environment\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The photograph filter you're referring to is called \\\"\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a sample geological database structure with some example\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 59,\n      \"total_tokens\": 69\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Web Marketing: A Simplified Explanation**\\n\\nWeb\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a rewritten and improved version of the story\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 447,\n      \"total_tokens\": 457\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are the questions rewritten in a more conversational\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 168,\n      \"total_tokens\": 178\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Learning Progress: 0%**\\n\\n| Topic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 216,\n      \"total_tokens\": 226\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I couldn't find any information on a person named\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a list of the largest outdoor retailers in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To create a WordPress shortcode that includes Facebook SDK code\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The sentence is mostly grammatically correct, but there\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 78,\n      \"total_tokens\": 88\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to engage in a debate with\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 59,\n      \"total_tokens\": 69\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd love to hear about your business. As\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 64,\n      \"total_tokens\": 74\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'll wait for your request to proceed with part\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2410,\n      \"total_tokens\": 2420\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The final part of the Day Sculpting program emphasizes\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2699,\n      \"total_tokens\": 2709\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Analysis of the Coming of Age Story Archetype\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 349,\n      \"total_tokens\": 359\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Apostle John is one of the most prominent figures\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To build a Google Places autocomplete feature on Jetpack\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 427,\n      \"total_tokens\": 437\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The information provided does not mention the captain's name\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 169,\n      \"total_tokens\": 179\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The metaverse is a shared, immersive and interactive\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 39,\n      \"total_tokens\": 49\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some ideas for a series of articles for\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"\\\"Purim Palooza Alert: \\n\\nTo\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 78,\n      \"total_tokens\": 88\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Summary of the paper in 10 points:\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2022,\n      \"total_tokens\": 2032\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You'll provide three pieces of text, and then\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 58,\n      \"total_tokens\": 68\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'm ready to proceed with text 3.\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1650,\n      \"total_tokens\": 1660\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'm ready to answer questions on Text 1\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1116,\n      \"total_tokens\": 1126\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"This is a Solidity contract written in the older\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 334,\n      \"total_tokens\": 344\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Speech Recognition and Synthesis using Python**\\n\\nTo\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 84,\n      \"total_tokens\": 94\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help you discuss a paper\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To handle the given utterance, we can use\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 375,\n      \"total_tokens\": 385\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Subscription Services Template:**\\n\\n**Title:** Virtual\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 443,\n      \"total_tokens\": 453\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Hello. How can I assist you today?\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 36,\n      \"total_tokens\": 46\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Differentiating yourself from other Etsy shops is crucial to\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 102,\n      \"total_tokens\": 112\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To become a Licensed Marriage and Family Therapist (\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 53,\n      \"total_tokens\": 63\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**What is Quantum Computing?**\\n\\nQuantum computing\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Aquí te dejo 40 opciones de nombres\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 108,\n      \"total_tokens\": 118\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Deposition is a geological process that involves the transportation\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some good e-governance initiatives in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 55,\n      \"total_tokens\": 65\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a simple Python program that accepts a command\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Imagine you're playing with a toy box. You\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 47,\n      \"total_tokens\": 57\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's an example of a question they might ask\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 66,\n      \"total_tokens\": 76\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Arduino Uno adalah sebuah papan mikrokontrol\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To edit an array that is within an object,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Microsoft ENTRA (Enterprise Mobility + Security) is\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To calculate the difference in interest paid between a simple\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 69,\n      \"total_tokens\": 79\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Yes, you can use Spring State Machine and Spring\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The issue lies in the fact that the `meta\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 142,\n      \"total_tokens\": 152\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some effective marketing tactics for local small businesses\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 46,\n      \"total_tokens\": 56\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The French Revolution, which lasted from 1789\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 41,\n      \"total_tokens\": 51\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Roles of a Network Driver:**\\n\\nA network\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 65,\n      \"total_tokens\": 75\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Yes, I'm familiar with the SAS (Stat\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Using relays to control 12V solen\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 60,\n      \"total_tokens\": 70\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can use the following Python code to achieve this\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 55,\n      \"total_tokens\": 65\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some prompts for viral comics:\\n\\n1.\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 336,\n      \"total_tokens\": 346\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To simplify and make the comic funnier, consider\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 301,\n      \"total_tokens\": 311\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a rewritten version of the 4-panel\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 282,\n      \"total_tokens\": 292\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Subject: Request for E-Waste Collection and Computer\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 110,\n      \"total_tokens\": 120\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"In the context of conference calls, the state you\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 84,\n      \"total_tokens\": 94\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can provide a general classification of companies based on\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some user stories that describe the concept in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can check your Python version by running the following\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 39,\n      \"total_tokens\": 49\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Scenario:**\\n\\n15-year-old Black youth,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 473,\n      \"total_tokens\": 483\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"As a Demand Generation Manager for a B2B\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The error is due to a typo in your code\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085336,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 369,\n      \"total_tokens\": 379\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"고등교육의 필요성에 관한 영어 에\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 72,\n      \"total_tokens\": 82\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a simple C# program that uses the\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The error message \\\"connection refused\\\" indicates that the\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085331,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 85,\n      \"total_tokens\": 95\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To load an image, you can use various methods\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726085326,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 41,\n      \"total_tokens\": 51\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_llama_prefix_flashdecoding/test_flash_llama_flashdecoding.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Jeff Walker's Product Launch Formula is a comprehensive system\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 69,\n      \"total_tokens\": 79\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are three key indicators to determine if a customer\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 52,\n      \"total_tokens\": 62\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can use the `String.format()` method in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 97,\n      \"total_tokens\": 107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"In a realm of binary mysticism, we find\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 126,\n      \"total_tokens\": 136\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The `dummy` variable is being used to consume\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 305,\n      \"total_tokens\": 315\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can add multiple new columns in Power Query (\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"There are many exciting new technologies emerging across various fields\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 52,\n      \"total_tokens\": 62\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Poly Ether Ether Ketone (PEEK) is\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 40,\n      \"total_tokens\": 50\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a technical overview of a referral system similar\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 85,\n      \"total_tokens\": 95\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's an example of how you can add an\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help with Java. What\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can help you plan a road trip from Pune\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 82,\n      \"total_tokens\": 92\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to explain more about a topic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help you brainstorm and provide\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 47,\n      \"total_tokens\": 57\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Implementing a Minesweeper algorithm using algebraic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 54,\n      \"total_tokens\": 64\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"There are several issues with the provided code:\\n\\n1\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 375,\n      \"total_tokens\": 385\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \";)\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 2,\n      \"prompt_tokens\": 105,\n      \"total_tokens\": 107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"As I delved into the world of high-st\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2097,\n      \"total_tokens\": 2107\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hi, I'm\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2614,\n      \"total_tokens\": 2624\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To simulate a conversation between Alice and /u/C\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1070,\n      \"total_tokens\": 1080\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Alice: Hey /u/CruxHub,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1847,\n      \"total_tokens\": 1857\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Alice: Hi /u/CruxHub,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1849,\n      \"total_tokens\": 1859\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1004,\n      \"total_tokens\": 1014\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1100,\n      \"total_tokens\": 1110\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"/u/CruxHub: Hey Alice, I\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1044,\n      \"total_tokens\": 1054\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Dogme approach and the Lexical Approach are\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 54,\n      \"total_tokens\": 64\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Implementing a netfilter in Linux with a Rust\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 48,\n      \"total_tokens\": 58\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Damage to the Ulnar nerve can cause numb\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Space Shuttle's Reaction Control System (RCS\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can provide you with a basic Python script that\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 65,\n      \"total_tokens\": 75\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Farming meat has several negative impacts on the environment\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The photograph filter you're referring to is called \\\"\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a sample geological database structure with some example\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 59,\n      \"total_tokens\": 69\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Web Marketing: A Simplified Explanation**\\n\\nWeb\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a rewritten and improved version of the story\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 447,\n      \"total_tokens\": 457\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are the questions rewritten in a more conversational\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 168,\n      \"total_tokens\": 178\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Learning Progress: 0%**\\n\\n| Topic\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 216,\n      \"total_tokens\": 226\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I couldn't find any information on a person named\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a list of the largest outdoor retailers in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 43,\n      \"total_tokens\": 53\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To create a WordPress shortcode that includes Facebook SDK code\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The sentence is mostly grammatically correct, but there\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 78,\n      \"total_tokens\": 88\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to engage in a debate with\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 59,\n      \"total_tokens\": 69\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd love to hear about your business. As\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 64,\n      \"total_tokens\": 74\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'll wait for your request to proceed with part\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2410,\n      \"total_tokens\": 2420\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The final part of the Day Sculpting program emphasizes\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2699,\n      \"total_tokens\": 2709\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Analysis of the Coming of Age Story Archetype\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 349,\n      \"total_tokens\": 359\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The Apostle John is one of the most prominent figures\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To build a Google Places autocomplete feature on Jetpack\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 427,\n      \"total_tokens\": 437\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The information provided does not mention the captain's name\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 169,\n      \"total_tokens\": 179\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The metaverse is a shared, immersive and interactive\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 39,\n      \"total_tokens\": 49\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some ideas for a series of articles for\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"\\\"Purim Palooza Alert: \\n\\nTo\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 78,\n      \"total_tokens\": 88\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Summary of the paper in 10 points:\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 2022,\n      \"total_tokens\": 2032\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You'll provide three pieces of text, and then\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 58,\n      \"total_tokens\": 68\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'm ready to proceed with text 3.\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1650,\n      \"total_tokens\": 1660\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'm ready to answer questions on Text 1\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 1116,\n      \"total_tokens\": 1126\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"This is a Solidity contract written in the older\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 334,\n      \"total_tokens\": 344\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Speech Recognition and Synthesis using Python**\\n\\nTo\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 84,\n      \"total_tokens\": 94\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I'd be happy to help you discuss a paper\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To handle the given utterance, we can use\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 375,\n      \"total_tokens\": 385\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Subscription Services Template:**\\n\\n**Title:** Virtual\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 443,\n      \"total_tokens\": 453\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Hello. How can I assist you today?\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 36,\n      \"total_tokens\": 46\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Differentiating yourself from other Etsy shops is crucial to\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 102,\n      \"total_tokens\": 112\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To become a Licensed Marriage and Family Therapist (\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 53,\n      \"total_tokens\": 63\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**What is Quantum Computing?**\\n\\nQuantum computing\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Aquí te dejo 40 opciones de nombres\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 108,\n      \"total_tokens\": 118\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Deposition is a geological process that involves the transportation\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some good e-governance initiatives in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 55,\n      \"total_tokens\": 65\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a simple Python program that accepts a command\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Imagine you're playing with a toy box. You\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 47,\n      \"total_tokens\": 57\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's an example of a question they might ask\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 66,\n      \"total_tokens\": 76\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Arduino Uno adalah sebuah papan mikrokontrol\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 38,\n      \"total_tokens\": 48\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To edit an array that is within an object,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 42,\n      \"total_tokens\": 52\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Microsoft ENTRA (Enterprise Mobility + Security) is\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To calculate the difference in interest paid between a simple\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 69,\n      \"total_tokens\": 79\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Yes, you can use Spring State Machine and Spring\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 49,\n      \"total_tokens\": 59\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The issue lies in the fact that the `meta\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 142,\n      \"total_tokens\": 152\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some effective marketing tactics for local small businesses\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 46,\n      \"total_tokens\": 56\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The French Revolution, which lasted from 1789\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 41,\n      \"total_tokens\": 51\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Roles of a Network Driver:**\\n\\nA network\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 65,\n      \"total_tokens\": 75\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Yes, I'm familiar with the SAS (Stat\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Using relays to control 12V solen\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 60,\n      \"total_tokens\": 70\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can use the following Python code to achieve this\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 55,\n      \"total_tokens\": 65\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some prompts for viral comics:\\n\\n1.\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 336,\n      \"total_tokens\": 346\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To simplify and make the comic funnier, consider\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 301,\n      \"total_tokens\": 311\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a rewritten version of the 4-panel\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243278,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 282,\n      \"total_tokens\": 292\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Subject: Request for E-Waste Collection and Computer\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 110,\n      \"total_tokens\": 120\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"In the context of conference calls, the state you\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 84,\n      \"total_tokens\": 94\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"I can provide a general classification of companies based on\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 56,\n      \"total_tokens\": 66\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here are some user stories that describe the concept in\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 44,\n      \"total_tokens\": 54\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"You can check your Python version by running the following\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 39,\n      \"total_tokens\": 49\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"**Scenario:**\\n\\n15-year-old Black youth,\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 473,\n      \"total_tokens\": 483\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"As a Demand Generation Manager for a B2B\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 50,\n      \"total_tokens\": 60\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The error is due to a typo in your code\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 369,\n      \"total_tokens\": 379\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"고등교육의 필요성에 관한 영어 에\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243286,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 72,\n      \"total_tokens\": 82\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"Here's a simple C# program that uses the\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243283,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 51,\n      \"total_tokens\": 61\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"The error message \\\"connection refused\\\" indicates that the\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243277,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 85,\n      \"total_tokens\": 95\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"To load an image, you can use various methods\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1726243284,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"2.2.1-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 41,\n      \"total_tokens\": 51\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.1582031,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 2772,\n        \"logprob\": -0.23083496,\n        \"special\": false,\n        \"text\": \"De\"\n      },\n      {\n        \"id\": 1022,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"ep\"\n      },\n      {\n        \"id\": 6509,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 29892,\n        \"logprob\": -0.61816406,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 607,\n        \"logprob\": -0.7089844,\n        \"special\": false,\n        \"text\": \" which\"\n      },\n      {\n        \"id\": 508,\n        \"logprob\": -1.7724609,\n        \"special\": false,\n        \"text\": \" can\"\n      },\n      {\n        \"id\": 367,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" be\"\n      },\n      {\n        \"id\": 5545,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" considered\"\n      },\n      {\n        \"id\": 408,\n        \"logprob\": -0.3869629,\n        \"special\": false,\n        \"text\": \" as\"\n      }\n    ]\n  },\n  \"generated_text\": \"What is Deep Learning?\\nDeep learning, which can be considered as\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.1845703,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2772,\n          \"logprob\": -0.5727539,\n          \"special\": false,\n          \"text\": \"De\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -0.00010967255,\n          \"special\": false,\n          \"text\": \"ep\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.1239624,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.04510498,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.018295288,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 11306,\n          \"logprob\": -0.45922852,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.00020992756,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 4933,\n          \"logprob\": -0.0046539307,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.00025844574,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ]\n    },\n    \"generated_text\": \"\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.1826172,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2772,\n          \"logprob\": -0.56689453,\n          \"special\": false,\n          \"text\": \"De\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -0.000108003616,\n          \"special\": false,\n          \"text\": \"ep\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.1239624,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.044433594,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.018295288,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 11306,\n          \"logprob\": -0.45922852,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.0002104044,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 4933,\n          \"logprob\": -0.004711151,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.00025892258,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ]\n    },\n    \"generated_text\": \"\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.1826172,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2772,\n          \"logprob\": -0.56689453,\n          \"special\": false,\n          \"text\": \"De\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -0.000108003616,\n          \"special\": false,\n          \"text\": \"ep\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.1239624,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.044433594,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.018295288,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 11306,\n          \"logprob\": -0.45922852,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.0002104044,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 4933,\n          \"logprob\": -0.004711151,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.00025892258,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ]\n    },\n    \"generated_text\": \"\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -1.1826172,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2772,\n          \"logprob\": -0.56689453,\n          \"special\": false,\n          \"text\": \"De\"\n        },\n        {\n          \"id\": 1022,\n          \"logprob\": -0.000108003616,\n          \"special\": false,\n          \"text\": \"ep\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.1239624,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.044433594,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -0.018295288,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 11306,\n          \"logprob\": -0.45922852,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.0002104044,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 4933,\n          \"logprob\": -0.004711151,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 6509,\n          \"logprob\": -0.00025892258,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ]\n    },\n    \"generated_text\": \"\\nDeep learning is a subset of machine learning\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_medusa/test_flash_medusa_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.1845703,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 2772,\n        \"logprob\": -0.5727539,\n        \"special\": false,\n        \"text\": \"De\"\n      },\n      {\n        \"id\": 1022,\n        \"logprob\": -0.000108122826,\n        \"special\": false,\n        \"text\": \"ep\"\n      },\n      {\n        \"id\": 6509,\n        \"logprob\": -0.1239624,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 338,\n        \"logprob\": -0.044433594,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -0.01852417,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 11306,\n        \"logprob\": -0.45922852,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -0.0002104044,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 4933,\n        \"logprob\": -0.004787445,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 6509,\n        \"logprob\": -0.00026226044,\n        \"special\": false,\n        \"text\": \" learning\"\n      }\n    ]\n  },\n  \"generated_text\": \"\\nDeep learning is a subset of machine learning\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 28747,\n        \"logprob\": -0.54785156,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 3169,\n        \"logprob\": -1.4091797,\n        \"special\": false,\n        \"text\": \" Let\"\n      },\n      {\n        \"id\": 307,\n        \"logprob\": -3.0273438,\n        \"special\": false,\n        \"text\": \" n\"\n      },\n      {\n        \"id\": 327,\n        \"logprob\": -0.94433594,\n        \"special\": false,\n        \"text\": \" =\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.81347656,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -1.2958984,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 28734,\n        \"logprob\": -2.0644531,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 387,\n        \"logprob\": -1.9580078,\n        \"special\": false,\n        \"text\": \" -\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.5073242,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -1.1816406,\n        \"special\": false,\n        \"text\": \"1\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \": Let n = 10 - 1\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 28747,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 3169,\n        \"logprob\": -0.1307373,\n        \"special\": false,\n        \"text\": \" Let\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": -2.3359375,\n        \"special\": false,\n        \"text\": \" u\"\n      },\n      {\n        \"id\": 347,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" be\"\n      },\n      {\n        \"id\": 325,\n        \"logprob\": -1.0234375,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 28734,\n        \"logprob\": -2.0292969,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 648,\n        \"logprob\": -1.0439453,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.24499512,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28770,\n        \"logprob\": -0.5073242,\n        \"special\": false,\n        \"text\": \"3\"\n      },\n      {\n        \"id\": 387,\n        \"logprob\": -1.5507812,\n        \"special\": false,\n        \"text\": \" -\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request: Let u be (0 + 3 -\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mistral/test_flash_mistral_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 28747,\n          \"logprob\": -0.55078125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 3169,\n          \"logprob\": -1.4140625,\n          \"special\": false,\n          \"text\": \" Let\"\n        },\n        {\n          \"id\": 307,\n          \"logprob\": -3.0273438,\n          \"special\": false,\n          \"text\": \" n\"\n        },\n        {\n          \"id\": 327,\n          \"logprob\": -0.94140625,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.8173828,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.2978516,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 28734,\n          \"logprob\": -2.0664062,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 387,\n          \"logprob\": -1.9560547,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.5078125,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.1787109,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": Let n = 10 - 1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 28747,\n          \"logprob\": -0.54785156,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 3169,\n          \"logprob\": -1.4111328,\n          \"special\": false,\n          \"text\": \" Let\"\n        },\n        {\n          \"id\": 307,\n          \"logprob\": -3.0292969,\n          \"special\": false,\n          \"text\": \" n\"\n        },\n        {\n          \"id\": 327,\n          \"logprob\": -0.94433594,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.8178711,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.2939453,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 28734,\n          \"logprob\": -2.0644531,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 387,\n          \"logprob\": -1.9550781,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.5078125,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.1796875,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": Let n = 10 - 1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 28747,\n          \"logprob\": -0.55078125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 3169,\n          \"logprob\": -1.4140625,\n          \"special\": false,\n          \"text\": \" Let\"\n        },\n        {\n          \"id\": 307,\n          \"logprob\": -3.0273438,\n          \"special\": false,\n          \"text\": \" n\"\n        },\n        {\n          \"id\": 327,\n          \"logprob\": -0.94140625,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.8173828,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.2978516,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 28734,\n          \"logprob\": -2.0664062,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 387,\n          \"logprob\": -1.9560547,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.5078125,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.1787109,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": Let n = 10 - 1\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 28747,\n          \"logprob\": -0.55078125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 3169,\n          \"logprob\": -1.4140625,\n          \"special\": false,\n          \"text\": \" Let\"\n        },\n        {\n          \"id\": 307,\n          \"logprob\": -3.0273438,\n          \"special\": false,\n          \"text\": \" n\"\n        },\n        {\n          \"id\": 327,\n          \"logprob\": -0.94140625,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.8173828,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.2978516,\n          \"special\": false,\n          \"text\": \"1\"\n        },\n        {\n          \"id\": 28734,\n          \"logprob\": -2.0664062,\n          \"special\": false,\n          \"text\": \"0\"\n        },\n        {\n          \"id\": 387,\n          \"logprob\": -1.9560547,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 28705,\n          \"logprob\": -0.5078125,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 28740,\n          \"logprob\": -1.1787109,\n          \"special\": false,\n          \"text\": \"1\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": Let n = 10 - 1\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 20910,\n        \"logprob\": -0.96484375,\n        \"special\": false,\n        \"text\": \"Grad\"\n      },\n      {\n        \"id\": 722,\n        \"logprob\": -0.003168106,\n        \"special\": false,\n        \"text\": \"ient\"\n      },\n      {\n        \"id\": 24871,\n        \"logprob\": -0.16540527,\n        \"special\": false,\n        \"text\": \" descent\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.08886719,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 396,\n        \"logprob\": -0.75878906,\n        \"special\": false,\n        \"text\": \" an\"\n      },\n      {\n        \"id\": 18586,\n        \"logprob\": -0.5703125,\n        \"special\": false,\n        \"text\": \" optimization\"\n      },\n      {\n        \"id\": 9464,\n        \"logprob\": -0.11242676,\n        \"special\": false,\n        \"text\": \" algorithm\"\n      },\n      {\n        \"id\": 1307,\n        \"logprob\": -0.7939453,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 298,\n        \"logprob\": -0.17102051,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 26518,\n        \"logprob\": -0.34326172,\n        \"special\": false,\n        \"text\": \" minimize\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Gradient descent is an optimization algorithm used to minimize\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 1313,\n        \"logprob\": -2.3613281,\n        \"special\": false,\n        \"text\": \"It\"\n      },\n      {\n        \"id\": 3969,\n        \"logprob\": -0.7285156,\n        \"special\": false,\n        \"text\": \" seems\"\n      },\n      {\n        \"id\": 298,\n        \"logprob\": -1.3466797,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 528,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" me\"\n      },\n      {\n        \"id\": 28725,\n        \"logprob\": -1.6757812,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 369,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 513,\n        \"logprob\": -1.1269531,\n        \"special\": false,\n        \"text\": \" if\"\n      },\n      {\n        \"id\": 368,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" you\"\n      },\n      {\n        \"id\": 28742,\n        \"logprob\": -2.4921875,\n        \"special\": false,\n        \"text\": \"'\"\n      },\n      {\n        \"id\": 267,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"re\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is gradient descent?\\n\\nIt seems to me, that if you're\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral/test_flash_mixtral_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 20910,\n          \"logprob\": -0.96484375,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 722,\n          \"logprob\": -0.003168106,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 24871,\n          \"logprob\": -0.16369629,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.0881958,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 396,\n          \"logprob\": -0.76708984,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 18586,\n          \"logprob\": -0.57373047,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 9464,\n          \"logprob\": -0.11291504,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 1307,\n          \"logprob\": -0.79589844,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 298,\n          \"logprob\": -0.1694336,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 26518,\n          \"logprob\": -0.34350586,\n          \"special\": false,\n          \"text\": \" minimize\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm used to minimize\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 20910,\n          \"logprob\": -0.9628906,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 722,\n          \"logprob\": -0.0032176971,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 24871,\n          \"logprob\": -0.16540527,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.08898926,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 396,\n          \"logprob\": -0.765625,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 18586,\n          \"logprob\": -0.5708008,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 9464,\n          \"logprob\": -0.11401367,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 1307,\n          \"logprob\": -0.7963867,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 298,\n          \"logprob\": -0.17028809,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 26518,\n          \"logprob\": -0.34326172,\n          \"special\": false,\n          \"text\": \" minimize\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm used to minimize\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 20910,\n          \"logprob\": -0.9580078,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 722,\n          \"logprob\": -0.0032176971,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 24871,\n          \"logprob\": -0.16552734,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.08874512,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 396,\n          \"logprob\": -0.75878906,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 18586,\n          \"logprob\": -0.5703125,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 9464,\n          \"logprob\": -0.11236572,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 1307,\n          \"logprob\": -0.79541016,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 298,\n          \"logprob\": -0.17102051,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 26518,\n          \"logprob\": -0.34326172,\n          \"special\": false,\n          \"text\": \" minimize\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm used to minimize\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 20910,\n          \"logprob\": -0.9609375,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 722,\n          \"logprob\": -0.003168106,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 24871,\n          \"logprob\": -0.16601562,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.088134766,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 396,\n          \"logprob\": -0.7597656,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 18586,\n          \"logprob\": -0.5708008,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 9464,\n          \"logprob\": -0.11291504,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 1307,\n          \"logprob\": -0.7944336,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 298,\n          \"logprob\": -0.17102051,\n          \"special\": false,\n          \"text\": \" to\"\n        },\n        {\n          \"id\": 26518,\n          \"logprob\": -0.34399414,\n          \"special\": false,\n          \"text\": \" minimize\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm used to minimize\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.50878906,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.8876953,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 23229,\n        \"logprob\": -0.15124512,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 5168,\n        \"logprob\": -0.030288696,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.16687012,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.17858887,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 19804,\n        \"logprob\": -0.8046875,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": -0.007205963,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5599,\n        \"logprob\": -0.090026855,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 5168,\n        \"logprob\": -0.0030670166,\n        \"special\": false,\n        \"text\": \" learning\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 23229,\n        \"logprob\": -0.5229492,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 17504,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.5151367,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 19804,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 13253,\n        \"logprob\": -1.3359375,\n        \"special\": false,\n        \"text\": \" Machine\"\n      },\n      {\n        \"id\": 17504,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 28725,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \",\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep Learning is a subset of Machine Learning,\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_awq/test_flash_mixtral_awq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.50878906,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.8876953,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.15136719,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.030273438,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.1665039,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.1776123,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.8076172,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.007183075,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.090148926,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.0030670166,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.51220703,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.87402344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.15039062,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.030288696,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.1652832,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.17858887,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.81103516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.007183075,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.08880615,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.0030612946,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.51220703,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.87402344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.15039062,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.030288696,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.1652832,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.17858887,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.81103516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.007183075,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.08880615,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.0030612946,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.51220703,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.87402344,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.15039062,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.030288696,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.1652832,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.17858887,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.81103516,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.007183075,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.08880615,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.0030612946,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.6953125,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.4777832,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 23229,\n        \"logprob\": -0.13256836,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 5168,\n        \"logprob\": -0.023849487,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.13977051,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.14489746,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 19804,\n        \"logprob\": -0.63183594,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": -0.010314941,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5599,\n        \"logprob\": -0.0635376,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 5168,\n        \"logprob\": -0.0028572083,\n        \"special\": false,\n        \"text\": \" learning\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 23229,\n        \"logprob\": -0.18237305,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 17504,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 19804,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" subset\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 13253,\n        \"logprob\": -0.6040039,\n        \"special\": false,\n        \"text\": \" Machine\"\n      },\n      {\n        \"id\": 17504,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 28725,\n        \"logprob\": -0.11621094,\n        \"special\": false,\n        \"text\": \",\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is deep learning?\\nDeep Learning is a subset of Machine Learning,\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_mixtral_gptq/test_flash_mixtral_gptq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.4777832,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.13232422,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.023834229,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.13977051,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.14416504,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.63183594,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.010223389,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.064208984,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.0028266907,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.48339844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.13256836,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.02420044,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.13977051,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.14501953,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.63134766,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.010223389,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.06427002,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.002817154,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.48339844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.13256836,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.02420044,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.13977051,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.14501953,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.63134766,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.010223389,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.06427002,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.002817154,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.6953125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.48339844,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 23229,\n          \"logprob\": -0.13256836,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.02420044,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.13977051,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.14501953,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 19804,\n          \"logprob\": -0.63134766,\n          \"special\": false,\n          \"text\": \" subset\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.010223389,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5599,\n          \"logprob\": -0.06427002,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 5168,\n          \"logprob\": -0.002817154,\n          \"special\": false,\n          \"text\": \" learning\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a subset of machine learning\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 42,\n        \"logprob\": -0.88378906,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 1353,\n        \"logprob\": -0.94921875,\n        \"special\": false,\n        \"text\": \"'m\"\n      },\n      {\n        \"id\": 417,\n        \"logprob\": -2.2402344,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 2119,\n        \"logprob\": -0.3725586,\n        \"special\": false,\n        \"text\": \" sure\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.078125,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 534,\n        \"logprob\": -0.67822266,\n        \"special\": false,\n        \"text\": \" which\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -1.3837891,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 253,\n        \"logprob\": -1.7050781,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 1682,\n        \"logprob\": -0.052001953,\n        \"special\": false,\n        \"text\": \" best\"\n      },\n      {\n        \"id\": 1039,\n        \"logprob\": -2.0390625,\n        \"special\": false,\n        \"text\": \" way\"\n      }\n    ]\n  },\n  \"generated_text\": \"I'm not sure, which is the best way\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_neox/test_flash_neox_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8886719,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.98046875,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 417,\n          \"logprob\": -2.2265625,\n          \"special\": false,\n          \"text\": \" not\"\n        },\n        {\n          \"id\": 2119,\n          \"logprob\": -0.3479004,\n          \"special\": false,\n          \"text\": \" sure\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.0117188,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 534,\n          \"logprob\": -0.67871094,\n          \"special\": false,\n          \"text\": \" which\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 253,\n          \"logprob\": -1.7382812,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1682,\n          \"logprob\": -0.051330566,\n          \"special\": false,\n          \"text\": \" best\"\n        },\n        {\n          \"id\": 1039,\n          \"logprob\": -2.0390625,\n          \"special\": false,\n          \"text\": \" way\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm not sure, which is the best way\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.88378906,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.9819336,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 417,\n          \"logprob\": -2.2421875,\n          \"special\": false,\n          \"text\": \" not\"\n        },\n        {\n          \"id\": 2119,\n          \"logprob\": -0.3474121,\n          \"special\": false,\n          \"text\": \" sure\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.078125,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 534,\n          \"logprob\": -0.69140625,\n          \"special\": false,\n          \"text\": \" which\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.4072266,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 253,\n          \"logprob\": -1.7041016,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1682,\n          \"logprob\": -0.053375244,\n          \"special\": false,\n          \"text\": \" best\"\n        },\n        {\n          \"id\": 1039,\n          \"logprob\": -2.0351562,\n          \"special\": false,\n          \"text\": \" way\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm not sure, which is the best way\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8886719,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.98046875,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 417,\n          \"logprob\": -2.2265625,\n          \"special\": false,\n          \"text\": \" not\"\n        },\n        {\n          \"id\": 2119,\n          \"logprob\": -0.3479004,\n          \"special\": false,\n          \"text\": \" sure\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.0117188,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 534,\n          \"logprob\": -0.67871094,\n          \"special\": false,\n          \"text\": \" which\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 253,\n          \"logprob\": -1.7382812,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1682,\n          \"logprob\": -0.051330566,\n          \"special\": false,\n          \"text\": \" best\"\n        },\n        {\n          \"id\": 1039,\n          \"logprob\": -2.0390625,\n          \"special\": false,\n          \"text\": \" way\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm not sure, which is the best way\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8886719,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.98046875,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 417,\n          \"logprob\": -2.2265625,\n          \"special\": false,\n          \"text\": \" not\"\n        },\n        {\n          \"id\": 2119,\n          \"logprob\": -0.3479004,\n          \"special\": false,\n          \"text\": \" sure\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.0117188,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 534,\n          \"logprob\": -0.67871094,\n          \"special\": false,\n          \"text\": \" which\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.421875,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 253,\n          \"logprob\": -1.7382812,\n          \"special\": false,\n          \"text\": \" the\"\n        },\n        {\n          \"id\": 1682,\n          \"logprob\": -0.051330566,\n          \"special\": false,\n          \"text\": \" best\"\n        },\n        {\n          \"id\": 1039,\n          \"logprob\": -2.0390625,\n          \"special\": false,\n          \"text\": \" way\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm not sure, which is the best way\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 510,\n        \"logprob\": -0.63183594,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 3159,\n        \"logprob\": -0.5390625,\n        \"special\": false,\n        \"text\": \" word\"\n      },\n      {\n        \"id\": 346,\n        \"logprob\": -0.045684814,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 6441,\n        \"logprob\": -0.002090454,\n        \"special\": false,\n        \"text\": \"mem\"\n      },\n      {\n        \"id\": 70,\n        \"logprob\": -1.3589859e-05,\n        \"special\": false,\n        \"text\": \"e\"\n      },\n      {\n        \"id\": 3,\n        \"logprob\": -0.0009455681,\n        \"special\": false,\n        \"text\": \"\\\"\"\n      },\n      {\n        \"id\": 369,\n        \"logprob\": -0.088012695,\n        \"special\": false,\n        \"text\": \" was\"\n      },\n      {\n        \"id\": 806,\n        \"logprob\": -0.12585449,\n        \"special\": false,\n        \"text\": \" first\"\n      },\n      {\n        \"id\": 908,\n        \"logprob\": -0.017196655,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 275,\n        \"logprob\": -0.49731445,\n        \"special\": false,\n        \"text\": \" in\"\n      }\n    ]\n  },\n  \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_neox_sharded/test_flash_neox_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.63183594,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.5488281,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.045684814,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.00207901,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.335144e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00097227097,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.0892334,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12463379,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.01737976,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.50341797,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.63183594,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.5488281,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.045684814,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.00207901,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.335144e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00097227097,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.0892334,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12463379,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.01737976,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.50341797,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.63183594,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.5488281,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.045684814,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.00207901,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.335144e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00097227097,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.0892334,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12463379,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.01737976,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.50341797,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.63183594,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.5488281,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.045684814,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.00207901,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.335144e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00097227097,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.0892334,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12463379,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.01737976,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.50341797,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 2,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 54901,\n        \"logprob\": -0.84765625,\n        \"special\": false,\n        \"text\": \"beach\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": -0.008666992,\n        \"special\": true,\n        \"text\": \"<eos>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"beach\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_pali_gemma/test_flash_pali_gemma_two_images.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 8,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2502,\n        \"logprob\": -1.7890625,\n        \"special\": false,\n        \"text\": \"image\"\n      },\n      {\n        \"id\": 2196,\n        \"logprob\": -0.53125,\n        \"special\": false,\n        \"text\": \" result\"\n      },\n      {\n        \"id\": 604,\n        \"logprob\": -0.0077209473,\n        \"special\": false,\n        \"text\": \" for\"\n      },\n      {\n        \"id\": 12254,\n        \"logprob\": -1.703125,\n        \"special\": false,\n        \"text\": \" chicken\"\n      },\n      {\n        \"id\": 611,\n        \"logprob\": -0.21582031,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 573,\n        \"logprob\": -0.734375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 8318,\n        \"logprob\": -0.026000977,\n        \"special\": false,\n        \"text\": \" beach\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": -0.2109375,\n        \"special\": true,\n        \"text\": \"<eos>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"image result for chicken on the beach\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_pali_gemma2/test_flash_pali_gemma_image.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 20,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 108,\n        \"logprob\": -0.48046875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 30234,\n        \"logprob\": -2.21875,\n        \"special\": false,\n        \"text\": \"Brown\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.119140625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 3726,\n        \"logprob\": -1.703125,\n        \"special\": false,\n        \"text\": \"Car\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.0390625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 2915,\n        \"logprob\": -1.8203125,\n        \"special\": false,\n        \"text\": \"Color\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.035888672,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 19178,\n        \"logprob\": -2.015625,\n        \"special\": false,\n        \"text\": \"Cool\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.08105469,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40544,\n        \"logprob\": -2.09375,\n        \"special\": false,\n        \"text\": \"Decor\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.038330078,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.515625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.8671875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.6328125,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.265625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.0078125,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -1.03125,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 235336,\n        \"logprob\": -1.2109375,\n        \"special\": false,\n        \"text\": \"?\"\n      },\n      {\n        \"id\": 108,\n        \"logprob\": -0.29101562,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 235336,\n        \"logprob\": -0.08935547,\n        \"special\": false,\n        \"text\": \"?\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\nBrown\\nCar\\nColor\\nCool\\nDecor\\n\\n\\n\\n\\n\\n\\n?\\n?\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 25,\n        \"logprob\": -2.3203125,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 1391,\n        \"logprob\": -0.98779297,\n        \"special\": false,\n        \"text\": \" {\"\n      },\n      {\n        \"id\": 25927,\n        \"logprob\": -0.76660156,\n        \"special\": false,\n        \"text\": \"request\"\n      },\n      {\n        \"id\": 92,\n        \"logprob\": -0.7246094,\n        \"special\": false,\n        \"text\": \"}\"\n      },\n      {\n        \"id\": 4943,\n        \"logprob\": -0.41333008,\n        \"special\": false,\n        \"text\": \"\\\")\"\n      },\n      {\n        \"id\": 198,\n        \"logprob\": -0.11785889,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 50280,\n        \"logprob\": -0.97265625,\n        \"special\": false,\n        \"text\": \"        \"\n      },\n      {\n        \"id\": 26209,\n        \"logprob\": -1.4414062,\n        \"special\": false,\n        \"text\": \"response\"\n      },\n      {\n        \"id\": 796,\n        \"logprob\": -0.0569458,\n        \"special\": false,\n        \"text\": \" =\"\n      },\n      {\n        \"id\": 2116,\n        \"logprob\": -1.1533203,\n        \"special\": false,\n        \"text\": \" self\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \": {request}\\\")\\n        response = self\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"stop_sequence\",\n    \"generated_tokens\": 6,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 284,\n        \"logprob\": -0.28955078,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 3758,\n        \"logprob\": -0.7739258,\n        \"special\": false,\n        \"text\": \" send\"\n      },\n      {\n        \"id\": 1366,\n        \"logprob\": -0.85253906,\n        \"special\": false,\n        \"text\": \" data\"\n      },\n      {\n        \"id\": 625,\n        \"logprob\": -0.8984375,\n        \"special\": false,\n        \"text\": \" over\"\n      },\n      {\n        \"id\": 257,\n        \"logprob\": -1.0830078,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 3127,\n        \"logprob\": -1.9404297,\n        \"special\": false,\n        \"text\": \" network\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request to send data over a network\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi/test_flash_phi_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.3203125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1391,\n          \"logprob\": -0.98779297,\n          \"special\": false,\n          \"text\": \" {\"\n        },\n        {\n          \"id\": 25927,\n          \"logprob\": -0.7729492,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.7241211,\n          \"special\": false,\n          \"text\": \"}\"\n        },\n        {\n          \"id\": 4943,\n          \"logprob\": -0.4091797,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.119018555,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50280,\n          \"logprob\": -0.9707031,\n          \"special\": false,\n          \"text\": \"        \"\n        },\n        {\n          \"id\": 26209,\n          \"logprob\": -1.4414062,\n          \"special\": false,\n          \"text\": \"response\"\n        },\n        {\n          \"id\": 796,\n          \"logprob\": -0.056854248,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 2116,\n          \"logprob\": -1.1533203,\n          \"special\": false,\n          \"text\": \" self\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": {request}\\\")\\n        response = self\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.3203125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1391,\n          \"logprob\": -0.98779297,\n          \"special\": false,\n          \"text\": \" {\"\n        },\n        {\n          \"id\": 25927,\n          \"logprob\": -0.7729492,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.7241211,\n          \"special\": false,\n          \"text\": \"}\"\n        },\n        {\n          \"id\": 4943,\n          \"logprob\": -0.4091797,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.119018555,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50280,\n          \"logprob\": -0.9707031,\n          \"special\": false,\n          \"text\": \"        \"\n        },\n        {\n          \"id\": 26209,\n          \"logprob\": -1.4414062,\n          \"special\": false,\n          \"text\": \"response\"\n        },\n        {\n          \"id\": 796,\n          \"logprob\": -0.056854248,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 2116,\n          \"logprob\": -1.1533203,\n          \"special\": false,\n          \"text\": \" self\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": {request}\\\")\\n        response = self\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.3203125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1391,\n          \"logprob\": -0.98779297,\n          \"special\": false,\n          \"text\": \" {\"\n        },\n        {\n          \"id\": 25927,\n          \"logprob\": -0.7729492,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.7241211,\n          \"special\": false,\n          \"text\": \"}\"\n        },\n        {\n          \"id\": 4943,\n          \"logprob\": -0.4091797,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.119018555,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50280,\n          \"logprob\": -0.9707031,\n          \"special\": false,\n          \"text\": \"        \"\n        },\n        {\n          \"id\": 26209,\n          \"logprob\": -1.4414062,\n          \"special\": false,\n          \"text\": \"response\"\n        },\n        {\n          \"id\": 796,\n          \"logprob\": -0.056854248,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 2116,\n          \"logprob\": -1.1533203,\n          \"special\": false,\n          \"text\": \" self\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": {request}\\\")\\n        response = self\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25,\n          \"logprob\": -2.3203125,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 1391,\n          \"logprob\": -0.98779297,\n          \"special\": false,\n          \"text\": \" {\"\n        },\n        {\n          \"id\": 25927,\n          \"logprob\": -0.7729492,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.7241211,\n          \"special\": false,\n          \"text\": \"}\"\n        },\n        {\n          \"id\": 4943,\n          \"logprob\": -0.4091797,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -0.119018555,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50280,\n          \"logprob\": -0.9707031,\n          \"special\": false,\n          \"text\": \"        \"\n        },\n        {\n          \"id\": 26209,\n          \"logprob\": -1.4414062,\n          \"special\": false,\n          \"text\": \"response\"\n        },\n        {\n          \"id\": 796,\n          \"logprob\": -0.056854248,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 2116,\n          \"logprob\": -1.1533203,\n          \"special\": false,\n          \"text\": \" self\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \": {request}\\\")\\n        response = self\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 25584,\n        \"logprob\": -0.008979797,\n        \"special\": false,\n        \"text\": \"Grad\"\n      },\n      {\n        \"id\": 993,\n        \"logprob\": -8.34465e-07,\n        \"special\": false,\n        \"text\": \"ient\"\n      },\n      {\n        \"id\": 26815,\n        \"logprob\": -0.0009407997,\n        \"special\": false,\n        \"text\": \" descent\"\n      },\n      {\n        \"id\": 338,\n        \"logprob\": -0.0003838539,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 385,\n        \"logprob\": -0.24499512,\n        \"special\": false,\n        \"text\": \" an\"\n      },\n      {\n        \"id\": 13883,\n        \"logprob\": -0.010406494,\n        \"special\": false,\n        \"text\": \" optimization\"\n      },\n      {\n        \"id\": 5687,\n        \"logprob\": -0.00024354458,\n        \"special\": false,\n        \"text\": \" algorithm\"\n      },\n      {\n        \"id\": 15574,\n        \"logprob\": -0.6582031,\n        \"special\": false,\n        \"text\": \" commonly\"\n      },\n      {\n        \"id\": 1304,\n        \"logprob\": -0.00092840195,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 297,\n        \"logprob\": -0.19470215,\n        \"special\": false,\n        \"text\": \" in\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Gradient descent is an optimization algorithm commonly used in\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 25584,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Grad\"\n      },\n      {\n        \"id\": 993,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"ient\"\n      },\n      {\n        \"id\": 2726,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" Des\"\n      },\n      {\n        \"id\": 1760,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"cent\"\n      },\n      {\n        \"id\": 313,\n        \"logprob\": -0.12322998,\n        \"special\": false,\n        \"text\": \" (\"\n      },\n      {\n        \"id\": 29954,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"G\"\n      },\n      {\n        \"id\": 29928,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"D\"\n      },\n      {\n        \"id\": 29897,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \")\"\n      },\n      {\n        \"id\": 338,\n        \"logprob\": -0.6040039,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 385,\n        \"logprob\": -0.1796875,\n        \"special\": false,\n        \"text\": \" an\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"What is gradient descent?\\nGradient Descent (GD) is an\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_phi35_moe/test_flash_phi35_moe_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25584,\n          \"logprob\": -0.008979797,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 993,\n          \"logprob\": -8.34465e-07,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 26815,\n          \"logprob\": -0.00097084045,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.0003838539,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 385,\n          \"logprob\": -0.23840332,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 13883,\n          \"logprob\": -0.010406494,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 5687,\n          \"logprob\": -0.0002501011,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 15574,\n          \"logprob\": -0.6582031,\n          \"special\": false,\n          \"text\": \" commonly\"\n        },\n        {\n          \"id\": 1304,\n          \"logprob\": -0.00092840195,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 297,\n          \"logprob\": -0.18933105,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm commonly used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25584,\n          \"logprob\": -0.009017944,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 993,\n          \"logprob\": -9.536743e-07,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 26815,\n          \"logprob\": -0.00097084045,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.0003838539,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 385,\n          \"logprob\": -0.24499512,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 13883,\n          \"logprob\": -0.010406494,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 5687,\n          \"logprob\": -0.0002501011,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 15574,\n          \"logprob\": -0.6435547,\n          \"special\": false,\n          \"text\": \" commonly\"\n        },\n        {\n          \"id\": 1304,\n          \"logprob\": -0.0009279251,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 297,\n          \"logprob\": -0.18933105,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm commonly used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25584,\n          \"logprob\": -0.008956909,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 993,\n          \"logprob\": -8.34465e-07,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 26815,\n          \"logprob\": -0.0009407997,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.0003721714,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 385,\n          \"logprob\": -0.24499512,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 13883,\n          \"logprob\": -0.010406494,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 5687,\n          \"logprob\": -0.0002501011,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 15574,\n          \"logprob\": -0.6435547,\n          \"special\": false,\n          \"text\": \" commonly\"\n        },\n        {\n          \"id\": 1304,\n          \"logprob\": -0.00092601776,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 297,\n          \"logprob\": -0.19177246,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm commonly used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 25584,\n          \"logprob\": -0.008979797,\n          \"special\": false,\n          \"text\": \"Grad\"\n        },\n        {\n          \"id\": 993,\n          \"logprob\": -9.536743e-07,\n          \"special\": false,\n          \"text\": \"ient\"\n        },\n        {\n          \"id\": 26815,\n          \"logprob\": -0.0009407997,\n          \"special\": false,\n          \"text\": \" descent\"\n        },\n        {\n          \"id\": 338,\n          \"logprob\": -0.00038409233,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 385,\n          \"logprob\": -0.24499512,\n          \"special\": false,\n          \"text\": \" an\"\n        },\n        {\n          \"id\": 13883,\n          \"logprob\": -0.010414124,\n          \"special\": false,\n          \"text\": \" optimization\"\n        },\n        {\n          \"id\": 5687,\n          \"logprob\": -0.00024354458,\n          \"special\": false,\n          \"text\": \" algorithm\"\n        },\n        {\n          \"id\": 15574,\n          \"logprob\": -0.6435547,\n          \"special\": false,\n          \"text\": \" commonly\"\n        },\n        {\n          \"id\": 1304,\n          \"logprob\": -0.0009279251,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 297,\n          \"logprob\": -0.19470215,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"Gradient descent is an optimization algorithm commonly used in\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 198,\n        \"logprob\": -2.9023438,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 2,\n        \"logprob\": -2.9160156,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 4230,\n        \"logprob\": -3.1035156,\n        \"special\": false,\n        \"text\": \" Create\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -1.1025391,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1681,\n        \"logprob\": -1.6914062,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 198,\n        \"logprob\": -1.1953125,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 2035,\n        \"logprob\": -1.3203125,\n        \"special\": false,\n        \"text\": \"request\"\n      },\n      {\n        \"id\": 284,\n        \"logprob\": -0.13537598,\n        \"special\": false,\n        \"text\": \" =\"\n      },\n      {\n        \"id\": 7388,\n        \"logprob\": -1.2402344,\n        \"special\": false,\n        \"text\": \" requests\"\n      },\n      {\n        \"id\": 670,\n        \"logprob\": -0.2775879,\n        \"special\": false,\n        \"text\": \".get\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n# Create a request\\nrequest = requests.get\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 311,\n        \"logprob\": -1.4277344,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 279,\n        \"logprob\": -0.65478516,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2473,\n        \"logprob\": -1.8300781,\n        \"special\": false,\n        \"text\": \" service\"\n      },\n      {\n        \"id\": 382,\n        \"logprob\": -0.75,\n        \"special\": false,\n        \"text\": \".\\n\\n\"\n      },\n      {\n        \"id\": 286,\n        \"logprob\": -0.11621094,\n        \"special\": false,\n        \"text\": \"       \"\n      },\n      {\n        \"id\": 549,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" :\"\n      },\n      {\n        \"id\": 689,\n        \"logprob\": -0.48608398,\n        \"special\": false,\n        \"text\": \"return\"\n      },\n      {\n        \"id\": 25,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 5949,\n        \"logprob\": -0.5756836,\n        \"special\": false,\n        \"text\": \" Response\"\n      },\n      {\n        \"id\": 504,\n        \"logprob\": -0.24499512,\n        \"special\": false,\n        \"text\": \" from\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request to the service.\\n\\n        :return: Response from\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2/test_flash_qwen2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.9023438,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2,\n          \"logprob\": -2.9140625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 4230,\n          \"logprob\": -3.1054688,\n          \"special\": false,\n          \"text\": \" Create\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1681,\n          \"logprob\": -1.6914062,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -1.1923828,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2035,\n          \"logprob\": -1.3193359,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.13586426,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 7388,\n          \"logprob\": -1.2412109,\n          \"special\": false,\n          \"text\": \" requests\"\n        },\n        {\n          \"id\": 670,\n          \"logprob\": -0.2775879,\n          \"special\": false,\n          \"text\": \".get\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n# Create a request\\nrequest = requests.get\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.9023438,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2,\n          \"logprob\": -2.9140625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 4230,\n          \"logprob\": -3.1054688,\n          \"special\": false,\n          \"text\": \" Create\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1681,\n          \"logprob\": -1.6914062,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -1.1923828,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2035,\n          \"logprob\": -1.3193359,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.13586426,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 7388,\n          \"logprob\": -1.2412109,\n          \"special\": false,\n          \"text\": \" requests\"\n        },\n        {\n          \"id\": 670,\n          \"logprob\": -0.2775879,\n          \"special\": false,\n          \"text\": \".get\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n# Create a request\\nrequest = requests.get\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.9023438,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2,\n          \"logprob\": -2.9140625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 4230,\n          \"logprob\": -3.1054688,\n          \"special\": false,\n          \"text\": \" Create\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1681,\n          \"logprob\": -1.6914062,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -1.1923828,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2035,\n          \"logprob\": -1.3193359,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.13586426,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 7388,\n          \"logprob\": -1.2412109,\n          \"special\": false,\n          \"text\": \" requests\"\n        },\n        {\n          \"id\": 670,\n          \"logprob\": -0.2775879,\n          \"special\": false,\n          \"text\": \".get\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n# Create a request\\nrequest = requests.get\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 198,\n          \"logprob\": -2.9023438,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2,\n          \"logprob\": -2.9140625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 4230,\n          \"logprob\": -3.1054688,\n          \"special\": false,\n          \"text\": \" Create\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 1681,\n          \"logprob\": -1.6914062,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 198,\n          \"logprob\": -1.1923828,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 2035,\n          \"logprob\": -1.3193359,\n          \"special\": false,\n          \"text\": \"request\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.13586426,\n          \"special\": false,\n          \"text\": \" =\"\n        },\n        {\n          \"id\": 7388,\n          \"logprob\": -1.2412109,\n          \"special\": false,\n          \"text\": \" requests\"\n        },\n        {\n          \"id\": 670,\n          \"logprob\": -0.2775879,\n          \"special\": false,\n          \"text\": \".get\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n# Create a request\\nrequest = requests.get\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_bay.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image showcases the Statue of Liberty, a colossal bronze statue located in New York Harbor, a heritage building in the United States. The statue has a majestic presence, with one arm raised towards the sun and the other hitched on her hip. It sits atop a keeper's walkway, observed from the water. Surrounding the statue is a lush green meadow, where picnic spots, walkways, and a visitor desk can be found. In front of the statue, a large marina can accommodate fourteen different kinds of boats. In the backdrop stands the Empire State Building, marking the crowded skyscrapers of New York City.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738342753,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2.5-VL-3B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.0.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 128,\n    \"prompt_tokens\": 8736,\n    \"total_tokens\": 8864\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_inpaint.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image shows a whimsical scene set in what appears to be a fast-food restaurant. Dominating the foreground is a large, green, inflatable dinosaur with realistic textures, giving it a Jurassic Park-like appearance. The dinosaur is wearing a red Adult Swim logo hat, adding a humorous touch to its appearance.\\n\\nSurrounding the dinosaur are various food items typically found in a fast-food restaurant, including French fries in a plastic cup, a hamburger on a plate, and a beverage in another cup. The hamburger is detailed with lettuce, tomato, and other typical fast-food ingredients.\\n\\nAccompanying the dinosaur is a realistic-looking owl perched on the table, which adds to the surreal and playful atmosphere of the scene. The background features the interior of the restaurant with neon signs and other typical decor elements, enhancing the overall theme of a fun and fantastical fast-food experience.\\n\\nOverall, the image is a playful and imaginative blend of a standard fast-food setting with an unexpected and amusing twist provided by the dinosaur and owl characters.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738343775,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2.5-VL-3B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.0.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 206,\n    \"prompt_tokens\": 5375,\n    \"total_tokens\": 5581\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738342872,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2.5-VL-3B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.0.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 121,\n    \"prompt_tokens\": 1363,\n    \"total_tokens\": 1484\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_5_vl/test_flash_qwen2_5_vl_simple_streaming.json",
    "content": "{\n  \"choices\": [\n    {\n      \"delta\": {\n        \"content\": \"\",\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null\n    }\n  ],\n  \"created\": 1738343559,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2.5-VL-3B-Instruct\",\n  \"object\": \"chat.completion.chunk\",\n  \"system_fingerprint\": \"3.0.2-dev0-native\",\n  \"usage\": null\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_bay.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image showcases a stunning cityscape, featuring the iconic Statue of Liberty in the foreground. The image displays Lady Liberty's imposing presence, with her towering base standing beside her. Behind the statue, the city's skyline extends across the horizon, adorned with numerous tall buildings, including the Empire State Building and other notable skyscrapers. The water reflecting the sun's rays creates a serene and picturesque scene, emphasizing the beauty and resilience of this global landmark. The sky is a clear, pale blue, adding to the overall tranquility of the scene.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738348090,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2-VL-7B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 110,\n    \"prompt_tokens\": 8736,\n    \"total_tokens\": 8846\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_inpaint.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image shows a stylized scene set in what appears to be a diner or restaurant. In the foreground, there is a table with various food items, including a burger with lettuce and tomato, a bowl of fries, and a drink in a cup with a straw. On the right side of the table, there is an owl sitting alertly, looking directly at the camera. Behind the owl and the table, there is a large, green, dinosaur-like creature resembling Godzilla, with its mouth open and tongue visible. In the background, the diner's decor includes various signs and posters, with a green sign reading \\\"Basta\\\" and another sign that says \\\"Tabasco.\\\" The setting has a retro or vintage feel, with fluorescent lighting overhead and clean, polished surfaces.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738348100,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2-VL-7B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 156,\n    \"prompt_tokens\": 5375,\n    \"total_tokens\": 5531\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1738347908,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2-VL-7B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 89,\n    \"prompt_tokens\": 1364,\n    \"total_tokens\": 1453\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_qwen2_vl/test_flash_qwen2_vl_simple_streaming.json",
    "content": "{\n  \"choices\": [\n    {\n      \"delta\": {\n        \"content\": \"\",\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null\n    }\n  ],\n  \"created\": 1737646031,\n  \"id\": \"\",\n  \"model\": \"Qwen/Qwen2-VL-7B-Instruct\",\n  \"object\": \"chat.completion.chunk\",\n  \"system_fingerprint\": \"3.0.2-dev0-native\",\n  \"usage\": null\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 1241,\n        \"logprob\": -0.9863281,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 258,\n        \"logprob\": -0.21447754,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 942,\n        \"logprob\": -0.43701172,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 372,\n        \"logprob\": -0.5361328,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 7371,\n        \"logprob\": -0.44555664,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 9956,\n        \"logprob\": -1.2412109,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 8657,\n        \"logprob\": -0.7583008,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 185,\n        \"logprob\": -0.76171875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 185,\n        \"logprob\": -0.20837402,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 1018,\n        \"logprob\": -1.2470703,\n        \"special\": false,\n        \"text\": \"print\"\n      }\n    ]\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\nprint\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_santacoder/test_flash_santacoder_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1241,\n          \"logprob\": -0.9863281,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 258,\n          \"logprob\": -0.21362305,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 942,\n          \"logprob\": -0.44360352,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 372,\n          \"logprob\": -0.54248047,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 7371,\n          \"logprob\": -0.44555664,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 9956,\n          \"logprob\": -1.2441406,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 8657,\n          \"logprob\": -0.75878906,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.76171875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 1018,\n          \"logprob\": -1.2460938,\n          \"special\": false,\n          \"text\": \"print\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\nprint\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1241,\n          \"logprob\": -0.9863281,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 258,\n          \"logprob\": -0.21362305,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 942,\n          \"logprob\": -0.44360352,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 372,\n          \"logprob\": -0.54248047,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 7371,\n          \"logprob\": -0.44555664,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 9956,\n          \"logprob\": -1.2441406,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 8657,\n          \"logprob\": -0.75878906,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.76171875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 1018,\n          \"logprob\": -1.2460938,\n          \"special\": false,\n          \"text\": \"print\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\nprint\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1241,\n          \"logprob\": -0.9863281,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 258,\n          \"logprob\": -0.21362305,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 942,\n          \"logprob\": -0.44360352,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 372,\n          \"logprob\": -0.54248047,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 7371,\n          \"logprob\": -0.44555664,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 9956,\n          \"logprob\": -1.2441406,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 8657,\n          \"logprob\": -0.75878906,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.76171875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 1018,\n          \"logprob\": -1.2460938,\n          \"special\": false,\n          \"text\": \"print\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\nprint\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 1241,\n          \"logprob\": -0.9863281,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 258,\n          \"logprob\": -0.21362305,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 942,\n          \"logprob\": -0.44360352,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 372,\n          \"logprob\": -0.54248047,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 7371,\n          \"logprob\": -0.44555664,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 9956,\n          \"logprob\": -1.2441406,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 8657,\n          \"logprob\": -0.75878906,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.76171875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 185,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 1018,\n          \"logprob\": -1.2460938,\n          \"special\": false,\n          \"text\": \"print\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\nprint\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2262,\n        \"logprob\": -0.7705078,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 284,\n        \"logprob\": -0.2590332,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": -0.39379883,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 440,\n        \"logprob\": -0.61376953,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8279,\n        \"logprob\": -0.47338867,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10896,\n        \"logprob\": -1.5068359,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 657,\n        \"logprob\": -0.80810547,\n        \"special\": false,\n        \"text\": \"\\\")\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": -0.7397461,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": -0.35229492,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 589,\n        \"logprob\": -1.0371094,\n        \"special\": false,\n        \"text\": \"def\"\n      }\n    ]\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_default_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 60,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 2262,\n        \"logprob\": -0.045715332,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 284,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 440,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8279,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10896,\n        \"logprob\": -0.3659668,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 657,\n        \"logprob\": -0.5229492,\n        \"special\": false,\n        \"text\": \"\\\")\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": -0.10632324,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 589,\n        \"logprob\": -0.20141602,\n        \"special\": false,\n        \"text\": \"def\"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 81,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 7656,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"hello\"\n      },\n      {\n        \"id\": 81,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 426,\n        \"logprob\": -0.051635742,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 426,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 711,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"):\"\n      },\n      {\n        \"id\": 284,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 440,\n        \"logprob\": -0.16027832,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8279,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 313,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 474,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 636,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" name\"\n      },\n      {\n        \"id\": 27,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \")\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 589,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"def\"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 81,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 7656,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"hello\"\n      },\n      {\n        \"id\": 81,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 426,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 81,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 381,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"age\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 426,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 30,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 11442,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" age\"\n      },\n      {\n        \"id\": 711,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"):\"\n      },\n      {\n        \"id\": 284,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 440,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8279,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 313,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 474,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 636,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" name\"\n      },\n      {\n        \"id\": 474,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 313,\n        \"logprob\": -0.6933594,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 313,\n        \"logprob\": -1.7011719,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 474,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 596,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" str\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 381,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"age\"\n      },\n      {\n        \"id\": 490,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"))\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 203,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 589,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"def\"\n      },\n      {\n        \"id\": 1459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef print_hello_name(name):\\n    print(\\\"Hello \\\" + name)\\n\\ndef print_hello_name_age(name, age):\\n    print(\\\"Hello \\\" + name + \\\" \\\" + str(age))\\n\\ndef print\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder/test_flash_starcoder_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2262,\n          \"logprob\": -0.7705078,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.2602539,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1459,\n          \"logprob\": -0.39282227,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 440,\n          \"logprob\": -0.6113281,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8279,\n          \"logprob\": -0.4765625,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10896,\n          \"logprob\": -1.5068359,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 657,\n          \"logprob\": -0.8154297,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.7319336,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.35229492,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 589,\n          \"logprob\": -1.0380859,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2262,\n          \"logprob\": -0.7705078,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.2602539,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1459,\n          \"logprob\": -0.39282227,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 440,\n          \"logprob\": -0.6113281,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8279,\n          \"logprob\": -0.4765625,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10896,\n          \"logprob\": -1.5068359,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 657,\n          \"logprob\": -0.8154297,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.7319336,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.35229492,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 589,\n          \"logprob\": -1.0380859,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2262,\n          \"logprob\": -0.7705078,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.2602539,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1459,\n          \"logprob\": -0.39282227,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 440,\n          \"logprob\": -0.6113281,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8279,\n          \"logprob\": -0.4765625,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10896,\n          \"logprob\": -1.5068359,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 657,\n          \"logprob\": -0.8154297,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.7319336,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.35229492,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 589,\n          \"logprob\": -1.0380859,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2262,\n          \"logprob\": -0.7705078,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 284,\n          \"logprob\": -0.2602539,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1459,\n          \"logprob\": -0.39282227,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 440,\n          \"logprob\": -0.6113281,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8279,\n          \"logprob\": -0.4765625,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10896,\n          \"logprob\": -1.5068359,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 657,\n          \"logprob\": -0.8154297,\n          \"special\": false,\n          \"text\": \"\\\")\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.7319336,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 203,\n          \"logprob\": -0.35229492,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 589,\n          \"logprob\": -1.0380859,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ]\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World\\\")\\n\\ndef\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2284,\n        \"logprob\": -0.92626953,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": -0.40844727,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": -0.27905273,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": -0.6118164,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": -0.68652344,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10914,\n        \"logprob\": -1.4619141,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 16013,\n        \"logprob\": -0.7993164,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.63134766,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.23278809,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 610,\n        \"logprob\": -1.2294922,\n        \"special\": false,\n        \"text\": \"def\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 60,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 2284,\n        \"logprob\": -0.31323242,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": -0.26611328,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10914,\n        \"logprob\": -0.7871094,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 16013,\n        \"logprob\": -0.64746094,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.054870605,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 610,\n        \"logprob\": -0.41064453,\n        \"special\": false,\n        \"text\": \"def\"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 7670,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"hello\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 444,\n        \"logprob\": -0.21655273,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 45,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 444,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 731,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"):\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": -0.034698486,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 655,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" name\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": -0.20141602,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 16013,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 610,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"def\"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 7670,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"hello\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 444,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 400,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"age\"\n      },\n      {\n        \"id\": 45,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 444,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"name\"\n      },\n      {\n        \"id\": 49,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 11505,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" age\"\n      },\n      {\n        \"id\": 731,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"):\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 655,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" name\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 3021,\n        \"logprob\": -0.5761719,\n        \"special\": false,\n        \"text\": \" \\\",\"\n      },\n      {\n        \"id\": 863,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" you\"\n      },\n      {\n        \"id\": 904,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" are\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 615,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" str\"\n      },\n      {\n        \"id\": 45,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"(\"\n      },\n      {\n        \"id\": 400,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"age\"\n      },\n      {\n        \"id\": 46,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \")\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef print_hello_name(name):\\n    print(\\\"Hello \\\" + name + \\\"!\\\")\\n\\ndef print_hello_name_age(name, age):\\n    print(\\\"Hello \\\" + name + \\\", you are \\\" + str(age)\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2284,\n          \"logprob\": -0.92626953,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 303,\n          \"logprob\": -0.40722656,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1489,\n          \"logprob\": -0.27954102,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 459,\n          \"logprob\": -0.6142578,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8302,\n          \"logprob\": -0.68310547,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10914,\n          \"logprob\": -1.4570312,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 16013,\n          \"logprob\": -0.80126953,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.6303711,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.23327637,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 610,\n          \"logprob\": -1.2304688,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2284,\n          \"logprob\": -0.92626953,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 303,\n          \"logprob\": -0.40722656,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1489,\n          \"logprob\": -0.27954102,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 459,\n          \"logprob\": -0.6142578,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8302,\n          \"logprob\": -0.68310547,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10914,\n          \"logprob\": -1.4570312,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 16013,\n          \"logprob\": -0.80126953,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.6303711,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.23327637,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 610,\n          \"logprob\": -1.2304688,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2284,\n          \"logprob\": -0.92626953,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 303,\n          \"logprob\": -0.40722656,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1489,\n          \"logprob\": -0.27954102,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 459,\n          \"logprob\": -0.6142578,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8302,\n          \"logprob\": -0.68310547,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10914,\n          \"logprob\": -1.4570312,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 16013,\n          \"logprob\": -0.80126953,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.6303711,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.23327637,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 610,\n          \"logprob\": -1.2304688,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 2284,\n          \"logprob\": -0.92626953,\n          \"special\": false,\n          \"text\": \"():\"\n        },\n        {\n          \"id\": 303,\n          \"logprob\": -0.40722656,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 1489,\n          \"logprob\": -0.27954102,\n          \"special\": false,\n          \"text\": \" print\"\n        },\n        {\n          \"id\": 459,\n          \"logprob\": -0.6142578,\n          \"special\": false,\n          \"text\": \"(\\\"\"\n        },\n        {\n          \"id\": 8302,\n          \"logprob\": -0.68310547,\n          \"special\": false,\n          \"text\": \"Hello\"\n        },\n        {\n          \"id\": 10914,\n          \"logprob\": -1.4570312,\n          \"special\": false,\n          \"text\": \" World\"\n        },\n        {\n          \"id\": 16013,\n          \"logprob\": -0.80126953,\n          \"special\": false,\n          \"text\": \"!\\\")\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.6303711,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -0.23327637,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 610,\n          \"logprob\": -1.2304688,\n          \"special\": false,\n          \"text\": \"def\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2284,\n        \"logprob\": -0.9355469,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": -0.40795898,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": -0.27954102,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": -0.6142578,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": -0.68310547,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10914,\n        \"logprob\": -1.4599609,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 16013,\n        \"logprob\": -0.80126953,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.23242188,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 610,\n        \"logprob\": -1.2294922,\n        \"special\": false,\n        \"text\": \"def\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"():\\n    print(\\\"Hello World!\\\")\\n\\ndef\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_default_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 60,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40,\n        \"logprob\": -0.7944336,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 494,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" +\"\n      },\n      {\n        \"id\": 447,\n        \"logprob\": -0.1796875,\n        \"special\": false,\n        \"text\": \" [\"\n      },\n      {\n        \"id\": 9009,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"markdown\"\n      },\n      {\n        \"id\": 98,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"]\"\n      },\n      {\n        \"id\": 37402,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" slideshow\"\n      },\n      {\n        \"id\": 8492,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"={\\\"\"\n      },\n      {\n        \"id\": 7277,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"slide\"\n      },\n      {\n        \"id\": 100,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 700,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"type\"\n      },\n      {\n        \"id\": 582,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\\":\"\n      },\n      {\n        \"id\": 332,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 7277,\n        \"logprob\": -0.06994629,\n        \"special\": false,\n        \"text\": \"slide\"\n      },\n      {\n        \"id\": 3667,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\\"}\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 607,\n        \"logprob\": -0.8261719,\n        \"special\": false,\n        \"text\": \" #\"\n      },\n      {\n        \"id\": 244,\n        \"logprob\": -1.8574219,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 55,\n        \"logprob\": -1.4541016,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 51,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 6208,\n        \"logprob\": -0.9794922,\n        \"special\": false,\n        \"text\": \" What\"\n      },\n      {\n        \"id\": 458,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 341,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 10609,\n        \"logprob\": -0.69189453,\n        \"special\": false,\n        \"text\": \" difference\"\n      },\n      {\n        \"id\": 3761,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" between\"\n      },\n      {\n        \"id\": 331,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1168,\n        \"logprob\": -0.27172852,\n        \"special\": false,\n        \"text\": \" list\"\n      },\n      {\n        \"id\": 480,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 331,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 8871,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" tuple\"\n      },\n      {\n        \"id\": 68,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"?\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40,\n        \"logprob\": -1.3359375,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 449,\n        \"logprob\": -0.03164673,\n        \"special\": false,\n        \"text\": \" -\"\n      },\n      {\n        \"id\": 418,\n        \"logprob\": -1.0947266,\n        \"special\": false,\n        \"text\": \" A\"\n      },\n      {\n        \"id\": 1168,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" list\"\n      },\n      {\n        \"id\": 458,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 331,\n        \"logprob\": -0.3305664,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 14792,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" mutable\"\n      },\n      {\n        \"id\": 6645,\n        \"logprob\": -0.40478516,\n        \"special\": false,\n        \"text\": \" sequence\"\n      },\n      {\n        \"id\": 451,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 4725,\n        \"logprob\": -0.50390625,\n        \"special\": false,\n        \"text\": \" elements\"\n      },\n      {\n        \"id\": 49,\n        \"logprob\": -2.1269531,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 2236,\n        \"logprob\": -0.1427002,\n        \"special\": false,\n        \"text\": \" while\"\n      },\n      {\n        \"id\": 331,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 8871,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" tuple\"\n      },\n      {\n        \"id\": 458,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 619,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" an\"\n      },\n      {\n        \"id\": 26079,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" immutable\"\n      },\n      {\n        \"id\": 6645,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" sequence\"\n      },\n      {\n        \"id\": 451,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 4725,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" elements\"\n      },\n      {\n        \"id\": 51,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 40,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"#\"\n      },\n      {\n        \"id\": 449,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" -\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\n# + [markdown] slideshow={\\\"slide_type\\\": \\\"slide\\\"}\\n# # 2. What is the difference between a list and a tuple?\\n#\\n# - A list is a mutable sequence of elements, while a tuple is an immutable sequence of elements.\\n# -\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 222,\n          \"logprob\": -1.9091797,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -1.0478516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 40,\n          \"logprob\": -3.015625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 494,\n          \"logprob\": -1.4228516,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.1025391,\n          \"special\": false,\n          \"text\": \" [\"\n        },\n        {\n          \"id\": 9009,\n          \"logprob\": -0.0008444786,\n          \"special\": false,\n          \"text\": \"markdown\"\n        },\n        {\n          \"id\": 98,\n          \"logprob\": -8.8095665e-05,\n          \"special\": false,\n          \"text\": \"]\"\n        },\n        {\n          \"id\": 37402,\n          \"logprob\": -0.5810547,\n          \"special\": false,\n          \"text\": \" slideshow\"\n        },\n        {\n          \"id\": 8492,\n          \"logprob\": -0.00022864342,\n          \"special\": false,\n          \"text\": \"={\\\"\"\n        },\n        {\n          \"id\": 7277,\n          \"logprob\": -0.00030994415,\n          \"special\": false,\n          \"text\": \"slide\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\n# + [markdown] slideshow={\\\"slide\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 222,\n          \"logprob\": -1.9091797,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -1.0478516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 40,\n          \"logprob\": -3.015625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 494,\n          \"logprob\": -1.4228516,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.1025391,\n          \"special\": false,\n          \"text\": \" [\"\n        },\n        {\n          \"id\": 9009,\n          \"logprob\": -0.0008444786,\n          \"special\": false,\n          \"text\": \"markdown\"\n        },\n        {\n          \"id\": 98,\n          \"logprob\": -8.8095665e-05,\n          \"special\": false,\n          \"text\": \"]\"\n        },\n        {\n          \"id\": 37402,\n          \"logprob\": -0.5810547,\n          \"special\": false,\n          \"text\": \" slideshow\"\n        },\n        {\n          \"id\": 8492,\n          \"logprob\": -0.00022864342,\n          \"special\": false,\n          \"text\": \"={\\\"\"\n        },\n        {\n          \"id\": 7277,\n          \"logprob\": -0.00030994415,\n          \"special\": false,\n          \"text\": \"slide\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\n# + [markdown] slideshow={\\\"slide\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 222,\n          \"logprob\": -1.9091797,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -1.0478516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 40,\n          \"logprob\": -3.015625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 494,\n          \"logprob\": -1.4228516,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.1025391,\n          \"special\": false,\n          \"text\": \" [\"\n        },\n        {\n          \"id\": 9009,\n          \"logprob\": -0.0008444786,\n          \"special\": false,\n          \"text\": \"markdown\"\n        },\n        {\n          \"id\": 98,\n          \"logprob\": -8.8095665e-05,\n          \"special\": false,\n          \"text\": \"]\"\n        },\n        {\n          \"id\": 37402,\n          \"logprob\": -0.5810547,\n          \"special\": false,\n          \"text\": \" slideshow\"\n        },\n        {\n          \"id\": 8492,\n          \"logprob\": -0.00022864342,\n          \"special\": false,\n          \"text\": \"={\\\"\"\n        },\n        {\n          \"id\": 7277,\n          \"logprob\": -0.00030994415,\n          \"special\": false,\n          \"text\": \"slide\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\n# + [markdown] slideshow={\\\"slide\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 222,\n          \"logprob\": -1.9091797,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 222,\n          \"logprob\": -1.0478516,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 40,\n          \"logprob\": -3.015625,\n          \"special\": false,\n          \"text\": \"#\"\n        },\n        {\n          \"id\": 494,\n          \"logprob\": -1.4228516,\n          \"special\": false,\n          \"text\": \" +\"\n        },\n        {\n          \"id\": 447,\n          \"logprob\": -1.1025391,\n          \"special\": false,\n          \"text\": \" [\"\n        },\n        {\n          \"id\": 9009,\n          \"logprob\": -0.0008444786,\n          \"special\": false,\n          \"text\": \"markdown\"\n        },\n        {\n          \"id\": 98,\n          \"logprob\": -8.8095665e-05,\n          \"special\": false,\n          \"text\": \"]\"\n        },\n        {\n          \"id\": 37402,\n          \"logprob\": -0.5810547,\n          \"special\": false,\n          \"text\": \" slideshow\"\n        },\n        {\n          \"id\": 8492,\n          \"logprob\": -0.00022864342,\n          \"special\": false,\n          \"text\": \"={\\\"\"\n        },\n        {\n          \"id\": 7277,\n          \"logprob\": -0.00030994415,\n          \"special\": false,\n          \"text\": \"slide\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\n# + [markdown] slideshow={\\\"slide\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder2_lora/test_flash_starcoder2_with_hugcode_adapter.json",
    "content": "{\n  \"details\": {\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 100,\n        \"logprob\": -0.9824219,\n        \"special\": false,\n        \"text\": \"_\"\n      },\n      {\n        \"id\": 5879,\n        \"logprob\": -0.3017578,\n        \"special\": false,\n        \"text\": \"world\"\n      },\n      {\n        \"id\": 2284,\n        \"logprob\": -0.68652344,\n        \"special\": false,\n        \"text\": \"():\"\n      },\n      {\n        \"id\": 303,\n        \"logprob\": -0.27734375,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 1489,\n        \"logprob\": -0.4482422,\n        \"special\": false,\n        \"text\": \" print\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": -0.54248047,\n        \"special\": false,\n        \"text\": \"(\\\"\"\n      },\n      {\n        \"id\": 8302,\n        \"logprob\": -0.4296875,\n        \"special\": false,\n        \"text\": \"Hello\"\n      },\n      {\n        \"id\": 10914,\n        \"logprob\": -0.8544922,\n        \"special\": false,\n        \"text\": \" World\"\n      },\n      {\n        \"id\": 16013,\n        \"logprob\": -0.7573242,\n        \"special\": false,\n        \"text\": \"!\\\")\"\n      },\n      {\n        \"id\": 222,\n        \"logprob\": -0.81347656,\n        \"special\": false,\n        \"text\": \"\\n\"\n      }\n    ]\n  },\n  \"generated_text\": \"_world():\\n    print(\\\"Hello World!\\\")\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 2,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 284,\n        \"logprob\": -1.1679688,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 0,\n        \"logprob\": null,\n        \"special\": true,\n        \"text\": \"<|endoftext|>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n   \"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_default_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 2,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 284,\n        \"logprob\": -0.048583984,\n        \"special\": false,\n        \"text\": \"\\n   \"\n      },\n      {\n        \"id\": 0,\n        \"logprob\": null,\n        \"special\": true,\n        \"text\": \"<|endoftext|>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n   \"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_flash_starcoder_gptq/test_flash_starcoder_gptq_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 2,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 284,\n          \"logprob\": -0.046844482,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"special\": true,\n          \"text\": \"<|endoftext|>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n   \"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 2,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 284,\n          \"logprob\": -0.046722412,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"special\": true,\n          \"text\": \"<|endoftext|>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n   \"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 2,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 284,\n          \"logprob\": -0.04650879,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"special\": true,\n          \"text\": \"<|endoftext|>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n   \"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 2,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 284,\n          \"logprob\": -0.046539307,\n          \"special\": false,\n          \"text\": \"\\n   \"\n        },\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"special\": true,\n          \"text\": \"<|endoftext|>\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n   \"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_grammar_llama/test_non_flash_llama_grammar_json.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 30,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 6377,\n        \"logprob\": -0.14916992,\n        \"special\": false,\n        \"text\": \"{\\\"\"\n      },\n      {\n        \"id\": 29888,\n        \"logprob\": -0.13598633,\n        \"special\": false,\n        \"text\": \"f\"\n      },\n      {\n        \"id\": 12935,\n        \"logprob\": -0.017669678,\n        \"special\": false,\n        \"text\": \"irs\"\n      },\n      {\n        \"id\": 29873,\n        \"logprob\": -0.00085639954,\n        \"special\": false,\n        \"text\": \"t\"\n      },\n      {\n        \"id\": 1170,\n        \"logprob\": -0.0054016113,\n        \"special\": false,\n        \"text\": \"Name\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.13549805,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 19504,\n        \"logprob\": -0.8852539,\n        \"special\": false,\n        \"text\": \"David\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.16394043,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 29882,\n        \"logprob\": -0.08862305,\n        \"special\": false,\n        \"text\": \"h\"\n      },\n      {\n        \"id\": 711,\n        \"logprob\": -0.66259766,\n        \"special\": false,\n        \"text\": \"ob\"\n      },\n      {\n        \"id\": 1609,\n        \"logprob\": -5.51939e-05,\n        \"special\": false,\n        \"text\": \"by\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.23120117,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 29911,\n        \"logprob\": -2.3730469,\n        \"special\": false,\n        \"text\": \"T\"\n      },\n      {\n        \"id\": 11003,\n        \"logprob\": -0.032104492,\n        \"special\": false,\n        \"text\": \"rees\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.22021484,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 4230,\n        \"logprob\": -0.06726074,\n        \"special\": false,\n        \"text\": \"last\"\n      },\n      {\n        \"id\": 1170,\n        \"logprob\": -0.003501892,\n        \"special\": false,\n        \"text\": \"Name\"\n      },\n      {\n        \"id\": 4710,\n        \"logprob\": -0.0045661926,\n        \"special\": false,\n        \"text\": \"\\\":\\\"\"\n      },\n      {\n        \"id\": 29950,\n        \"logprob\": -0.12512207,\n        \"special\": false,\n        \"text\": \"H\"\n      },\n      {\n        \"id\": 14339,\n        \"logprob\": -0.009552002,\n        \"special\": false,\n        \"text\": \"olt\"\n      },\n      {\n        \"id\": 29920,\n        \"logprob\": -0.00042438507,\n        \"special\": false,\n        \"text\": \"z\"\n      },\n      {\n        \"id\": 3284,\n        \"logprob\": -0.11651611,\n        \"special\": false,\n        \"text\": \"\\\",\\\"\"\n      },\n      {\n        \"id\": 29876,\n        \"logprob\": -0.29736328,\n        \"special\": false,\n        \"text\": \"n\"\n      },\n      {\n        \"id\": 398,\n        \"logprob\": -0.003030777,\n        \"special\": false,\n        \"text\": \"um\"\n      },\n      {\n        \"id\": 29907,\n        \"logprob\": -0.3774414,\n        \"special\": false,\n        \"text\": \"C\"\n      },\n      {\n        \"id\": 1446,\n        \"logprob\": -0.0003130436,\n        \"special\": false,\n        \"text\": \"ats\"\n      },\n      {\n        \"id\": 1115,\n        \"logprob\": -0.0021514893,\n        \"special\": false,\n        \"text\": \"\\\":\"\n      },\n      {\n        \"id\": 29906,\n        \"logprob\": -0.071899414,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 29913,\n        \"logprob\": -0.018997192,\n        \"special\": false,\n        \"text\": \"}\"\n      },\n      {\n        \"id\": 2,\n        \"logprob\": 0.0,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"{\\\"firstName\\\":\\\"David\\\",\\\"hobby\\\":\\\"Trees\\\",\\\"lastName\\\":\\\"Holtz\\\",\\\"numCats\\\":2}\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.1.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"{ \\\"unit\\\": \\\"fahrenheit\\\", \\\"temperature\\\": [ 72, 79, 88 ] }\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1740095072,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 135,\n    \"total_tokens\": 164\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.2.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"{ \\\"unit\\\": \\\"fahrenheit\\\", \\\"temperature\\\": [ 72, 79, 88 ] }\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1740095073,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 135,\n    \"total_tokens\": 164\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_grammar_response_format_llama/test_grammar_response_format_llama_json.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"{ \\\"unit\\\": \\\"fahrenheit\\\", \\\"temperature\\\": [ 72, 79, 88 ] }\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1732525803,\n  \"id\": \"\",\n  \"model\": \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"2.4.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 136,\n    \"total_tokens\": 165\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics/test_idefics.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 1,\n        \"logprob\": null,\n        \"text\": \"<s>\"\n      },\n      {\n        \"id\": 4911,\n        \"logprob\": -6.9765625,\n        \"text\": \"User\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": -0.0059432983,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 32000,\n        \"logprob\": -0.8408203,\n        \"text\": \"<fake_token_around_image>\"\n      },\n      {\n        \"id\": 32001,\n        \"logprob\": -9.906292e-05,\n        \"text\": \"<image>\"\n      },\n      {\n        \"id\": 32000,\n        \"logprob\": -2.3841858e-07,\n        \"text\": \"<fake_token_around_image>\"\n      },\n      {\n        \"id\": 1815,\n        \"logprob\": -4.1679688,\n        \"text\": \"Can\"\n      },\n      {\n        \"id\": 366,\n        \"logprob\": -0.014099121,\n        \"text\": \"you\"\n      },\n      {\n        \"id\": 2649,\n        \"logprob\": -4.4609375,\n        \"text\": \"tell\"\n      },\n      {\n        \"id\": 592,\n        \"logprob\": -0.29882812,\n        \"text\": \"me\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -4.1445312,\n        \"text\": \"a\"\n      },\n      {\n        \"id\": 1407,\n        \"logprob\": -9.3828125,\n        \"text\": \"very\"\n      },\n      {\n        \"id\": 3273,\n        \"logprob\": -1.9736328,\n        \"text\": \"short\"\n      },\n      {\n        \"id\": 5828,\n        \"logprob\": -0.2800293,\n        \"text\": \"story\"\n      },\n      {\n        \"id\": 2729,\n        \"logprob\": -3.5625,\n        \"text\": \"based\"\n      },\n      {\n        \"id\": 373,\n        \"logprob\": -0.0006427765,\n        \"text\": \"on\"\n      },\n      {\n        \"id\": 278,\n        \"logprob\": -0.13952637,\n        \"text\": \"the\"\n      },\n      {\n        \"id\": 1967,\n        \"logprob\": -0.068115234,\n        \"text\": \"image\"\n      },\n      {\n        \"id\": 29973,\n        \"logprob\": -0.16357422,\n        \"text\": \"?\"\n      }\n    ],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 32002,\n        \"logprob\": -0.0026474,\n        \"special\": true,\n        \"text\": \"<end_of_utterance>\"\n      },\n      {\n        \"id\": 29871,\n        \"logprob\": -8.547306e-05,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.7881393e-05,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 7900,\n        \"logprob\": -3.0994415e-06,\n        \"special\": false,\n        \"text\": \"Ass\"\n      },\n      {\n        \"id\": 22137,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"istant\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": -3.2186508e-06,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 319,\n        \"logprob\": -0.92529297,\n        \"special\": false,\n        \"text\": \" A\"\n      },\n      {\n        \"id\": 696,\n        \"logprob\": -1.1269531,\n        \"special\": false,\n        \"text\": \" ro\"\n      },\n      {\n        \"id\": 15664,\n        \"logprob\": -0.00029492378,\n        \"special\": false,\n        \"text\": \"oster\"\n      },\n      {\n        \"id\": 15028,\n        \"logprob\": -1.1855469,\n        \"special\": false,\n        \"text\": \" stands\"\n      }\n    ]\n  },\n  \"generated_text\": \" \\nAssistant: A rooster stands\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics/test_idefics_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1,\n          \"logprob\": null,\n          \"text\": \"<s>\"\n        },\n        {\n          \"id\": 4911,\n          \"logprob\": -6.9804688,\n          \"text\": \"User\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.006122589,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -0.8417969,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 32001,\n          \"logprob\": -9.918213e-05,\n          \"text\": \"<image>\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -2.3841858e-07,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 1815,\n          \"logprob\": -4.1679688,\n          \"text\": \"Can\"\n        },\n        {\n          \"id\": 366,\n          \"logprob\": -0.014091492,\n          \"text\": \"you\"\n        },\n        {\n          \"id\": 2649,\n          \"logprob\": -4.4726562,\n          \"text\": \"tell\"\n        },\n        {\n          \"id\": 592,\n          \"logprob\": -0.2998047,\n          \"text\": \"me\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -4.15625,\n          \"text\": \"a\"\n        },\n        {\n          \"id\": 1407,\n          \"logprob\": -9.3828125,\n          \"text\": \"very\"\n        },\n        {\n          \"id\": 3273,\n          \"logprob\": -1.9716797,\n          \"text\": \"short\"\n        },\n        {\n          \"id\": 5828,\n          \"logprob\": -0.27734375,\n          \"text\": \"story\"\n        },\n        {\n          \"id\": 2729,\n          \"logprob\": -3.5605469,\n          \"text\": \"based\"\n        },\n        {\n          \"id\": 373,\n          \"logprob\": -0.00064468384,\n          \"text\": \"on\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -0.14160156,\n          \"text\": \"the\"\n        },\n        {\n          \"id\": 1967,\n          \"logprob\": -0.06915283,\n          \"text\": \"image\"\n        },\n        {\n          \"id\": 29973,\n          \"logprob\": -0.16381836,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 32002,\n          \"logprob\": -0.0026664734,\n          \"special\": true,\n          \"text\": \"<end_of_utterance>\"\n        },\n        {\n          \"id\": 29871,\n          \"logprob\": -8.583069e-05,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.8119812e-05,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 7900,\n          \"logprob\": -2.9802322e-06,\n          \"special\": false,\n          \"text\": \"Ass\"\n        },\n        {\n          \"id\": 22137,\n          \"logprob\": 0.0,\n          \"special\": false,\n          \"text\": \"istant\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -3.2186508e-06,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 319,\n          \"logprob\": -0.9301758,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 696,\n          \"logprob\": -1.1279297,\n          \"special\": false,\n          \"text\": \" ro\"\n        },\n        {\n          \"id\": 15664,\n          \"logprob\": -0.0002939701,\n          \"special\": false,\n          \"text\": \"oster\"\n        },\n        {\n          \"id\": 15028,\n          \"logprob\": -1.1865234,\n          \"special\": false,\n          \"text\": \" stands\"\n        }\n      ]\n    },\n    \"generated_text\": \" \\nAssistant: A rooster stands\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1,\n          \"logprob\": null,\n          \"text\": \"<s>\"\n        },\n        {\n          \"id\": 4911,\n          \"logprob\": -6.9804688,\n          \"text\": \"User\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.006122589,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -0.8417969,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 32001,\n          \"logprob\": -9.942055e-05,\n          \"text\": \"<image>\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -2.3841858e-07,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 1815,\n          \"logprob\": -4.1679688,\n          \"text\": \"Can\"\n        },\n        {\n          \"id\": 366,\n          \"logprob\": -0.014091492,\n          \"text\": \"you\"\n        },\n        {\n          \"id\": 2649,\n          \"logprob\": -4.4726562,\n          \"text\": \"tell\"\n        },\n        {\n          \"id\": 592,\n          \"logprob\": -0.2998047,\n          \"text\": \"me\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -4.15625,\n          \"text\": \"a\"\n        },\n        {\n          \"id\": 1407,\n          \"logprob\": -9.3828125,\n          \"text\": \"very\"\n        },\n        {\n          \"id\": 3273,\n          \"logprob\": -1.9716797,\n          \"text\": \"short\"\n        },\n        {\n          \"id\": 5828,\n          \"logprob\": -0.27734375,\n          \"text\": \"story\"\n        },\n        {\n          \"id\": 2729,\n          \"logprob\": -3.5605469,\n          \"text\": \"based\"\n        },\n        {\n          \"id\": 373,\n          \"logprob\": -0.0006451607,\n          \"text\": \"on\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -0.14160156,\n          \"text\": \"the\"\n        },\n        {\n          \"id\": 1967,\n          \"logprob\": -0.06915283,\n          \"text\": \"image\"\n        },\n        {\n          \"id\": 29973,\n          \"logprob\": -0.16381836,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 32002,\n          \"logprob\": -0.0026664734,\n          \"special\": true,\n          \"text\": \"<end_of_utterance>\"\n        },\n        {\n          \"id\": 29871,\n          \"logprob\": -8.571148e-05,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.8119812e-05,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 7900,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \"Ass\"\n        },\n        {\n          \"id\": 22137,\n          \"logprob\": 0.0,\n          \"special\": false,\n          \"text\": \"istant\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 319,\n          \"logprob\": -0.9301758,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 696,\n          \"logprob\": -1.1279297,\n          \"special\": false,\n          \"text\": \" ro\"\n        },\n        {\n          \"id\": 15664,\n          \"logprob\": -0.0002939701,\n          \"special\": false,\n          \"text\": \"oster\"\n        },\n        {\n          \"id\": 15028,\n          \"logprob\": -1.1865234,\n          \"special\": false,\n          \"text\": \" stands\"\n        }\n      ]\n    },\n    \"generated_text\": \" \\nAssistant: A rooster stands\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1,\n          \"logprob\": null,\n          \"text\": \"<s>\"\n        },\n        {\n          \"id\": 4911,\n          \"logprob\": -6.9804688,\n          \"text\": \"User\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.006122589,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -0.8417969,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 32001,\n          \"logprob\": -9.918213e-05,\n          \"text\": \"<image>\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -2.3841858e-07,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 1815,\n          \"logprob\": -4.1679688,\n          \"text\": \"Can\"\n        },\n        {\n          \"id\": 366,\n          \"logprob\": -0.014091492,\n          \"text\": \"you\"\n        },\n        {\n          \"id\": 2649,\n          \"logprob\": -4.4726562,\n          \"text\": \"tell\"\n        },\n        {\n          \"id\": 592,\n          \"logprob\": -0.2998047,\n          \"text\": \"me\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -4.15625,\n          \"text\": \"a\"\n        },\n        {\n          \"id\": 1407,\n          \"logprob\": -9.3828125,\n          \"text\": \"very\"\n        },\n        {\n          \"id\": 3273,\n          \"logprob\": -1.9716797,\n          \"text\": \"short\"\n        },\n        {\n          \"id\": 5828,\n          \"logprob\": -0.27734375,\n          \"text\": \"story\"\n        },\n        {\n          \"id\": 2729,\n          \"logprob\": -3.5605469,\n          \"text\": \"based\"\n        },\n        {\n          \"id\": 373,\n          \"logprob\": -0.00064468384,\n          \"text\": \"on\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -0.14160156,\n          \"text\": \"the\"\n        },\n        {\n          \"id\": 1967,\n          \"logprob\": -0.06915283,\n          \"text\": \"image\"\n        },\n        {\n          \"id\": 29973,\n          \"logprob\": -0.16381836,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 32002,\n          \"logprob\": -0.0026664734,\n          \"special\": true,\n          \"text\": \"<end_of_utterance>\"\n        },\n        {\n          \"id\": 29871,\n          \"logprob\": -8.59499e-05,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.8119812e-05,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 7900,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \"Ass\"\n        },\n        {\n          \"id\": 22137,\n          \"logprob\": 0.0,\n          \"special\": false,\n          \"text\": \"istant\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 319,\n          \"logprob\": -0.9301758,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 696,\n          \"logprob\": -1.1279297,\n          \"special\": false,\n          \"text\": \" ro\"\n        },\n        {\n          \"id\": 15664,\n          \"logprob\": -0.0002939701,\n          \"special\": false,\n          \"text\": \"oster\"\n        },\n        {\n          \"id\": 15028,\n          \"logprob\": -1.1865234,\n          \"special\": false,\n          \"text\": \" stands\"\n        }\n      ]\n    },\n    \"generated_text\": \" \\nAssistant: A rooster stands\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1,\n          \"logprob\": null,\n          \"text\": \"<s>\"\n        },\n        {\n          \"id\": 4911,\n          \"logprob\": -6.9804688,\n          \"text\": \"User\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -0.006122589,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -0.8417969,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 32001,\n          \"logprob\": -9.942055e-05,\n          \"text\": \"<image>\"\n        },\n        {\n          \"id\": 32000,\n          \"logprob\": -2.3841858e-07,\n          \"text\": \"<fake_token_around_image>\"\n        },\n        {\n          \"id\": 1815,\n          \"logprob\": -4.1679688,\n          \"text\": \"Can\"\n        },\n        {\n          \"id\": 366,\n          \"logprob\": -0.014091492,\n          \"text\": \"you\"\n        },\n        {\n          \"id\": 2649,\n          \"logprob\": -4.4726562,\n          \"text\": \"tell\"\n        },\n        {\n          \"id\": 592,\n          \"logprob\": -0.2998047,\n          \"text\": \"me\"\n        },\n        {\n          \"id\": 263,\n          \"logprob\": -4.15625,\n          \"text\": \"a\"\n        },\n        {\n          \"id\": 1407,\n          \"logprob\": -9.3828125,\n          \"text\": \"very\"\n        },\n        {\n          \"id\": 3273,\n          \"logprob\": -1.9716797,\n          \"text\": \"short\"\n        },\n        {\n          \"id\": 5828,\n          \"logprob\": -0.27734375,\n          \"text\": \"story\"\n        },\n        {\n          \"id\": 2729,\n          \"logprob\": -3.5605469,\n          \"text\": \"based\"\n        },\n        {\n          \"id\": 373,\n          \"logprob\": -0.0006451607,\n          \"text\": \"on\"\n        },\n        {\n          \"id\": 278,\n          \"logprob\": -0.14160156,\n          \"text\": \"the\"\n        },\n        {\n          \"id\": 1967,\n          \"logprob\": -0.06915283,\n          \"text\": \"image\"\n        },\n        {\n          \"id\": 29973,\n          \"logprob\": -0.16381836,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 32002,\n          \"logprob\": -0.0026664734,\n          \"special\": true,\n          \"text\": \"<end_of_utterance>\"\n        },\n        {\n          \"id\": 29871,\n          \"logprob\": -8.571148e-05,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.8119812e-05,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 7900,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \"Ass\"\n        },\n        {\n          \"id\": 22137,\n          \"logprob\": 0.0,\n          \"special\": false,\n          \"text\": \"istant\"\n        },\n        {\n          \"id\": 29901,\n          \"logprob\": -3.0994415e-06,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 319,\n          \"logprob\": -0.9301758,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 696,\n          \"logprob\": -1.1279297,\n          \"special\": false,\n          \"text\": \" ro\"\n        },\n        {\n          \"id\": 15664,\n          \"logprob\": -0.0002939701,\n          \"special\": false,\n          \"text\": \"oster\"\n        },\n        {\n          \"id\": 15028,\n          \"logprob\": -1.1865234,\n          \"special\": false,\n          \"text\": \" stands\"\n        }\n      ]\n    },\n    \"generated_text\": \" \\nAssistant: A rooster stands\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics/test_idefics_two_images.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 13,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 450,\n        \"logprob\": -0.2602539,\n        \"special\": false,\n        \"text\": \" The\"\n      },\n      {\n        \"id\": 21282,\n        \"logprob\": -0.018463135,\n        \"special\": false,\n        \"text\": \" cow\"\n      },\n      {\n        \"id\": 322,\n        \"logprob\": -0.1829834,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 521,\n        \"logprob\": -0.62109375,\n        \"special\": false,\n        \"text\": \" ch\"\n      },\n      {\n        \"id\": 21475,\n        \"logprob\": -0.0037403107,\n        \"special\": false,\n        \"text\": \"icken\"\n      },\n      {\n        \"id\": 526,\n        \"logprob\": -0.018920898,\n        \"special\": false,\n        \"text\": \" are\"\n      },\n      {\n        \"id\": 13407,\n        \"logprob\": -1.0732422,\n        \"special\": false,\n        \"text\": \" standing\"\n      },\n      {\n        \"id\": 373,\n        \"logprob\": -0.5292969,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -0.47070312,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 25695,\n        \"logprob\": -0.25708008,\n        \"special\": false,\n        \"text\": \" beach\"\n      },\n      {\n        \"id\": 29889,\n        \"logprob\": -0.17578125,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 32002,\n        \"logprob\": -0.0023422241,\n        \"special\": true,\n        \"text\": \"<end_of_utterance>\"\n      },\n      {\n        \"id\": 2,\n        \"logprob\": -0.00030851364,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" The cow and chicken are standing on a beach.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 288,\n        \"logprob\": -0.2854004,\n        \"special\": false,\n        \"text\": \"ing\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.38061523,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 633,\n        \"logprob\": -0.09301758,\n        \"special\": false,\n        \"text\": \" new\"\n      },\n      {\n        \"id\": 4480,\n        \"logprob\": -0.26782227,\n        \"special\": false,\n        \"text\": \" feature\"\n      },\n      {\n        \"id\": 297,\n        \"logprob\": -0.8510742,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 272,\n        \"logprob\": -0.13464355,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2039,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" game\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.89990234,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.10632324,\n        \"special\": false,\n        \"text\": \"\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test requesting a new feature in the game.\\n\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 330,\n          \"logprob\": -0.09289551,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 13088,\n          \"logprob\": -0.6743164,\n          \"special\": false,\n          \"text\": \" chicken\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.31396484,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 6398,\n          \"logprob\": -0.051727295,\n          \"special\": false,\n          \"text\": \" sitting\"\n        },\n        {\n          \"id\": 356,\n          \"logprob\": -0.34448242,\n          \"special\": false,\n          \"text\": \" on\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.1194458,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 17972,\n          \"logprob\": -0.03237915,\n          \"special\": false,\n          \"text\": \" pile\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.00018751621,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2445,\n          \"logprob\": -0.07043457,\n          \"special\": false,\n          \"text\": \" money\"\n        },\n        {\n          \"id\": 28723,\n          \"logprob\": -0.00422287,\n          \"special\": false,\n          \"text\": \".\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" A chicken is sitting on a pile of money.\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 330,\n          \"logprob\": -0.09448242,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 13088,\n          \"logprob\": -0.6743164,\n          \"special\": false,\n          \"text\": \" chicken\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.31201172,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 6398,\n          \"logprob\": -0.051635742,\n          \"special\": false,\n          \"text\": \" sitting\"\n        },\n        {\n          \"id\": 356,\n          \"logprob\": -0.34033203,\n          \"special\": false,\n          \"text\": \" on\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.1194458,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 17972,\n          \"logprob\": -0.032562256,\n          \"special\": false,\n          \"text\": \" pile\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.00018763542,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2445,\n          \"logprob\": -0.07122803,\n          \"special\": false,\n          \"text\": \" money\"\n        },\n        {\n          \"id\": 28723,\n          \"logprob\": -0.0041007996,\n          \"special\": false,\n          \"text\": \".\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" A chicken is sitting on a pile of money.\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 330,\n          \"logprob\": -0.09448242,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 13088,\n          \"logprob\": -0.6743164,\n          \"special\": false,\n          \"text\": \" chicken\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.31201172,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 6398,\n          \"logprob\": -0.051635742,\n          \"special\": false,\n          \"text\": \" sitting\"\n        },\n        {\n          \"id\": 356,\n          \"logprob\": -0.34033203,\n          \"special\": false,\n          \"text\": \" on\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.1194458,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 17972,\n          \"logprob\": -0.032562256,\n          \"special\": false,\n          \"text\": \" pile\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.00018787384,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2445,\n          \"logprob\": -0.07122803,\n          \"special\": false,\n          \"text\": \" money\"\n        },\n        {\n          \"id\": 28723,\n          \"logprob\": -0.0041007996,\n          \"special\": false,\n          \"text\": \".\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" A chicken is sitting on a pile of money.\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 330,\n          \"logprob\": -0.09448242,\n          \"special\": false,\n          \"text\": \" A\"\n        },\n        {\n          \"id\": 13088,\n          \"logprob\": -0.6743164,\n          \"special\": false,\n          \"text\": \" chicken\"\n        },\n        {\n          \"id\": 349,\n          \"logprob\": -0.31201172,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 6398,\n          \"logprob\": -0.051635742,\n          \"special\": false,\n          \"text\": \" sitting\"\n        },\n        {\n          \"id\": 356,\n          \"logprob\": -0.34033203,\n          \"special\": false,\n          \"text\": \" on\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.1194458,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 17972,\n          \"logprob\": -0.032562256,\n          \"special\": false,\n          \"text\": \" pile\"\n        },\n        {\n          \"id\": 302,\n          \"logprob\": -0.00018763542,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 2445,\n          \"logprob\": -0.07122803,\n          \"special\": false,\n          \"text\": \" money\"\n        },\n        {\n          \"id\": 28723,\n          \"logprob\": -0.0041007996,\n          \"special\": false,\n          \"text\": \".\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \" A chicken is sitting on a pile of money.\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 330,\n        \"logprob\": -0.08660889,\n        \"special\": false,\n        \"text\": \" A\"\n      },\n      {\n        \"id\": 13088,\n        \"logprob\": -0.7089844,\n        \"special\": false,\n        \"text\": \" chicken\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.32885742,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 6398,\n        \"logprob\": -0.05126953,\n        \"special\": false,\n        \"text\": \" sitting\"\n      },\n      {\n        \"id\": 356,\n        \"logprob\": -0.35229492,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.12561035,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 17972,\n        \"logprob\": -0.038085938,\n        \"special\": false,\n        \"text\": \" pile\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": -0.00018656254,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 2445,\n        \"logprob\": -0.07293701,\n        \"special\": false,\n        \"text\": \" money\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.004852295,\n        \"special\": false,\n        \"text\": \".\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" A chicken is sitting on a pile of money.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 19,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 415,\n        \"logprob\": -0.03665161,\n        \"special\": false,\n        \"text\": \" The\"\n      },\n      {\n        \"id\": 12072,\n        \"logprob\": -0.13549805,\n        \"special\": false,\n        \"text\": \" cow\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.05819702,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 6328,\n        \"logprob\": -0.6826172,\n        \"special\": false,\n        \"text\": \" standing\"\n      },\n      {\n        \"id\": 356,\n        \"logprob\": -0.1607666,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 272,\n        \"logprob\": -0.5073242,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 10305,\n        \"logprob\": -0.016418457,\n        \"special\": false,\n        \"text\": \" beach\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": -1.3916016,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 272,\n        \"logprob\": -0.020217896,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 13088,\n        \"logprob\": -0.0028133392,\n        \"special\": false,\n        \"text\": \" chicken\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -0.003145218,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 6398,\n        \"logprob\": -0.37060547,\n        \"special\": false,\n        \"text\": \" sitting\"\n      },\n      {\n        \"id\": 356,\n        \"logprob\": -0.034851074,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.2878418,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 17972,\n        \"logprob\": -0.046051025,\n        \"special\": false,\n        \"text\": \" pile\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": -0.00028848648,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 2445,\n        \"logprob\": -0.025772095,\n        \"special\": false,\n        \"text\": \" money\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.018127441,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 32002,\n        \"logprob\": -0.00019824505,\n        \"special\": true,\n        \"text\": \"<end_of_utterance>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" The cow is standing on the beach and the chicken is sitting on a pile of money.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_idefics3/test_flash_idefics3_next_simple_url.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 9,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2684,\n        \"logprob\": -0.24902344,\n        \"special\": false,\n        \"text\": \" There\"\n      },\n      {\n        \"id\": 374,\n        \"logprob\": -0.0703125,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.23535156,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 35372,\n        \"logprob\": -0.125,\n        \"special\": false,\n        \"text\": \" statue\"\n      },\n      {\n        \"id\": 304,\n        \"logprob\": -0.30273438,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 279,\n        \"logprob\": -0.20507812,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2217,\n        \"logprob\": -0.076171875,\n        \"special\": false,\n        \"text\": \" image\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.053710938,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 128258,\n        \"logprob\": -0.011352539,\n        \"special\": true,\n        \"text\": \"<end_of_utterance>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" There is a statue in the image.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_basic.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"{\\\"firstName\\\":\\\"David\\\",\\\"lastName\\\":\\\"(Not provided)\\\",\\\"hobby\\\":\\\", nature\\\",\\\"numCats\\\":2}\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1746053368,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 35,\n    \"prompt_tokens\": 32,\n    \"total_tokens\": 67\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_complex.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"{\\\"name\\\":\\\"John Smith\\\",\\\"age\\\":30,\\\"address\\\":{\\\"street\\\":\\\"Maple Street\\\",\\\"city\\\":\\\"Boston\\\"},\\\"hobbies\\\":[\\\"botany\\\",\\\"astronomy\\\",\\\"solving mathematical puzzles\\\"]}\",\n        \"role\": \"assistant\"\n      }\n    }\n  ],\n  \"created\": 1746053373,\n  \"id\": \"\",\n  \"model\": \"google/gemma-3-4b-it\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 44,\n    \"prompt_tokens\": 37,\n    \"total_tokens\": 81\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_json_schema_constrain/test_json_schema_stream.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"{\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"f\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"irs\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"t\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Name\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\":\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"David\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\",\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"l\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"ast\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Name\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\":\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Unknown\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975615,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\",\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"h\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"obb\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"y\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\":\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\",\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" \\\\\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"riding\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" bicycles\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\\\\\",\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" \\\\\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"having\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" cats\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\\\\\"\\\",\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"num\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Cats\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\\\":\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"2\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"}\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"\",\n          \"role\": \"assistant\"\n        },\n        \"finish_reason\": \"stop\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741975616,\n    \"id\": \"\",\n    \"model\": \"google/gemma-3-4b-it\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.2.1-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"stop_sequence\",\n    \"generated_tokens\": 6,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.0654297,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 1014,\n        \"logprob\": -2.7460938,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 6032,\n        \"logprob\": -1.359375,\n        \"special\": false,\n        \"text\": \" purpose\"\n      },\n      {\n        \"id\": 302,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 456,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" this\"\n      },\n      {\n        \"id\": 1369,\n        \"logprob\": -0.40063477,\n        \"special\": false,\n        \"text\": \" test\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request\\nThe purpose of this test\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.007621765,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.20812988,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 16114,\n          \"logprob\": -1.2587891,\n          \"special\": false,\n          \"text\": \"Once\"\n        },\n        {\n          \"id\": 3714,\n          \"logprob\": -0.20825195,\n          \"special\": false,\n          \"text\": \" upon\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.0017709732,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 727,\n          \"logprob\": -0.011932373,\n          \"special\": false,\n          \"text\": \" time\"\n        },\n        {\n          \"id\": 28725,\n          \"logprob\": -0.17297363,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 736,\n          \"logprob\": -0.9057617,\n          \"special\": false,\n          \"text\": \" there\"\n        },\n        {\n          \"id\": 403,\n          \"logprob\": -0.05758667,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.00970459,\n          \"special\": false,\n          \"text\": \" a\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nOnce upon a time, there was a\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.007621765,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.20275879,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 16114,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"Once\"\n        },\n        {\n          \"id\": 3714,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \" upon\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.0017738342,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 727,\n          \"logprob\": -0.011932373,\n          \"special\": false,\n          \"text\": \" time\"\n        },\n        {\n          \"id\": 28725,\n          \"logprob\": -0.17297363,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 736,\n          \"logprob\": -0.9057617,\n          \"special\": false,\n          \"text\": \" there\"\n        },\n        {\n          \"id\": 403,\n          \"logprob\": -0.05758667,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.00970459,\n          \"special\": false,\n          \"text\": \" a\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nOnce upon a time, there was a\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.007621765,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.20275879,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 16114,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"Once\"\n        },\n        {\n          \"id\": 3714,\n          \"logprob\": -0.2084961,\n          \"special\": false,\n          \"text\": \" upon\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.0017738342,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 727,\n          \"logprob\": -0.011932373,\n          \"special\": false,\n          \"text\": \" time\"\n        },\n        {\n          \"id\": 28725,\n          \"logprob\": -0.17297363,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 736,\n          \"logprob\": -0.9057617,\n          \"special\": false,\n          \"text\": \" there\"\n        },\n        {\n          \"id\": 403,\n          \"logprob\": -0.05758667,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.00970459,\n          \"special\": false,\n          \"text\": \" a\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nOnce upon a time, there was a\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -0.007621765,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.20812988,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 16114,\n          \"logprob\": -1.2587891,\n          \"special\": false,\n          \"text\": \"Once\"\n        },\n        {\n          \"id\": 3714,\n          \"logprob\": -0.20825195,\n          \"special\": false,\n          \"text\": \" upon\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.0017709732,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 727,\n          \"logprob\": -0.011932373,\n          \"special\": false,\n          \"text\": \" time\"\n        },\n        {\n          \"id\": 28725,\n          \"logprob\": -0.17297363,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 736,\n          \"logprob\": -0.9057617,\n          \"special\": false,\n          \"text\": \" there\"\n        },\n        {\n          \"id\": 403,\n          \"logprob\": -0.05758667,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 264,\n          \"logprob\": -0.00970459,\n          \"special\": false,\n          \"text\": \" a\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nOnce upon a time, there was a\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_llava_next/test_flash_llava_next_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.00756073,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.20117188,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 16114,\n        \"logprob\": -1.2597656,\n        \"special\": false,\n        \"text\": \"Once\"\n      },\n      {\n        \"id\": 3714,\n        \"logprob\": -0.20825195,\n        \"special\": false,\n        \"text\": \" upon\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.00178051,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 727,\n        \"logprob\": -0.011955261,\n        \"special\": false,\n        \"text\": \" time\"\n      },\n      {\n        \"id\": 28725,\n        \"logprob\": -0.17541504,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 736,\n        \"logprob\": -0.91308594,\n        \"special\": false,\n        \"text\": \" there\"\n      },\n      {\n        \"id\": 403,\n        \"logprob\": -0.058410645,\n        \"special\": false,\n        \"text\": \" was\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.009689331,\n        \"special\": false,\n        \"text\": \" a\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nOnce upon a time, there was a\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_customer_support_adapter.json",
    "content": "{\n  \"details\": {\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 40,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.27416992,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.17016602,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28737,\n        \"logprob\": -2.7109375,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 28809,\n        \"logprob\": -1.5,\n        \"special\": false,\n        \"text\": \"’\"\n      },\n      {\n        \"id\": 28719,\n        \"logprob\": -0.34204102,\n        \"special\": false,\n        \"text\": \"m\"\n      },\n      {\n        \"id\": 459,\n        \"logprob\": -1.6914062,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 1864,\n        \"logprob\": -0.69140625,\n        \"special\": false,\n        \"text\": \" sure\"\n      },\n      {\n        \"id\": 513,\n        \"logprob\": -1.6171875,\n        \"special\": false,\n        \"text\": \" if\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -1.3837891,\n        \"special\": false,\n        \"text\": \" I\"\n      },\n      {\n        \"id\": 541,\n        \"logprob\": -1.2226562,\n        \"special\": false,\n        \"text\": \" can\"\n      },\n      {\n        \"id\": 1567,\n        \"logprob\": -1.8652344,\n        \"special\": false,\n        \"text\": \" come\"\n      },\n      {\n        \"id\": 582,\n        \"logprob\": -0.0070228577,\n        \"special\": false,\n        \"text\": \" up\"\n      },\n      {\n        \"id\": 395,\n        \"logprob\": -0.0054092407,\n        \"special\": false,\n        \"text\": \" with\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.62597656,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28770,\n        \"logprob\": -0.0035572052,\n        \"special\": false,\n        \"text\": \"3\"\n      },\n      {\n        \"id\": 4842,\n        \"logprob\": -0.93603516,\n        \"special\": false,\n        \"text\": \" unique\"\n      },\n      {\n        \"id\": 3085,\n        \"logprob\": -0.028411865,\n        \"special\": false,\n        \"text\": \" words\"\n      },\n      {\n        \"id\": 369,\n        \"logprob\": -1.0400391,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 6685,\n        \"logprob\": -0.09710693,\n        \"special\": false,\n        \"text\": \" describe\"\n      },\n      {\n        \"id\": 528,\n        \"logprob\": -0.066467285,\n        \"special\": false,\n        \"text\": \" me\"\n      },\n      {\n        \"id\": 28725,\n        \"logprob\": -1.0722656,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 562,\n        \"logprob\": -0.33422852,\n        \"special\": false,\n        \"text\": \" but\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.5136719,\n        \"special\": false,\n        \"text\": \" I\"\n      },\n      {\n        \"id\": 28809,\n        \"logprob\": -0.8989258,\n        \"special\": false,\n        \"text\": \"’\"\n      },\n      {\n        \"id\": 584,\n        \"logprob\": -0.2076416,\n        \"special\": false,\n        \"text\": \"ll\"\n      },\n      {\n        \"id\": 1464,\n        \"logprob\": -0.8808594,\n        \"special\": false,\n        \"text\": \" try\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.88427734,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.91064453,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.08105469,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -1.8486328,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.111572266,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 23626,\n        \"logprob\": -3.15625,\n        \"special\": false,\n        \"text\": \" Creative\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.9194336,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28750,\n        \"logprob\": -0.24841309,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -9.393692e-05,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 6785,\n        \"logprob\": -3.1386719,\n        \"special\": false,\n        \"text\": \" Fun\"\n      },\n      {\n        \"id\": 1780,\n        \"logprob\": -0.53564453,\n        \"special\": false,\n        \"text\": \"ny\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.09033203,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28770,\n        \"logprob\": -0.00466156,\n        \"special\": false,\n        \"text\": \"3\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.00016450882,\n        \"special\": false,\n        \"text\": \".\"\n      }\n    ]\n  },\n  \"generated_text\": \"\\n\\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\\n\\n1. Creative\\n2. Funny\\n3.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_with_dbpedia_adapter.json",
    "content": "{\n  \"details\": {\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 7,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 1,\n        \"logprob\": -0.49658203,\n        \"special\": true,\n        \"text\": \"<s>\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.0016384125,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": -1.4931641,\n        \"special\": true,\n        \"text\": \"<s>\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.00075769424,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -0.25024414,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -0.2631836,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 2,\n        \"logprob\": -0.0003285408,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ]\n  },\n  \"generated_text\": \"  11\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_adapter.json",
    "content": "{\n  \"details\": {\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 40,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -1.0488281,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.0800781,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 27332,\n        \"logprob\": -2.1152344,\n        \"special\": false,\n        \"text\": \"###\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -1.6748047,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28740,\n        \"logprob\": -0.097229004,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.16467285,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 7615,\n        \"logprob\": -2.2246094,\n        \"special\": false,\n        \"text\": \" News\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.0488281,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 27332,\n        \"logprob\": -0.69189453,\n        \"special\": false,\n        \"text\": \"###\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.013343811,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28750,\n        \"logprob\": -0.011230469,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.00096845627,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 21095,\n        \"logprob\": -2.5605469,\n        \"special\": false,\n        \"text\": \" Blog\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.19458008,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 27332,\n        \"logprob\": -0.031280518,\n        \"special\": false,\n        \"text\": \"###\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.0030708313,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28770,\n        \"logprob\": -0.0029277802,\n        \"special\": false,\n        \"text\": \"3\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.0012350082,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 20108,\n        \"logprob\": -2.1582031,\n        \"special\": false,\n        \"text\": \" Article\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.05810547,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 27332,\n        \"logprob\": -0.35083008,\n        \"special\": false,\n        \"text\": \"###\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.034332275,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28781,\n        \"logprob\": -0.009666443,\n        \"special\": false,\n        \"text\": \"4\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.0013113022,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 8349,\n        \"logprob\": -2.6191406,\n        \"special\": false,\n        \"text\": \" Review\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.04031372,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 27332,\n        \"logprob\": -0.45239258,\n        \"special\": false,\n        \"text\": \"###\"\n      },\n      {\n        \"id\": 28705,\n        \"logprob\": -0.045410156,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 28782,\n        \"logprob\": -0.0041236877,\n        \"special\": false,\n        \"text\": \"5\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.0010223389,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 5299,\n        \"logprob\": -2.8066406,\n        \"special\": false,\n        \"text\": \" Other\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.12054443,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.44580078,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.4921875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.3574219,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.0039062,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.5859375,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.43481445,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.2783203,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.20410156,\n        \"special\": false,\n        \"text\": \"\\n\"\n      }\n    ]\n  },\n  \"generated_text\": \"\\n\\n### 1. News\\n### 2. Blog\\n### 3. Article\\n### 4. Review\\n### 5. Other\\n\\n\\n\\n\\n\\n\\n\\n\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_lora_mistral/test_lora_mistral_without_customer_support_adapter.json",
    "content": "{\n  \"details\": {\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 40,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -0.31347656,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.27441406,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28737,\n        \"logprob\": -2.2285156,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 28809,\n        \"logprob\": -1.4677734,\n        \"special\": false,\n        \"text\": \"’\"\n      },\n      {\n        \"id\": 28719,\n        \"logprob\": -0.31762695,\n        \"special\": false,\n        \"text\": \"m\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -1.6865234,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 1215,\n        \"logprob\": -3.2695312,\n        \"special\": false,\n        \"text\": \" very\"\n      },\n      {\n        \"id\": 20640,\n        \"logprob\": -3.1230469,\n        \"special\": false,\n        \"text\": \" passionate\"\n      },\n      {\n        \"id\": 1338,\n        \"logprob\": -0.48339844,\n        \"special\": false,\n        \"text\": \" person\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.9970703,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.5498047,\n        \"special\": false,\n        \"text\": \" I\"\n      },\n      {\n        \"id\": 28809,\n        \"logprob\": -1.1923828,\n        \"special\": false,\n        \"text\": \"’\"\n      },\n      {\n        \"id\": 28719,\n        \"logprob\": -0.080444336,\n        \"special\": false,\n        \"text\": \"m\"\n      },\n      {\n        \"id\": 1215,\n        \"logprob\": -1.8271484,\n        \"special\": false,\n        \"text\": \" very\"\n      },\n      {\n        \"id\": 12215,\n        \"logprob\": -2.8847656,\n        \"special\": false,\n        \"text\": \" driven\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -1.0927734,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 315,\n        \"logprob\": -0.4584961,\n        \"special\": false,\n        \"text\": \" I\"\n      },\n      {\n        \"id\": 28809,\n        \"logprob\": -0.5019531,\n        \"special\": false,\n        \"text\": \"’\"\n      },\n      {\n        \"id\": 28719,\n        \"logprob\": -0.030715942,\n        \"special\": false,\n        \"text\": \"m\"\n      },\n      {\n        \"id\": 1215,\n        \"logprob\": -0.96972656,\n        \"special\": false,\n        \"text\": \" very\"\n      },\n      {\n        \"id\": 7798,\n        \"logprob\": -2.8847656,\n        \"special\": false,\n        \"text\": \" determined\"\n      },\n      {\n        \"id\": 28723,\n        \"logprob\": -0.27319336,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.56396484,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.011016846,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 3195,\n        \"logprob\": -0.7163086,\n        \"special\": false,\n        \"text\": \"What\"\n      },\n      {\n        \"id\": 349,\n        \"logprob\": -1.1611328,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 574,\n        \"logprob\": -0.515625,\n        \"special\": false,\n        \"text\": \" your\"\n      },\n      {\n        \"id\": 6656,\n        \"logprob\": -1.0253906,\n        \"special\": false,\n        \"text\": \" favorite\"\n      },\n      {\n        \"id\": 1970,\n        \"logprob\": -2.1738281,\n        \"special\": false,\n        \"text\": \" thing\"\n      },\n      {\n        \"id\": 684,\n        \"logprob\": -0.48364258,\n        \"special\": false,\n        \"text\": \" about\"\n      },\n      {\n        \"id\": 1250,\n        \"logprob\": -1.8876953,\n        \"special\": false,\n        \"text\": \" being\"\n      },\n      {\n        \"id\": 264,\n        \"logprob\": -0.41967773,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 8626,\n        \"logprob\": -2.9160156,\n        \"special\": false,\n        \"text\": \" teacher\"\n      },\n      {\n        \"id\": 28804,\n        \"logprob\": -0.11920166,\n        \"special\": false,\n        \"text\": \"?\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.023727417,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.010848999,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 28737,\n        \"logprob\": -1.0566406,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 2016,\n        \"logprob\": -0.7163086,\n        \"special\": false,\n        \"text\": \" love\"\n      },\n      {\n        \"id\": 272,\n        \"logprob\": -1.9169922,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 1639,\n        \"logprob\": -2.03125,\n        \"special\": false,\n        \"text\": \" fact\"\n      }\n    ]\n  },\n  \"generated_text\": \"\\n\\nI’m a very passionate person. I’m very driven. I’m very determined.\\n\\nWhat is your favorite thing about being a teacher?\\n\\nI love the fact\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mamba/test_mamba.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 187,\n        \"logprob\": -0.37890625,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -0.35742188,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 30763,\n        \"logprob\": -1.1015625,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 4715,\n        \"logprob\": -0.5234375,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -0.55078125,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 247,\n        \"logprob\": -0.6640625,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 747,\n        \"logprob\": -2.0625,\n        \"special\": false,\n        \"text\": \" new\"\n      },\n      {\n        \"id\": 1511,\n        \"logprob\": -2.375,\n        \"special\": false,\n        \"text\": \" type\"\n      },\n      {\n        \"id\": 273,\n        \"logprob\": -0.0029144287,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5145,\n        \"logprob\": -1.2734375,\n        \"special\": false,\n        \"text\": \" machine\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\n\\nDeep learning is a new type of machine\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 2502,\n        \"logprob\": null,\n        \"text\": \" red\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -2.734375,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 8862,\n        \"logprob\": -3.6875,\n        \"text\": \" yellow\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.40234375,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 209,\n        \"logprob\": -8.25,\n        \"text\": \" \"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 187,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 395,\n        \"logprob\": -0.3125,\n        \"special\": false,\n        \"text\": \"and\"\n      },\n      {\n        \"id\": 4797,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" blue\"\n      },\n      {\n        \"id\": 9830,\n        \"logprob\": -2.25,\n        \"special\": false,\n        \"text\": \" colors\"\n      },\n      {\n        \"id\": 15,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 329,\n        \"logprob\": -2.296875,\n        \"special\": false,\n        \"text\": \" A\"\n      },\n      {\n        \"id\": 1180,\n        \"logprob\": -2.046875,\n        \"special\": false,\n        \"text\": \" number\"\n      },\n      {\n        \"id\": 273,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 253,\n        \"logprob\": -0.86328125,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 3295,\n        \"logprob\": -0.55078125,\n        \"special\": false,\n        \"text\": \" color\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"blue, red, yellow, \\nand blue colors. A number of the color\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.83984375,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -12.8125,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -2.84375,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -1.25,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 187,\n          \"logprob\": -0.37890625,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.4296875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.078125,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.515625,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.6015625,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.65625,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 747,\n          \"logprob\": -2.109375,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 1511,\n          \"logprob\": -2.328125,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.0032653809,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.28125,\n          \"special\": false,\n          \"text\": \" machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new type of machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.80078125,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -13.25,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -2.828125,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -1.1953125,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 187,\n          \"logprob\": -0.296875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.3359375,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.5546875,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.62890625,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.64453125,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 747,\n          \"logprob\": -2.078125,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 1511,\n          \"logprob\": -2.28125,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.0030670166,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.3125,\n          \"special\": false,\n          \"text\": \" machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new type of machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.80078125,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -13.25,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -2.828125,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -1.1953125,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 187,\n          \"logprob\": -0.296875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.3359375,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.5546875,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.62890625,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.64453125,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 747,\n          \"logprob\": -2.078125,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 1511,\n          \"logprob\": -2.28125,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.0030670166,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.3125,\n          \"special\": false,\n          \"text\": \" machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new type of machine\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.80078125,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -13.25,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -2.828125,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -1.1953125,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 187,\n          \"logprob\": -0.296875,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.3359375,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.2578125,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.5546875,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.62890625,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.64453125,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 747,\n          \"logprob\": -2.078125,\n          \"special\": false,\n          \"text\": \" new\"\n        },\n        {\n          \"id\": 1511,\n          \"logprob\": -2.28125,\n          \"special\": false,\n          \"text\": \" type\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.0030670166,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.3125,\n          \"special\": false,\n          \"text\": \" machine\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\n\\nDeep learning is a new type of machine\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mllama/test_mllama_load.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"A chicken sits on a pile of money, looking\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1747230173,\n    \"id\": \"\",\n    \"model\": \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"3.3.6-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  },\n  {\n    \"choices\": [\n      {\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null,\n        \"message\": {\n          \"content\": \"A chicken sits on a pile of money, looking\",\n          \"name\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"usage\": null\n      }\n    ],\n    \"created\": 1747230173,\n    \"id\": \"\",\n    \"model\": \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n    \"object\": \"chat.completion\",\n    \"system_fingerprint\": \"3.3.6-dev0-native\",\n    \"usage\": {\n      \"completion_tokens\": 10,\n      \"prompt_tokens\": 45,\n      \"total_tokens\": 55\n    }\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mllama/test_mllama_simpl.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"A chicken sits on a pile of money, looking\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1747230171,\n  \"id\": \"\",\n  \"model\": \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.3.6-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 10,\n    \"prompt_tokens\": 45,\n    \"total_tokens\": 55\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mpt/test_mpt.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 17,\n    \"prefill\": [\n      {\n        \"id\": 1276,\n        \"logprob\": null,\n        \"text\": \"What\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -1.5117188,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 18147,\n        \"logprob\": -8.96875,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 20727,\n        \"logprob\": -1.953125,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 32,\n        \"logprob\": -0.94189453,\n        \"text\": \"?\"\n      }\n    ],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 428,\n        \"logprob\": -1.5830078,\n        \"special\": false,\n        \"text\": \" -\"\n      },\n      {\n        \"id\": 18147,\n        \"logprob\": -3.3105469,\n        \"special\": false,\n        \"text\": \" Deep\"\n      },\n      {\n        \"id\": 20727,\n        \"logprob\": -0.3215332,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -2.5566406,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 30763,\n        \"logprob\": -1.6074219,\n        \"special\": false,\n        \"text\": \"Deep\"\n      },\n      {\n        \"id\": 20727,\n        \"logprob\": -0.69628906,\n        \"special\": false,\n        \"text\": \" Learning\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -0.6923828,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 247,\n        \"logprob\": -0.5263672,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 749,\n        \"logprob\": -1.8544922,\n        \"special\": false,\n        \"text\": \" sub\"\n      },\n      {\n        \"id\": 3423,\n        \"logprob\": -0.6118164,\n        \"special\": false,\n        \"text\": \"field\"\n      },\n      {\n        \"id\": 273,\n        \"logprob\": -0.055877686,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 5145,\n        \"logprob\": -1.0537109,\n        \"special\": false,\n        \"text\": \" machine\"\n      },\n      {\n        \"id\": 4715,\n        \"logprob\": -0.0115737915,\n        \"special\": false,\n        \"text\": \" learning\"\n      },\n      {\n        \"id\": 326,\n        \"logprob\": -0.9111328,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 4648,\n        \"logprob\": -1.4589844,\n        \"special\": false,\n        \"text\": \" uses\"\n      },\n      {\n        \"id\": 13345,\n        \"logprob\": -1.4853516,\n        \"special\": false,\n        \"text\": \" artificial\"\n      },\n      {\n        \"id\": 11454,\n        \"logprob\": -0.021636963,\n        \"special\": false,\n        \"text\": \" neural\"\n      }\n    ]\n  },\n  \"generated_text\": \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mpt/test_mpt_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 17,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.5117188,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -8.96875,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -1.953125,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -0.94189453,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 428,\n          \"logprob\": -1.5830078,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -3.3183594,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.32617188,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -2.5742188,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.6015625,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.69628906,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.67822266,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.5395508,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 749,\n          \"logprob\": -1.8623047,\n          \"special\": false,\n          \"text\": \" sub\"\n        },\n        {\n          \"id\": 3423,\n          \"logprob\": -0.6020508,\n          \"special\": false,\n          \"text\": \"field\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.0552063,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.0742188,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.011405945,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 326,\n          \"logprob\": -0.9165039,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 4648,\n          \"logprob\": -1.4501953,\n          \"special\": false,\n          \"text\": \" uses\"\n        },\n        {\n          \"id\": 13345,\n          \"logprob\": -1.4960938,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11454,\n          \"logprob\": -0.02116394,\n          \"special\": false,\n          \"text\": \" neural\"\n        }\n      ]\n    },\n    \"generated_text\": \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 17,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.5,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -8.984375,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -1.96875,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -0.93359375,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 428,\n          \"logprob\": -1.5800781,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -3.3242188,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.31835938,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -2.5644531,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.5957031,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.69628906,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.68603516,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.5258789,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 749,\n          \"logprob\": -1.859375,\n          \"special\": false,\n          \"text\": \" sub\"\n        },\n        {\n          \"id\": 3423,\n          \"logprob\": -0.6166992,\n          \"special\": false,\n          \"text\": \"field\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.056762695,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.0703125,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.011428833,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 326,\n          \"logprob\": -0.9213867,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 4648,\n          \"logprob\": -1.4726562,\n          \"special\": false,\n          \"text\": \" uses\"\n        },\n        {\n          \"id\": 13345,\n          \"logprob\": -1.5039062,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11454,\n          \"logprob\": -0.021652222,\n          \"special\": false,\n          \"text\": \" neural\"\n        }\n      ]\n    },\n    \"generated_text\": \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 17,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.5,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -8.984375,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -1.96875,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -0.93359375,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 428,\n          \"logprob\": -1.5800781,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -3.3242188,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.31835938,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -2.5644531,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.5957031,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.69628906,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.68603516,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.5258789,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 749,\n          \"logprob\": -1.859375,\n          \"special\": false,\n          \"text\": \" sub\"\n        },\n        {\n          \"id\": 3423,\n          \"logprob\": -0.6166992,\n          \"special\": false,\n          \"text\": \"field\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.056762695,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.0703125,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.011428833,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 326,\n          \"logprob\": -0.9213867,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 4648,\n          \"logprob\": -1.4726562,\n          \"special\": false,\n          \"text\": \" uses\"\n        },\n        {\n          \"id\": 13345,\n          \"logprob\": -1.5039062,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11454,\n          \"logprob\": -0.021652222,\n          \"special\": false,\n          \"text\": \" neural\"\n        }\n      ]\n    },\n    \"generated_text\": \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 17,\n      \"prefill\": [\n        {\n          \"id\": 1276,\n          \"logprob\": null,\n          \"text\": \"What\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -1.5,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -8.984375,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -1.96875,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 32,\n          \"logprob\": -0.93359375,\n          \"text\": \"?\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 428,\n          \"logprob\": -1.5800781,\n          \"special\": false,\n          \"text\": \" -\"\n        },\n        {\n          \"id\": 18147,\n          \"logprob\": -3.3242188,\n          \"special\": false,\n          \"text\": \" Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.31835938,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -2.5644531,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 30763,\n          \"logprob\": -1.5957031,\n          \"special\": false,\n          \"text\": \"Deep\"\n        },\n        {\n          \"id\": 20727,\n          \"logprob\": -0.69628906,\n          \"special\": false,\n          \"text\": \" Learning\"\n        },\n        {\n          \"id\": 310,\n          \"logprob\": -0.68603516,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -0.5258789,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 749,\n          \"logprob\": -1.859375,\n          \"special\": false,\n          \"text\": \" sub\"\n        },\n        {\n          \"id\": 3423,\n          \"logprob\": -0.6166992,\n          \"special\": false,\n          \"text\": \"field\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -0.056762695,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 5145,\n          \"logprob\": -1.0703125,\n          \"special\": false,\n          \"text\": \" machine\"\n        },\n        {\n          \"id\": 4715,\n          \"logprob\": -0.011428833,\n          \"special\": false,\n          \"text\": \" learning\"\n        },\n        {\n          \"id\": 326,\n          \"logprob\": -0.9213867,\n          \"special\": false,\n          \"text\": \" that\"\n        },\n        {\n          \"id\": 4648,\n          \"logprob\": -1.4726562,\n          \"special\": false,\n          \"text\": \" uses\"\n        },\n        {\n          \"id\": 13345,\n          \"logprob\": -1.5039062,\n          \"special\": false,\n          \"text\": \" artificial\"\n        },\n        {\n          \"id\": 11454,\n          \"logprob\": -0.021652222,\n          \"special\": false,\n          \"text\": \" neural\"\n        }\n      ]\n    },\n    \"generated_text\": \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 5,\n    \"prefill\": [\n      {\n        \"id\": 0,\n        \"logprob\": null,\n        \"text\": \"<pad>\"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 926,\n        \"logprob\": -4.3554688,\n        \"special\": false,\n        \"text\": \" To\"\n      },\n      {\n        \"id\": 18295,\n        \"logprob\": -7.7734375,\n        \"special\": false,\n        \"text\": \" sell\"\n      },\n      {\n        \"id\": 7868,\n        \"logprob\": -3.9257812,\n        \"special\": false,\n        \"text\": \" things\"\n      },\n      {\n        \"id\": 260,\n        \"logprob\": -2.4179688,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": 0.0,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ]\n  },\n  \"generated_text\": \"To sell things.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [\n      {\n        \"id\": 0,\n        \"logprob\": null,\n        \"text\": \"<pad>\"\n      }\n    ],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 16017,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" blue\"\n      },\n      {\n        \"id\": 20495,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" sky\"\n      },\n      {\n        \"id\": 259,\n        \"logprob\": -0.47070312,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 261,\n        \"logprob\": -0.15307617,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 35622,\n        \"logprob\": -0.796875,\n        \"special\": false,\n        \"text\": \" cloud\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": -1.2958984,\n        \"special\": false,\n        \"text\": \"s\"\n      },\n      {\n        \"id\": 305,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 35622,\n        \"logprob\": -1.2998047,\n        \"special\": false,\n        \"text\": \" cloud\"\n      },\n      {\n        \"id\": 263,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"s\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": 0.0,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Why is the sky blue?blue sky , clouds and clouds\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_mt0_base/test_mt0_base_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 6,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 259,\n          \"logprob\": -1.3798828,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -0.36328125,\n          \"special\": false,\n          \"text\": \"Because\"\n        },\n        {\n          \"id\": 609,\n          \"logprob\": -1.0947266,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 339,\n          \"logprob\": -0.8286133,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 16017,\n          \"logprob\": -1.6826172,\n          \"special\": false,\n          \"text\": \" blue\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.7290039,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"Because it is blue\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 6,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 259,\n          \"logprob\": -1.3789062,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -0.36279297,\n          \"special\": false,\n          \"text\": \"Because\"\n        },\n        {\n          \"id\": 609,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 339,\n          \"logprob\": -0.8276367,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 16017,\n          \"logprob\": -1.6845703,\n          \"special\": false,\n          \"text\": \" blue\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.72753906,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"Because it is blue\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 6,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 259,\n          \"logprob\": -1.3789062,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -0.36279297,\n          \"special\": false,\n          \"text\": \"Because\"\n        },\n        {\n          \"id\": 609,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 339,\n          \"logprob\": -0.8276367,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 16017,\n          \"logprob\": -1.6845703,\n          \"special\": false,\n          \"text\": \" blue\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.72753906,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"Because it is blue\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 6,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 259,\n          \"logprob\": -1.3789062,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 39261,\n          \"logprob\": -0.36279297,\n          \"special\": false,\n          \"text\": \"Because\"\n        },\n        {\n          \"id\": 609,\n          \"logprob\": -1.0966797,\n          \"special\": false,\n          \"text\": \" it\"\n        },\n        {\n          \"id\": 339,\n          \"logprob\": -0.8276367,\n          \"special\": false,\n          \"text\": \" is\"\n        },\n        {\n          \"id\": 16017,\n          \"logprob\": -1.6845703,\n          \"special\": false,\n          \"text\": \" blue\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.72753906,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"Because it is blue\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_neox/test_neox.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 42,\n        \"logprob\": -0.86279297,\n        \"special\": false,\n        \"text\": \"I\"\n      },\n      {\n        \"id\": 1353,\n        \"logprob\": -0.94921875,\n        \"special\": false,\n        \"text\": \"'m\"\n      },\n      {\n        \"id\": 7016,\n        \"logprob\": -2.1835938,\n        \"special\": false,\n        \"text\": \" sorry\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.074035645,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1394,\n        \"logprob\": -0.86376953,\n        \"special\": false,\n        \"text\": \"You\"\n      },\n      {\n        \"id\": 452,\n        \"logprob\": -1.2070312,\n        \"special\": false,\n        \"text\": \" have\"\n      },\n      {\n        \"id\": 247,\n        \"logprob\": -1.4365234,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 4327,\n        \"logprob\": -1.109375,\n        \"special\": false,\n        \"text\": \" choice\"\n      },\n      {\n        \"id\": 273,\n        \"logprob\": -0.93408203,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 752,\n        \"logprob\": -1.8808594,\n        \"special\": false,\n        \"text\": \" what\"\n      }\n    ]\n  },\n  \"generated_text\": \"I'm sorry,You have a choice of what\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_neox/test_neox_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8618164,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 7016,\n          \"logprob\": -2.1738281,\n          \"special\": false,\n          \"text\": \" sorry\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.0758667,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1394,\n          \"logprob\": -0.9135742,\n          \"special\": false,\n          \"text\": \"You\"\n        },\n        {\n          \"id\": 452,\n          \"logprob\": -1.1445312,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -1.4375,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 4327,\n          \"logprob\": -1.1103516,\n          \"special\": false,\n          \"text\": \" choice\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -1.0058594,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 752,\n          \"logprob\": -1.921875,\n          \"special\": false,\n          \"text\": \" what\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm sorry,You have a choice of what\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8618164,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 7016,\n          \"logprob\": -2.1738281,\n          \"special\": false,\n          \"text\": \" sorry\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.0758667,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1394,\n          \"logprob\": -0.9135742,\n          \"special\": false,\n          \"text\": \"You\"\n        },\n        {\n          \"id\": 452,\n          \"logprob\": -1.1445312,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -1.4375,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 4327,\n          \"logprob\": -1.1103516,\n          \"special\": false,\n          \"text\": \" choice\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -1.0058594,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 752,\n          \"logprob\": -1.921875,\n          \"special\": false,\n          \"text\": \" what\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm sorry,You have a choice of what\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8618164,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 7016,\n          \"logprob\": -2.1738281,\n          \"special\": false,\n          \"text\": \" sorry\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.0758667,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1394,\n          \"logprob\": -0.9135742,\n          \"special\": false,\n          \"text\": \"You\"\n        },\n        {\n          \"id\": 452,\n          \"logprob\": -1.1445312,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -1.4375,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 4327,\n          \"logprob\": -1.1103516,\n          \"special\": false,\n          \"text\": \" choice\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -1.0058594,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 752,\n          \"logprob\": -1.921875,\n          \"special\": false,\n          \"text\": \" what\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm sorry,You have a choice of what\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 42,\n          \"logprob\": -0.8618164,\n          \"special\": false,\n          \"text\": \"I\"\n        },\n        {\n          \"id\": 1353,\n          \"logprob\": -0.9506836,\n          \"special\": false,\n          \"text\": \"'m\"\n        },\n        {\n          \"id\": 7016,\n          \"logprob\": -2.1738281,\n          \"special\": false,\n          \"text\": \" sorry\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.0758667,\n          \"special\": false,\n          \"text\": \",\"\n        },\n        {\n          \"id\": 1394,\n          \"logprob\": -0.9135742,\n          \"special\": false,\n          \"text\": \"You\"\n        },\n        {\n          \"id\": 452,\n          \"logprob\": -1.1445312,\n          \"special\": false,\n          \"text\": \" have\"\n        },\n        {\n          \"id\": 247,\n          \"logprob\": -1.4375,\n          \"special\": false,\n          \"text\": \" a\"\n        },\n        {\n          \"id\": 4327,\n          \"logprob\": -1.1103516,\n          \"special\": false,\n          \"text\": \" choice\"\n        },\n        {\n          \"id\": 273,\n          \"logprob\": -1.0058594,\n          \"special\": false,\n          \"text\": \" of\"\n        },\n        {\n          \"id\": 752,\n          \"logprob\": -1.921875,\n          \"special\": false,\n          \"text\": \" what\"\n        }\n      ]\n    },\n    \"generated_text\": \"I'm sorry,You have a choice of what\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_neox_sharded/test_neox.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 510,\n        \"logprob\": -0.5878906,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 3159,\n        \"logprob\": -0.5449219,\n        \"special\": false,\n        \"text\": \" word\"\n      },\n      {\n        \"id\": 346,\n        \"logprob\": -0.05038452,\n        \"special\": false,\n        \"text\": \" \\\"\"\n      },\n      {\n        \"id\": 6441,\n        \"logprob\": -0.002292633,\n        \"special\": false,\n        \"text\": \"mem\"\n      },\n      {\n        \"id\": 70,\n        \"logprob\": -1.3828278e-05,\n        \"special\": false,\n        \"text\": \"e\"\n      },\n      {\n        \"id\": 3,\n        \"logprob\": -0.0010242462,\n        \"special\": false,\n        \"text\": \"\\\"\"\n      },\n      {\n        \"id\": 369,\n        \"logprob\": -0.090270996,\n        \"special\": false,\n        \"text\": \" was\"\n      },\n      {\n        \"id\": 806,\n        \"logprob\": -0.12719727,\n        \"special\": false,\n        \"text\": \" first\"\n      },\n      {\n        \"id\": 908,\n        \"logprob\": -0.016571045,\n        \"special\": false,\n        \"text\": \" used\"\n      },\n      {\n        \"id\": 275,\n        \"logprob\": -0.43432617,\n        \"special\": false,\n        \"text\": \" in\"\n      }\n    ]\n  },\n  \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_neox_sharded/test_neox_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.5878906,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.5498047,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.04815674,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.002313614,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.2636185e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0010147095,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.0859375,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12609863,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.016601562,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.38256836,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.6201172,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.051879883,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.0020179749,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -9.059906e-06,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00096797943,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.07940674,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12182617,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.017227173,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.44482422,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.6201172,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.051879883,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.0020179749,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -9.059906e-06,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.00096797943,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.07940674,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12182617,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.017227173,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.44482422,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 510,\n          \"logprob\": -0.6201172,\n          \"special\": false,\n          \"text\": \"The\"\n        },\n        {\n          \"id\": 3159,\n          \"logprob\": -0.546875,\n          \"special\": false,\n          \"text\": \" word\"\n        },\n        {\n          \"id\": 346,\n          \"logprob\": -0.051879883,\n          \"special\": false,\n          \"text\": \" \\\"\"\n        },\n        {\n          \"id\": 6441,\n          \"logprob\": -0.0020179749,\n          \"special\": false,\n          \"text\": \"mem\"\n        },\n        {\n          \"id\": 70,\n          \"logprob\": -1.04904175e-05,\n          \"special\": false,\n          \"text\": \"e\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0009560585,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        },\n        {\n          \"id\": 369,\n          \"logprob\": -0.08557129,\n          \"special\": false,\n          \"text\": \" was\"\n        },\n        {\n          \"id\": 806,\n          \"logprob\": -0.12084961,\n          \"special\": false,\n          \"text\": \" first\"\n        },\n        {\n          \"id\": 908,\n          \"logprob\": -0.01737976,\n          \"special\": false,\n          \"text\": \" used\"\n        },\n        {\n          \"id\": 275,\n          \"logprob\": -0.4025879,\n          \"special\": false,\n          \"text\": \" in\"\n        }\n      ]\n    },\n    \"generated_text\": \"The word \\\"meme\\\" was first used in\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 13,\n        \"logprob\": -2.3417969,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 3057,\n        \"logprob\": -1.8730469,\n        \"special\": false,\n        \"text\": \"Test\"\n      },\n      {\n        \"id\": 2009,\n        \"logprob\": -1.2626953,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -1.7060547,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 3057,\n        \"logprob\": -1.4482422,\n        \"special\": false,\n        \"text\": \"Test\"\n      },\n      {\n        \"id\": 2009,\n        \"logprob\": -0.15246582,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.796875,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 3057,\n        \"logprob\": -0.22766113,\n        \"special\": false,\n        \"text\": \"Test\"\n      },\n      {\n        \"id\": 2009,\n        \"logprob\": -0.007045746,\n        \"special\": false,\n        \"text\": \" request\"\n      },\n      {\n        \"id\": 13,\n        \"logprob\": -0.021759033,\n        \"special\": false,\n        \"text\": \"\\n\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"\\nTest request\\nTest request\\nTest request\\n\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_all_params.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": 0,\n    \"tokens\": [\n      {\n        \"id\": 29899,\n        \"logprob\": -1.4980469,\n        \"special\": false,\n        \"text\": \"-\"\n      },\n      {\n        \"id\": 1454,\n        \"logprob\": -0.19433594,\n        \"special\": false,\n        \"text\": \"for\"\n      },\n      {\n        \"id\": 29899,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"-\"\n      },\n      {\n        \"id\": 9342,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"comment\"\n      },\n      {\n        \"id\": 29901,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 396,\n        \"logprob\": -0.27392578,\n        \"special\": false,\n        \"text\": \" #\"\n      },\n      {\n        \"id\": 29906,\n        \"logprob\": -0.49389648,\n        \"special\": false,\n        \"text\": \"2\"\n      },\n      {\n        \"id\": 29900,\n        \"logprob\": -0.81103516,\n        \"special\": false,\n        \"text\": \"0\"\n      },\n      {\n        \"id\": 29896,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"1\"\n      },\n      {\n        \"id\": 29955,\n        \"logprob\": -1.0800781,\n        \"special\": false,\n        \"text\": \"7\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \"Test request-for-comment: #2017\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_server_gptq_quantized/test_server_gptq_quantized_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.3359375,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.8623047,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -1.2451172,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.6923828,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.4492188,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.15197754,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.8022461,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -0.22583008,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.007095337,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.021652222,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nTest request\\nTest request\\nTest request\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.3476562,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.8789062,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -1.2734375,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.703125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.4677734,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.15454102,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.7973633,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -0.23278809,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.006980896,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.022033691,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nTest request\\nTest request\\nTest request\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.3203125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.8486328,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -1.2480469,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.7060547,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.4511719,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.1529541,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.81396484,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -0.22180176,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.007133484,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.021835327,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nTest request\\nTest request\\nTest request\\n\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 13,\n          \"logprob\": -2.3261719,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.8691406,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -1.2597656,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -1.7070312,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -1.4550781,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.1538086,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.79345703,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 3057,\n          \"logprob\": -0.22924805,\n          \"special\": false,\n          \"text\": \"Test\"\n        },\n        {\n          \"id\": 2009,\n          \"logprob\": -0.0070266724,\n          \"special\": false,\n          \"text\": \" request\"\n        },\n        {\n          \"id\": 13,\n          \"logprob\": -0.021942139,\n          \"special\": false,\n          \"text\": \"\\n\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \"\\nTest request\\nTest request\\nTest request\\n\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_smolvlm/test_flash_smolvlm_next_simple_url.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 8,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 330,\n        \"logprob\": -0.107421875,\n        \"special\": false,\n        \"text\": \" A\"\n      },\n      {\n        \"id\": 11426,\n        \"logprob\": -0.30078125,\n        \"special\": false,\n        \"text\": \" bee\"\n      },\n      {\n        \"id\": 335,\n        \"logprob\": -0.9609375,\n        \"special\": false,\n        \"text\": \" on\"\n      },\n      {\n        \"id\": 253,\n        \"logprob\": -0.0703125,\n        \"special\": false,\n        \"text\": \" a\"\n      },\n      {\n        \"id\": 11986,\n        \"logprob\": -0.5,\n        \"special\": false,\n        \"text\": \" pink\"\n      },\n      {\n        \"id\": 8525,\n        \"logprob\": -0.09716797,\n        \"special\": false,\n        \"text\": \" flower\"\n      },\n      {\n        \"id\": 30,\n        \"logprob\": -1.078125,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 49154,\n        \"logprob\": -0.110839844,\n        \"special\": true,\n        \"text\": \"<end_of_utterance>\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" A bee on a pink flower.\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"eos_token\",\n    \"generated_tokens\": 7,\n    \"prefill\": [\n      {\n        \"id\": 0,\n        \"logprob\": null,\n        \"text\": \"<pad>\"\n      }\n    ],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 3,\n        \"logprob\": -0.7001953,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 18,\n        \"logprob\": -1.1943359,\n        \"special\": false,\n        \"text\": \"-\"\n      },\n      {\n        \"id\": 26937,\n        \"logprob\": -1.2099609,\n        \"special\": false,\n        \"text\": \"196\"\n      },\n      {\n        \"id\": 3,\n        \"logprob\": -1.2451172,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 1956,\n        \"logprob\": -0.3322754,\n        \"special\": false,\n        \"text\": \"°\"\n      },\n      {\n        \"id\": 254,\n        \"logprob\": -0.19213867,\n        \"special\": false,\n        \"text\": \"C\"\n      },\n      {\n        \"id\": 1,\n        \"logprob\": -0.030151367,\n        \"special\": true,\n        \"text\": \"</s>\"\n      }\n    ]\n  },\n  \"generated_text\": \"-196 °C\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 7,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 3,\n          \"logprob\": -0.7001953,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.1943359,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 26937,\n          \"logprob\": -1.2119141,\n          \"special\": false,\n          \"text\": \"196\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -1.2480469,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 1956,\n          \"logprob\": -0.33203125,\n          \"special\": false,\n          \"text\": \"°\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.19250488,\n          \"special\": false,\n          \"text\": \"C\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.030166626,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"-196 °C\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 7,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 3,\n          \"logprob\": -0.7001953,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.1943359,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 26937,\n          \"logprob\": -1.2119141,\n          \"special\": false,\n          \"text\": \"196\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -1.2480469,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 1956,\n          \"logprob\": -0.33203125,\n          \"special\": false,\n          \"text\": \"°\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.19250488,\n          \"special\": false,\n          \"text\": \"C\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.030166626,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"-196 °C\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 7,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 3,\n          \"logprob\": -0.7001953,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.1943359,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 26937,\n          \"logprob\": -1.2119141,\n          \"special\": false,\n          \"text\": \"196\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -1.2480469,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 1956,\n          \"logprob\": -0.33203125,\n          \"special\": false,\n          \"text\": \"°\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.19250488,\n          \"special\": false,\n          \"text\": \"C\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.030166626,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"-196 °C\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"eos_token\",\n      \"generated_tokens\": 7,\n      \"prefill\": [\n        {\n          \"id\": 0,\n          \"logprob\": null,\n          \"text\": \"<pad>\"\n        }\n      ],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 3,\n          \"logprob\": -0.7001953,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 18,\n          \"logprob\": -1.1943359,\n          \"special\": false,\n          \"text\": \"-\"\n        },\n        {\n          \"id\": 26937,\n          \"logprob\": -1.2099609,\n          \"special\": false,\n          \"text\": \"196\"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -1.2451172,\n          \"special\": false,\n          \"text\": \" \"\n        },\n        {\n          \"id\": 1956,\n          \"logprob\": -0.3322754,\n          \"special\": false,\n          \"text\": \"°\"\n        },\n        {\n          \"id\": 254,\n          \"logprob\": -0.19213867,\n          \"special\": false,\n          \"text\": \"C\"\n        },\n        {\n          \"id\": 1,\n          \"logprob\": -0.030151367,\n          \"special\": true,\n          \"text\": \"</s>\"\n        }\n      ]\n    },\n    \"generated_text\": \"-196 °C\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_auto_nostream.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": [\n          {\n            \"function\": {\n              \"arguments\": \"{\\\"location\\\":\\\"Brooklyn, NY\\\",\\\"format\\\":\\\"fahrenheit\\\"}\",\n              \"description\": null,\n              \"name\": \"get_current_weather\"\n            },\n            \"id\": \"0\",\n            \"type\": \"function\"\n          }\n        ]\n      }\n    }\n  ],\n  \"created\": 1741372434,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 501,\n    \"total_tokens\": 530\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice_nostream.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": [\n          {\n            \"function\": {\n              \"arguments\": \"{\\\"location\\\":\\\"Brooklyn, NY\\\",\\\"format\\\":\\\"fahrenheit\\\"}\",\n              \"description\": null,\n              \"name\": \"get_current_weather\"\n            },\n            \"id\": \"0\",\n            \"type\": \"function\"\n          }\n        ]\n      }\n    }\n  ],\n  \"created\": 1741372657,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 286,\n    \"total_tokens\": 315\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_choice_stream.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"{\",\n                \"name\": \"get_current_weather\"\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"location\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\":\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"Bro\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"oklyn\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \",\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" NY\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\",\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"format\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\":\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"f\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"ahrenheit\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\"}\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741688515,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_nostream.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"I'm an artificial intelligence model known as a large language model (LLM) or conversational AI\",\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      }\n    }\n  ],\n  \"created\": 1741693957,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 12,\n    \"prompt_tokens\": 53,\n    \"total_tokens\": 65\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_insufficient_information_stream.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"I\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"'m\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" an\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" artificial\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" intelligence\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" model\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" known\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" as\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" large\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" language\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" model\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" (\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"LL\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"M\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \")\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" or\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" convers\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"ational\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" AI\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741694017,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_nostream.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": [\n          {\n            \"function\": {\n              \"arguments\": \"{\\\"location\\\":\\\"Brooklyn, NY\\\",\\\"format\\\":\\\"fahrenheit\\\"}\",\n              \"description\": null,\n              \"name\": \"get_current_weather\"\n            },\n            \"id\": \"0\",\n            \"type\": \"function\"\n          }\n        ]\n      }\n    }\n  ],\n  \"created\": 1741372335,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 29,\n    \"prompt_tokens\": 501,\n    \"total_tokens\": 530\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_openai.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"{\",\n                \"name\": \"get_current_weather\"\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"location\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\":\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"Bro\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"oklyn\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \",\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" NY\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\",\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"format\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\":\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \" \\\"\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"f\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"ahrenheit\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": null,\n          \"function_call\": null,\n          \"refusal\": null,\n          \"role\": \"assistant\",\n          \"tool_calls\": [\n            {\n              \"function\": {\n                \"arguments\": \"\\\"}\",\n                \"name\": null\n              },\n              \"id\": \"0\",\n              \"index\": 0,\n              \"type\": \"function\"\n            }\n          ]\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741689423,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"service_tier\": null,\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_auto.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Once\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" upon\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" time\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" in\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" vibrant\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" ocean\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" filled\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" with\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" coral\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" reefs\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" schools\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" of\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" shimmer\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"ing\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" fish\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741695408,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_function_object.json",
    "content": "[]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_none.json",
    "content": "[\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"Once\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263693,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" upon\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263693,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263693,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" time\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263693,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" in\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" vibrant\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" ocean\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" filled\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" with\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" coral\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" reefs\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" schools\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" of\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" shimmer\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"ing\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" fish\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" lived\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" three\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" dear\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" friends\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \":\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Luna\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" sea\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" turtle\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Fin\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"ley\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" friendly\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" fish\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263694,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Cr\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"usty\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" wise\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" crab\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \".\\n\\n\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"L\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"una\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" was\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" oldest\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" of\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" three\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \".\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" She\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" had\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" traveled\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" world\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" exploring\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" hidden\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" caves\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" ship\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"w\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"re\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263695,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"cks\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" collecting\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" sparkling\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" shells\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" shiny\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" pe\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"bb\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"les\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \".\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" Her\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" shell\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" was\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" a\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" beautiful\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" mosaic\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" of\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" blues\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" greens\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \",\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" and\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" her\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" gentle\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" eyes\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" twink\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \"led\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" with\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263696,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" secrets\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263697,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" of\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263697,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" the\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": null,\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263697,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  },\n  {\n    \"choices\": [\n      {\n        \"delta\": {\n          \"content\": \" deep\",\n          \"role\": \"assistant\",\n          \"tool_calls\": null\n        },\n        \"finish_reason\": \"length\",\n        \"index\": 0,\n        \"logprobs\": null\n      }\n    ],\n    \"created\": 1741263697,\n    \"id\": \"\",\n    \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n    \"object\": \"chat.completion.chunk\",\n    \"system_fingerprint\": \"3.1.2-dev0-native\",\n    \"usage\": null\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_grammar_tools_sea_creatures_stream_required.json",
    "content": "[]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_tools_llama/test_flash_llama_tool_reply_response.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\\n\\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \\n\\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.\",\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      }\n    }\n  ],\n  \"created\": 1741263702,\n  \"id\": \"\",\n  \"model\": \"meta-llama/Llama-3.1-8B-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.1.2-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 83,\n    \"prompt_tokens\": 109,\n    \"total_tokens\": 192\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 100,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 2721,\n        \"logprob\": -0.21582031,\n        \"special\": false,\n        \"text\": \" people\"\n      },\n      {\n        \"id\": 21807,\n        \"logprob\": -0.26953125,\n        \"special\": false,\n        \"text\": \" died\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": -0.95703125,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -1.3359375,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -1.3828125,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 7284,\n        \"logprob\": -0.011291504,\n        \"special\": false,\n        \"text\": \"191\"\n      },\n      {\n        \"id\": 36,\n        \"logprob\": -0.011413574,\n        \"special\": false,\n        \"text\": \"8\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.23242188,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 27650,\n        \"logprob\": -0.0010070801,\n        \"special\": false,\n        \"text\": \" pandemic\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": -0.69140625,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 114059,\n        \"logprob\": -1.4375,\n        \"special\": false,\n        \"text\": \" Estimating\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.24316406,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 10593,\n        \"logprob\": -0.37304688,\n        \"special\": false,\n        \"text\": \" death\"\n      },\n      {\n        \"id\": 49973,\n        \"logprob\": -0.025390625,\n        \"special\": false,\n        \"text\": \" toll\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -0.27539062,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.057617188,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -0.040527344,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 7284,\n        \"logprob\": -0.00050735474,\n        \"special\": false,\n        \"text\": \"191\"\n      },\n      {\n        \"id\": 36,\n        \"logprob\": -9.298325e-06,\n        \"special\": false,\n        \"text\": \"8\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.09863281,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 27650,\n        \"logprob\": -0.0011749268,\n        \"special\": false,\n        \"text\": \" pandemic\"\n      },\n      {\n        \"id\": 373,\n        \"logprob\": -0.32421875,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 8210,\n        \"logprob\": -0.58203125,\n        \"special\": false,\n        \"text\": \" difficult\"\n      },\n      {\n        \"id\": 2895,\n        \"logprob\": -0.40429688,\n        \"special\": false,\n        \"text\": \" because\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -1.2734375,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 49119,\n        \"logprob\": -0.51171875,\n        \"special\": false,\n        \"text\": \" incomplete\"\n      },\n      {\n        \"id\": 13308,\n        \"logprob\": -0.38085938,\n        \"special\": false,\n        \"text\": \" records\"\n      },\n      {\n        \"id\": 341,\n        \"logprob\": -0.55859375,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 2895,\n        \"logprob\": -0.765625,\n        \"special\": false,\n        \"text\": \" because\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -1.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.828125,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2304,\n        \"logprob\": -1.015625,\n        \"special\": false,\n        \"text\": \" fact\"\n      },\n      {\n        \"id\": 511,\n        \"logprob\": -0.004638672,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 2233,\n        \"logprob\": -0.953125,\n        \"special\": false,\n        \"text\": \" many\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": -0.87890625,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.60546875,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 6759,\n        \"logprob\": -1.6484375,\n        \"special\": false,\n        \"text\": \" extra\"\n      },\n      {\n        \"id\": 40657,\n        \"logprob\": -0.00022125244,\n        \"special\": false,\n        \"text\": \" deaths\"\n      },\n      {\n        \"id\": 1610,\n        \"logprob\": -0.67578125,\n        \"special\": false,\n        \"text\": \" were\"\n      },\n      {\n        \"id\": 702,\n        \"logprob\": -0.30664062,\n        \"special\": false,\n        \"text\": \" not\"\n      },\n      {\n        \"id\": 48692,\n        \"logprob\": -0.1953125,\n        \"special\": false,\n        \"text\": \" attributed\"\n      },\n      {\n        \"id\": 328,\n        \"logprob\": -0.0079956055,\n        \"special\": false,\n        \"text\": \" to\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.515625,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.0040893555,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": -0.083496094,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 13618,\n        \"logprob\": -0.515625,\n        \"special\": false,\n        \"text\": \" Many\"\n      },\n      {\n        \"id\": 22215,\n        \"logprob\": -1.5703125,\n        \"special\": false,\n        \"text\": \" experts\"\n      },\n      {\n        \"id\": 11081,\n        \"logprob\": -0.96875,\n        \"special\": false,\n        \"text\": \" believe\"\n      },\n      {\n        \"id\": 511,\n        \"logprob\": -0.1171875,\n        \"special\": false,\n        \"text\": \" that\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.25195312,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -0.828125,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 7284,\n        \"logprob\": -0.00010967255,\n        \"special\": false,\n        \"text\": \"191\"\n      },\n      {\n        \"id\": 36,\n        \"logprob\": -8.535385e-05,\n        \"special\": false,\n        \"text\": \"8\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.056152344,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 27650,\n        \"logprob\": -0.0007095337,\n        \"special\": false,\n        \"text\": \" pandemic\"\n      },\n      {\n        \"id\": 26132,\n        \"logprob\": -0.18847656,\n        \"special\": false,\n        \"text\": \" killed\"\n      },\n      {\n        \"id\": 1867,\n        \"logprob\": -0.71484375,\n        \"special\": false,\n        \"text\": \" between\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -0.0062561035,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 1175,\n        \"logprob\": -0.009277344,\n        \"special\": false,\n        \"text\": \"50\"\n      },\n      {\n        \"id\": 341,\n        \"logprob\": -0.15332031,\n        \"special\": false,\n        \"text\": \" and\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -8.34465e-07,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 1135,\n        \"logprob\": -0.00065612793,\n        \"special\": false,\n        \"text\": \"100\"\n      },\n      {\n        \"id\": 5534,\n        \"logprob\": -1.4066696e-05,\n        \"special\": false,\n        \"text\": \" million\"\n      },\n      {\n        \"id\": 2721,\n        \"logprob\": -0.0008392334,\n        \"special\": false,\n        \"text\": \" people\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": -0.54296875,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 372,\n        \"logprob\": -1.8046875,\n        \"special\": false,\n        \"text\": \" I\"\n      },\n      {\n        \"id\": 140680,\n        \"logprob\": -0.578125,\n        \"special\": false,\n        \"text\": \"assistant\"\n      },\n      {\n        \"id\": 200006,\n        \"logprob\": 0.0,\n        \"special\": true,\n        \"text\": \"<|header_end|>\"\n      },\n      {\n        \"id\": 368,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"\\n\\n\"\n      },\n      {\n        \"id\": 954,\n        \"logprob\": -0.032226562,\n        \"special\": false,\n        \"text\": \"The\"\n      },\n      {\n        \"id\": 220,\n        \"logprob\": -4.4345856e-05,\n        \"special\": false,\n        \"text\": \" \"\n      },\n      {\n        \"id\": 7284,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"191\"\n      },\n      {\n        \"id\": 36,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \"8\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.015625,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 27650,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" pandemic\"\n      },\n      {\n        \"id\": 24,\n        \"logprob\": -0.0072021484,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 1437,\n        \"logprob\": -0.0001707077,\n        \"special\": false,\n        \"text\": \" also\"\n      },\n      {\n        \"id\": 5711,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" known\"\n      },\n      {\n        \"id\": 486,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" as\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -5.9604645e-07,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 25836,\n        \"logprob\": -1.4305115e-06,\n        \"special\": false,\n        \"text\": \" Spanish\"\n      },\n      {\n        \"id\": 18938,\n        \"logprob\": -0.0015029907,\n        \"special\": false,\n        \"text\": \" flu\"\n      },\n      {\n        \"id\": 24,\n        \"logprob\": -0.0052490234,\n        \"special\": false,\n        \"text\": \",\"\n      },\n      {\n        \"id\": 373,\n        \"logprob\": -0.3125,\n        \"special\": false,\n        \"text\": \" is\"\n      },\n      {\n        \"id\": 26078,\n        \"logprob\": -0.21289062,\n        \"special\": false,\n        \"text\": \" indeed\"\n      },\n      {\n        \"id\": 1085,\n        \"logprob\": -0.080078125,\n        \"special\": false,\n        \"text\": \" one\"\n      },\n      {\n        \"id\": 323,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" of\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 2167,\n        \"logprob\": -0.20117188,\n        \"special\": false,\n        \"text\": \" most\"\n      },\n      {\n        \"id\": 92679,\n        \"logprob\": -0.12695312,\n        \"special\": false,\n        \"text\": \" devastating\"\n      },\n      {\n        \"id\": 854,\n        \"logprob\": -0.25976562,\n        \"special\": false,\n        \"text\": \" public\"\n      },\n      {\n        \"id\": 4500,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" health\"\n      },\n      {\n        \"id\": 93079,\n        \"logprob\": -0.50390625,\n        \"special\": false,\n        \"text\": \" crises\"\n      },\n      {\n        \"id\": 310,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" in\"\n      },\n      {\n        \"id\": 6023,\n        \"logprob\": -0.0015182495,\n        \"special\": false,\n        \"text\": \" human\"\n      },\n      {\n        \"id\": 7068,\n        \"logprob\": 0.0,\n        \"special\": false,\n        \"text\": \" history\"\n      },\n      {\n        \"id\": 26,\n        \"logprob\": -0.0012664795,\n        \"special\": false,\n        \"text\": \".\"\n      },\n      {\n        \"id\": 114059,\n        \"logprob\": -0.004119873,\n        \"special\": false,\n        \"text\": \" Estimating\"\n      },\n      {\n        \"id\": 290,\n        \"logprob\": -0.00033569336,\n        \"special\": false,\n        \"text\": \" the\"\n      },\n      {\n        \"id\": 6318,\n        \"logprob\": -0.20117188,\n        \"special\": false,\n        \"text\": \" exact\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \" people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\\n\\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact\"\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4_image_base64_rgb_jpg.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image is a blank white space with no visible objects or features. It appears to be an empty or placeholder image, devoid of any content or visual elements.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1743861910,\n  \"id\": \"\",\n  \"model\": \"ll-re/Llama-4-Scout-17B-16E-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 34,\n    \"prompt_tokens\": 166,\n    \"total_tokens\": 200\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4_image_base64_rgb_png.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image is a blank white space with no visible objects or features.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1743861909,\n  \"id\": \"\",\n  \"model\": \"ll-re/Llama-4-Scout-17B-16E-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 15,\n    \"prompt_tokens\": 166,\n    \"total_tokens\": 181\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4_image_base64_rgba.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"length\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image is a black background with no discernible objects or features. The image appears to be a blank or empty space, devoid of any visual elements.\\n\\n**Key Features:**\\n\\n* **Color:** The dominant color of the image is black.\\n* **Objects:** There are no visible objects or shapes in the image.\\n* **Background:** The background of the image is a solid black color.\\n\\n**Conclusion:**\\nIn summary, the image is a simple and empty visual representation with a black background and no\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1743861909,\n  \"id\": \"\",\n  \"model\": \"ll-re/Llama-4-Scout-17B-16E-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 100,\n    \"prompt_tokens\": 166,\n    \"total_tokens\": 266\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4_image_cow.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1743863057,\n  \"id\": \"\",\n  \"model\": \"ll-re/Llama-4-Scout-17B-16E-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 46,\n    \"prompt_tokens\": 164,\n    \"total_tokens\": 210\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_llama4/test_flash_llama4_image_cow_dog.json",
    "content": "{\n  \"choices\": [\n    {\n      \"finish_reason\": \"stop\",\n      \"index\": 0,\n      \"logprobs\": null,\n      \"message\": {\n        \"content\": \"The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify.\",\n        \"name\": null,\n        \"role\": \"assistant\",\n        \"tool_calls\": null\n      },\n      \"usage\": null\n    }\n  ],\n  \"created\": 1743863056,\n  \"id\": \"\",\n  \"model\": \"ll-re/Llama-4-Scout-17B-16E-Instruct\",\n  \"object\": \"chat.completion\",\n  \"system_fingerprint\": \"3.2.1-dev0-native\",\n  \"usage\": {\n    \"completion_tokens\": 30,\n    \"prompt_tokens\": 168,\n    \"total_tokens\": 198\n  }\n}\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_olmo/test_flash_llama_load.json",
    "content": "[\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 27,\n          \"logprob\": -1.3457031,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.453125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.4111328,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 11202,\n          \"logprob\": -0.45898438,\n          \"special\": false,\n          \"text\": \"```\"\n        },\n        {\n          \"id\": 8456,\n          \"logprob\": -0.41918945,\n          \"special\": false,\n          \"text\": \"json\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.003189087,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.061187744,\n          \"special\": false,\n          \"text\": \"{\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.009010315,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50276,\n          \"logprob\": -0.484375,\n          \"special\": false,\n          \"text\": \"  \"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0002951622,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \":\\n\\n```json\\n{\\n  \\\"\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 27,\n          \"logprob\": -1.3457031,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.453125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.4111328,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 11202,\n          \"logprob\": -0.45898438,\n          \"special\": false,\n          \"text\": \"```\"\n        },\n        {\n          \"id\": 8456,\n          \"logprob\": -0.41918945,\n          \"special\": false,\n          \"text\": \"json\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.003189087,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.061187744,\n          \"special\": false,\n          \"text\": \"{\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.009010315,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50276,\n          \"logprob\": -0.484375,\n          \"special\": false,\n          \"text\": \"  \"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0002951622,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \":\\n\\n```json\\n{\\n  \\\"\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 27,\n          \"logprob\": -1.3457031,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.453125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.4111328,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 11202,\n          \"logprob\": -0.45898438,\n          \"special\": false,\n          \"text\": \"```\"\n        },\n        {\n          \"id\": 8456,\n          \"logprob\": -0.41918945,\n          \"special\": false,\n          \"text\": \"json\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.003189087,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.061187744,\n          \"special\": false,\n          \"text\": \"{\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.009010315,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50276,\n          \"logprob\": -0.484375,\n          \"special\": false,\n          \"text\": \"  \"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0002951622,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \":\\n\\n```json\\n{\\n  \\\"\"\n  },\n  {\n    \"details\": {\n      \"best_of_sequences\": null,\n      \"finish_reason\": \"length\",\n      \"generated_tokens\": 10,\n      \"prefill\": [],\n      \"seed\": null,\n      \"tokens\": [\n        {\n          \"id\": 27,\n          \"logprob\": -1.3457031,\n          \"special\": false,\n          \"text\": \":\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.453125,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -1.4111328,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 11202,\n          \"logprob\": -0.45898438,\n          \"special\": false,\n          \"text\": \"```\"\n        },\n        {\n          \"id\": 8456,\n          \"logprob\": -0.41918945,\n          \"special\": false,\n          \"text\": \"json\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.003189087,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 92,\n          \"logprob\": -0.061187744,\n          \"special\": false,\n          \"text\": \"{\"\n        },\n        {\n          \"id\": 187,\n          \"logprob\": -0.009010315,\n          \"special\": false,\n          \"text\": \"\\n\"\n        },\n        {\n          \"id\": 50276,\n          \"logprob\": -0.484375,\n          \"special\": false,\n          \"text\": \"  \"\n        },\n        {\n          \"id\": 3,\n          \"logprob\": -0.0002951622,\n          \"special\": false,\n          \"text\": \"\\\"\"\n        }\n      ],\n      \"top_tokens\": null\n    },\n    \"generated_text\": \":\\n\\n```json\\n{\\n  \\\"\"\n  }\n]\n"
  },
  {
    "path": "integration-tests/models/__snapshots__/test_transformers_olmo/test_flash_llama_simple.json",
    "content": "{\n  \"details\": {\n    \"best_of_sequences\": null,\n    \"finish_reason\": \"length\",\n    \"generated_tokens\": 10,\n    \"prefill\": [],\n    \"seed\": null,\n    \"tokens\": [\n      {\n        \"id\": 27,\n        \"logprob\": -1.3457031,\n        \"special\": false,\n        \"text\": \":\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -1.4580078,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -1.4472656,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 11202,\n        \"logprob\": -0.46044922,\n        \"special\": false,\n        \"text\": \"```\"\n      },\n      {\n        \"id\": 8456,\n        \"logprob\": -0.4206543,\n        \"special\": false,\n        \"text\": \"json\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -0.0031471252,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 92,\n        \"logprob\": -0.061187744,\n        \"special\": false,\n        \"text\": \"{\"\n      },\n      {\n        \"id\": 187,\n        \"logprob\": -0.009033203,\n        \"special\": false,\n        \"text\": \"\\n\"\n      },\n      {\n        \"id\": 50276,\n        \"logprob\": -0.48461914,\n        \"special\": false,\n        \"text\": \"  \"\n      },\n      {\n        \"id\": 3,\n        \"logprob\": -0.0002901554,\n        \"special\": false,\n        \"text\": \"\\\"\"\n      }\n    ],\n    \"top_tokens\": null\n  },\n  \"generated_text\": \":\\n\\n```json\\n{\\n  \\\"\"\n}\n"
  },
  {
    "path": "integration-tests/models/test_bloom_560m.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef bloom_560_handle(launcher):\n    with launcher(\"bigscience/bloom-560m\", num_shard=1) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def bloom_560(bloom_560_handle):\n    await bloom_560_handle.health(240)\n    return bloom_560_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_bloom_560m(bloom_560, response_snapshot):\n    response = await bloom_560.generate(\n        \"Pour déguster un ortolan, il faut tout d'abord\",\n        max_new_tokens=10,\n        top_p=0.9,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_bloom_560m_all_params(bloom_560, response_snapshot):\n    response = await bloom_560.generate(\n        \"Pour déguster un ortolan, il faut tout d'abord\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):\n    responses = await generate_load(\n        bloom_560,\n        \"Pour déguster un ortolan, il faut tout d'abord\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_bloom_560m_sharded.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef bloom_560m_sharded_handle(launcher):\n    with launcher(\"bigscience/bloom-560m\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def bloom_560m_sharded(bloom_560m_sharded_handle):\n    await bloom_560m_sharded_handle.health(240)\n    return bloom_560m_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):\n    response = await bloom_560m_sharded.generate(\n        \"Pour déguster un ortolan, il faut tout d'abord\",\n        max_new_tokens=10,\n        top_p=0.9,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_bloom_560m_sharded_load(\n    bloom_560m_sharded, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        bloom_560m_sharded,\n        \"Pour déguster un ortolan, il faut tout d'abord\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_chat_llama.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_chat_handle(launcher):\n    with launcher(\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\", num_shard=2, disable_grammar_support=False\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_chat(flash_llama_chat_handle):\n    await flash_llama_chat_handle.health(300)\n    return flash_llama_chat_handle.client\n\n\n@pytest.mark.private\nasync def test_flash_llama_simple(flash_llama_chat, response_snapshot):\n    response = await flash_llama_chat.chat(\n        max_tokens=100,\n        seed=1,\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n    )\n\n    print(repr(response.choices[0].message.content))\n    assert (\n        response.choices[0].message.content\n        == \"As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\\n\\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C\"\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_chat_stream_options.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef chat_handle(launcher):\n    with launcher(\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def chat_client(chat_handle):\n    await chat_handle.health(300)\n    return chat_handle.client\n"
  },
  {
    "path": "integration-tests/models/test_completion_prompts.py",
    "content": "import pytest\nimport requests\nfrom openai import OpenAI\nfrom huggingface_hub import InferenceClient\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_completion_handle(launcher):\n    with launcher(\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_completion(flash_llama_completion_handle):\n    await flash_llama_completion_handle.health(300)\n    return flash_llama_completion_handle.client\n\n\n# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience\n# method for it. Instead, we use the `requests` library to make the HTTP request directly.\n\n\n@pytest.mark.release\ndef test_flash_llama_completion_single_prompt(\n    flash_llama_completion, response_snapshot\n):\n    response = requests.post(\n        f\"{flash_llama_completion.base_url}/v1/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"prompt\": \"What is Deep Learning?\",\n            \"max_tokens\": 10,\n            \"temperature\": 0.0,\n        },\n        headers=flash_llama_completion.headers,\n        stream=False,\n    )\n    response = response.json()\n    assert len(response[\"choices\"]) == 1\n    assert (\n        response[\"choices\"][0][\"text\"]\n        == \" A Beginner’s Guide\\nDeep learning is a subset\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\nasync def test_flash_llama_completion_stream_usage(\n    flash_llama_completion, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_completion.base_url}/v1\")\n    stream = client.chat_completion(\n        model=\"tgi\",\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": \"What is Deep Learning?\",\n            }\n        ],\n        max_tokens=10,\n        temperature=0.0,\n        stream_options={\"include_usage\": True},\n        stream=True,\n    )\n    string = \"\"\n    chunks = []\n    had_usage = False\n    for chunk in stream:\n        # remove \"data:\"\n        chunks.append(chunk)\n        if len(chunk.choices) == 1:\n            index = chunk.choices[0].index\n            assert index == 0\n            string += chunk.choices[0].delta.content\n        if chunk.usage:\n            assert not had_usage\n            had_usage = True\n\n    assert had_usage\n    assert (\n        string\n        == \"**Deep Learning: An Overview**\\n=====================================\\n\\n\"\n    )\n    assert chunks == response_snapshot\n\n    stream = client.chat_completion(\n        model=\"tgi\",\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": \"What is Deep Learning?\",\n            }\n        ],\n        max_tokens=10,\n        temperature=0.0,\n        # No usage\n        # stream_options={\"include_usage\": True},\n        stream=True,\n    )\n    string = \"\"\n    chunks = []\n    had_usage = False\n    for chunk in stream:\n        chunks.append(chunk)\n        assert chunk.usage is None\n        assert len(chunk.choices) == 1\n        assert chunk.choices[0].index == 0\n        string += chunk.choices[0].delta.content\n    assert (\n        string\n        == \"**Deep Learning: An Overview**\\n=====================================\\n\\n\"\n    )\n\n\n@pytest.mark.release\ndef test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):\n    response = requests.post(\n        f\"{flash_llama_completion.base_url}/v1/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"prompt\": [\n                \"What is Deep Learning?\",\n                \"Is water wet?\",\n                \"What is the capital of France?\",\n                \"def mai\",\n            ],\n            \"max_tokens\": 10,\n            \"seed\": 0,\n            \"temperature\": 0.0,\n        },\n        headers=flash_llama_completion.headers,\n        stream=False,\n    )\n    response = response.json()\n    assert len(response[\"choices\"]) == 4\n\n    all_indexes = [(choice[\"index\"], choice[\"text\"]) for choice in response[\"choices\"]]\n    all_indexes.sort()\n    all_indices, all_strings = zip(*all_indexes)\n    assert list(all_indices) == [0, 1, 2, 3]\n    assert list(all_strings) == [\n        \" A Beginner’s Guide\\nDeep learning is a subset\",\n        \" This is a question that has puzzled many people for\",\n        \" Paris\\nWhat is the capital of France?\\nThe\",\n        'usculas_minusculas(s):\\n    \"\"\"\\n',\n    ]\n\n    assert response == response_snapshot\n\n\n@pytest.mark.release\nasync def test_flash_llama_completion_many_prompts_stream(\n    flash_llama_completion, response_snapshot\n):\n    client = OpenAI(api_key=\"xx\", base_url=f\"{flash_llama_completion.base_url}/v1\")\n    stream = client.completions.create(\n        model=\"tgi\",\n        prompt=[\n            \"What is Deep Learning?\",\n            \"Is water wet?\",\n            \"What is the capital of France?\",\n            \"def mai\",\n        ],\n        max_tokens=10,\n        seed=0,\n        temperature=0.0,\n        stream=True,\n    )\n\n    strings = [\"\"] * 4\n    chunks = []\n    for chunk in stream:\n        chunks.append(chunk)\n        index = chunk.choices[0].index\n        assert 0 <= index <= 4\n        strings[index] += chunk.choices[0].text\n\n    assert list(strings) == [\n        \" A Beginner’s Guide\\nDeep learning is a subset\",\n        \" This is a question that has puzzled many people for\",\n        \" Paris\\nWhat is the capital of France?\\nThe\",\n        'usculas_minusculas(s):\\n    \"\"\"\\n',\n    ]\n    assert chunks == response_snapshot\n\n\n@pytest.mark.release\nasync def test_chat_openai_usage(flash_llama_completion, response_snapshot):\n    client = OpenAI(api_key=\"xx\", base_url=f\"{flash_llama_completion.base_url}/v1\")\n\n    stream = client.chat.completions.create(\n        model=\"tgi\",\n        messages=[{\"role\": \"user\", \"content\": \"Say 'OK!'\"}],\n        stream=True,\n        max_tokens=10,\n        seed=42,\n        stream_options={\"include_usage\": True},\n    )\n\n    chunks = []\n    for chunk in stream:\n        chunks.append(chunk)\n    for chunk in chunks[:-1]:\n        assert chunk.usage is None\n    for chunk in chunks[-1:]:\n        assert chunk.usage is not None\n\n    assert chunks == response_snapshot\n\n\n@pytest.mark.release\nasync def test_chat_openai_nousage(flash_llama_completion, response_snapshot):\n    client = OpenAI(api_key=\"xx\", base_url=f\"{flash_llama_completion.base_url}/v1\")\n\n    stream = client.chat.completions.create(\n        model=\"tgi\",\n        messages=[{\"role\": \"user\", \"content\": \"Say 'OK!'\"}],\n        stream=True,\n        max_tokens=10,\n        seed=42,\n        stream_options={\"include_usage\": False},\n    )\n\n    chunks = []\n    for chunk in stream:\n        assert chunk.usage is None\n        chunks.append(chunk)\n\n    assert chunks == response_snapshot\n\n\n@pytest.mark.release\nasync def test_chat_hfhub_usage(flash_llama_completion, response_snapshot):\n    client = InferenceClient(base_url=f\"{flash_llama_completion.base_url}/v1\")\n    stream = client.chat_completion(\n        model=\"tgi\",\n        messages=[{\"role\": \"user\", \"content\": \"Say 'OK!'\"}],\n        stream=True,\n        max_tokens=10,\n        seed=42,\n        stream_options={\"include_usage\": True},\n    )\n\n    chunks = []\n    for chunk in stream:\n        chunks.append(chunk)\n\n    for chunk in chunks[:-1]:\n        assert chunk.usage is None\n    for chunk in chunks[-1:]:\n        assert chunk.usage is not None\n\n    assert chunks == response_snapshot\n\n\n@pytest.mark.release\nasync def test_chat_hfhub_nousage(flash_llama_completion, response_snapshot):\n    client = InferenceClient(base_url=f\"{flash_llama_completion.base_url}/v1\")\n    stream = client.chat_completion(\n        model=\"tgi\",\n        messages=[{\"role\": \"user\", \"content\": \"Say 'OK!'\"}],\n        stream=True,\n        max_tokens=10,\n        seed=42,\n        stream_options={\"include_usage\": False},\n    )\n\n    chunks = []\n    for chunk in stream:\n        assert chunk.usage is None\n        chunks.append(chunk)\n\n    assert chunks == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_compressed_tensors_w8a8_int.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef compressed_tensors_w8a8_int_handle(launcher):\n    with launcher(\n        \"neuralmagic/Llama-3.2-3B-Instruct-quantized.w8a8\",\n        num_shard=2,\n        quantize=\"compressed-tensors\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def compressed_tensors_w8a8_int(compressed_tensors_w8a8_int_handle):\n    await compressed_tensors_w8a8_int_handle.health(300)\n    return compressed_tensors_w8a8_int_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int(\n    compressed_tensors_w8a8_int, response_snapshot\n):\n    response = await compressed_tensors_w8a8_int.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert (\n        response.generated_text\n        == \" and how does it differ from traditional machine learning?\\n\"\n    )\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int_all_params(\n    compressed_tensors_w8a8_int, response_snapshot\n):\n    response = await compressed_tensors_w8a8_int.generate(\n        \"What is deep learning\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep learning, also known as neural network or\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int_load(\n    compressed_tensors_w8a8_int, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        compressed_tensors_w8a8_int,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert (\n        responses[0].generated_text\n        == \" and how does it differ from traditional machine learning?\\n\"\n    )\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_compressed_tensors_w8a8_int_dynamic_weight.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef compressed_tensors_w8a8_int_dynamic_weight_handle(launcher):\n    with launcher(\n        \"danieldk/Qwen2.5-1.5B-Instruct-w8a8-int-dynamic-weight\",\n        num_shard=2,\n        quantize=\"compressed-tensors\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def compressed_tensors_w8a8_int_dynamic_weight(\n    compressed_tensors_w8a8_int_dynamic_weight_handle,\n):\n    await compressed_tensors_w8a8_int_dynamic_weight_handle.health(300)\n    return compressed_tensors_w8a8_int_dynamic_weight_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int_dynamic_weight(\n    compressed_tensors_w8a8_int_dynamic_weight, response_snapshot\n):\n    response = await compressed_tensors_w8a8_int_dynamic_weight.generate(\n        \"What is deep learning?\",\n        # prefer a longer response than the default, allow the llm to end generation\n        max_new_tokens=1000,\n        decoder_input_details=True,\n    )\n\n    assert (\n        response.generated_text\n        == \" Deep learning is a subset of machine learning that uses neural networks to learn from data. It is a type of artificial intelligence that can learn from and make predictions on large amounts of data. Deep learning is used in a variety of applications, including image and speech recognition, natural language processing, and autonomous vehicles. It is a rapidly growing field with many potential applications in the future.\"\n    )\n    assert response.details.generated_tokens == 76\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int_dynamic_weight_all_params(\n    compressed_tensors_w8a8_int_dynamic_weight, response_snapshot\n):\n    response = await compressed_tensors_w8a8_int_dynamic_weight.generate(\n        \"What is deep learning\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep Learning (DL), or artificial neural networks\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8a8_int_dynamic_weight_load(\n    compressed_tensors_w8a8_int_dynamic_weight, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        compressed_tensors_w8a8_int_dynamic_weight,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert (\n        responses[0].generated_text\n        == \" Deep learning is a subset of machine learning that uses\"\n    )\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_compressed_tensors_w8an_fp.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef compressed_tensors_w8an_handle(launcher):\n    with launcher(\n        \"neuralmagic/Llama-3.2-1B-Instruct-FP8\",\n        num_shard=2,\n        quantize=\"compressed-tensors\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def compressed_tensors_w8an(compressed_tensors_w8an_handle):\n    await compressed_tensors_w8an_handle.health(300)\n    return compressed_tensors_w8an_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8an(compressed_tensors_w8an, response_snapshot):\n    response = await compressed_tensors_w8an.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert (\n        response.generated_text\n        == \" Deep learning is a type of artificial intelligence (AI\"\n    )\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_compressed_tensors_w8an_all_params(\n    compressed_tensors_w8an, response_snapshot\n):\n    response = await compressed_tensors_w8an.generate(\n        \"What is deep learning\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep learning, also known as neural network or\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_w8an_load(\n    compressed_tensors_w8an, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        compressed_tensors_w8an,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert (\n        responses[0].generated_text\n        == \" Deep learning is a type of artificial intelligence (AI\"\n    )\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_compressed_tensors_wna16_int.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef compressed_tensors_wna16_handle(launcher):\n    with launcher(\n        \"neuralmagic/gemma-2-2b-it-quantized.w4a16\",\n        num_shard=2,\n        quantize=\"compressed-tensors\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def compressed_tensors_wna16(compressed_tensors_wna16_handle):\n    await compressed_tensors_wna16_handle.health(300)\n    return compressed_tensors_wna16_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_wna16(compressed_tensors_wna16, response_snapshot):\n    response = await compressed_tensors_wna16.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert (\n        response.generated_text\n        == \"\\n\\nDeep learning is a subset of machine learning that\"\n    )\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_compressed_tensors_wna16_all_params(\n    compressed_tensors_wna16, response_snapshot\n):\n    response = await compressed_tensors_wna16.generate(\n        \"What is deep learning\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\n\\nDeep Learning is a subset of machine learning\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_wna16_load(\n    compressed_tensors_wna16, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        compressed_tensors_wna16,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert (\n        responses[0].generated_text\n        == \"\\n\\nDeep learning is a subset of machine learning that\"\n    )\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_compressed_tensors_wna16_int_24.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef compressed_tensors_wna16_int_24_handle(launcher):\n    with launcher(\n        \"danieldk/Llama-3.1-8B-w4a16-int-24\",\n        num_shard=2,\n        quantize=\"compressed-tensors\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def compressed_tensors_wna16_int_24(compressed_tensors_wna16_int_24_handle):\n    await compressed_tensors_wna16_int_24_handle.health(300)\n    return compressed_tensors_wna16_int_24_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_wna16_int_24(\n    compressed_tensors_wna16_int_24, response_snapshot\n):\n    response = await compressed_tensors_wna16_int_24.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert (\n        response.generated_text\n        == \"Deep learning is a subset of machine learning that uses\"\n    )\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_wna16_int_24_all_params(\n    compressed_tensors_wna16_int_24, response_snapshot\n):\n    response = await compressed_tensors_wna16_int_24.generate(\n        \"What is deep learning\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep learning (DL) is a subset of\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_compressed_tensors_wna16_int_24_load(\n    compressed_tensors_wna16_int_24, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        compressed_tensors_wna16_int_24,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert (\n        responses[0].generated_text\n        == \"Deep learning is a subset of machine learning that uses\"\n    )\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_continue_final_message.py",
    "content": "import pytest\nimport requests\n\n\n@pytest.fixture(scope=\"module\")\ndef llama_continue_final_message_handle(launcher):\n    with launcher(\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def llama_continue_final_message(llama_continue_final_message_handle):\n    await llama_continue_final_message_handle.health(300)\n    return llama_continue_final_message_handle.client\n\n\ndef test_llama_completion_single_prompt(\n    llama_continue_final_message, response_snapshot\n):\n    response = requests.post(\n        f\"{llama_continue_final_message.base_url}/v1/chat/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"system message\"},\n                {\"role\": \"user\", \"content\": \"Which is bigger an elephant or a mouse?\"},\n            ],\n            \"max_tokens\": 30,\n            \"stream\": False,\n            \"seed\": 1337,\n        },\n        headers=llama_continue_final_message.headers,\n        stream=False,\n    )\n    response = response.json()\n    print(response)\n    assert len(response[\"choices\"]) == 1\n    content = response[\"choices\"][0][\"message\"][\"content\"]\n    assert (\n        content\n        == \"Both an elephant and a mouse are mammals. However, the differences between elephants and mice are:\\n\\n1\"\n    )\n    assert response == response_snapshot\n\n\ndef test_llama_completion_single_prompt_continue(\n    llama_continue_final_message, response_snapshot\n):\n    response = requests.post(\n        f\"{llama_continue_final_message.base_url}/v1/chat/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"system message\"},\n                {\"role\": \"user\", \"content\": \"Which is bigger an elephant or a mouse?\"},\n                {\n                    \"role\": \"assistant\",\n                    \"content\": \"the elephant, but have you heard about\",\n                },\n            ],\n            \"max_tokens\": 30,\n            \"stream\": False,\n            \"seed\": 1337,\n        },\n        headers=llama_continue_final_message.headers,\n        stream=False,\n    )\n    response = response.json()\n    print(response)\n    assert len(response[\"choices\"]) == 1\n    content = response[\"choices\"][0][\"message\"][\"content\"]\n    assert (\n        content\n        == \" the royal mouse? It is a little more slender and only weighs around 1.5 pounds for males and 1.3 pounds\"\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_awq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_awq_handle(launcher):\n    with launcher(\n        \"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq\",\n        num_shard=1,\n        quantize=\"awq\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_awq(flash_llama_awq_handle):\n    await flash_llama_awq_handle.health(300)\n    return flash_llama_awq_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_llama_awq(flash_llama_awq, response_snapshot):\n    response = await flash_llama_awq.generate(\n        \"What is Deep Learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"\\nWhat is the difference between Deep Learning and Machine\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):\n    response = await flash_llama_awq.generate(\n        \"What is Deep Learning?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_llama_awq, \"What is Deep Learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all(\n        [\n            r.generated_text\n            == \"\\nWhat is the difference between Deep Learning and Machine\"\n            for r in responses\n        ]\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_awq_sharded.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_awq_handle_sharded(launcher):\n    with launcher(\n        \"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq\",\n        num_shard=2,\n        quantize=\"awq\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):\n    await flash_llama_awq_handle_sharded.health(300)\n    return flash_llama_awq_handle_sharded.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):\n    response = await flash_llama_awq_sharded.generate(\n        \"What is Deep Learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"\\nWhat is the difference between Deep Learning and Machine\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_llama_awq_load_sharded(\n    flash_llama_awq_sharded, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_awq_sharded, \"What is Deep Learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all(\n        [\n            r.generated_text\n            == \"\\nWhat is the difference between Deep Learning and Machine\"\n            for r in responses\n        ]\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_deepseek_v2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_deepseek_v2_handle(launcher):\n    with launcher(\"deepseek-ai/DeepSeek-V2-Lite\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_deepseek_v2(flash_deepseek_v2_handle):\n    await flash_deepseek_v2_handle.health(300)\n    return flash_deepseek_v2_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot):\n    response = await flash_deepseek_v2.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot):\n    response = await flash_deepseek_v2.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_deepseek_v2_load(\n    flash_deepseek_v2, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_deepseek_v2, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_falcon.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_falcon_handle(launcher):\n    with launcher(\"tiiuae/falcon-7b\", trust_remote_code=True) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_falcon(flash_falcon_handle):\n    await flash_falcon_handle.health(300)\n    return flash_falcon_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_falcon(flash_falcon, response_snapshot):\n    response = await flash_falcon.generate(\n        \"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\\nDaniel: Hello, Girafatron!\\nGirafatron:\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_falcon_all_params(flash_falcon, response_snapshot):\n    response = await flash_falcon.generate(\n        \"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\\nDaniel: Hello, Girafatron!\\nGirafatron:\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_falcon,\n        \"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\\nDaniel: Hello, Girafatron!\\nGirafatron:\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_gemma.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_gemma_handle(launcher):\n    with launcher(\"google/gemma-2b\", num_shard=1) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_gemma(flash_gemma_handle):\n    await flash_gemma_handle.health(300)\n    return flash_gemma_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_simple(flash_gemma, response_snapshot):\n    response = await flash_gemma.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_all_params(flash_gemma, response_snapshot):\n    response = await flash_gemma.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):\n    responses = await generate_load(flash_gemma, \"Test request\", max_new_tokens=10, n=4)\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_gemma2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_gemma2_handle(launcher):\n    with launcher(\"google/gemma-2-9b-it\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_gemma2(flash_gemma2_handle):\n    await flash_gemma2_handle.health(300)\n    return flash_gemma2_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma2(flash_gemma2, response_snapshot):\n    response = await flash_gemma2.generate(\n        \"<start_of_turn>user:\\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\\n<start_of_turn>model:\\n\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.generated_text == \"**Hydrogen**, light and free,\\n**He\"\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma2_load(flash_gemma2, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_gemma2,\n        \"<start_of_turn>user:\\nWrite a poem to help me remember the first 10 elements on the periodic table, giving each element its own line.<end_of_turn>\\n<start_of_turn>model:\\n\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert responses[0].generated_text == \"**Hydrogen**, light and free,\\n**He\"\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_gemma3.py",
    "content": "import base64\nfrom io import BytesIO\nfrom PIL import Image\n\nimport pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_gemma3_handle(launcher):\n    with launcher(\"google/gemma-3-4b-it\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_gemma3(flash_gemma3_handle):\n    await flash_gemma3_handle.health(300)\n    return flash_gemma3_handle.client\n\n\nasync def test_flash_gemma3(flash_gemma3, response_snapshot):\n    response = await flash_gemma3.generate(\n        \"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many\",\n        seed=42,\n        max_new_tokens=100,\n    )\n\n    assert (\n        response.generated_text\n        == \" people died in the United States.\\n\\nThe generally accepted estimate is that 675,000 people died in the United States. However, some historians believe the actual number could be as high as 10 million.\\n\\nI am looking for more information on this discrepancy and the factors that contributed to the wide range of estimates.\\n\\nHere's a breakdown of the factors contributing to the wide range of estimates for the 1918 flu pandemic death toll in the United States\"\n    )\n    assert response.details.generated_tokens == 100\n    assert response == response_snapshot\n\n\nasync def test_flash_gemma3_image_cow_dog(flash_gemma3, response_snapshot):\n    image_url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png\"\n    response = await flash_gemma3.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"What is the breed of the dog in the image?\",\n                    },\n                ],\n            },\n        ],\n        max_tokens=100,\n    )\n\n    assert (\n        response.choices[0].message.content\n        == \"That's a fantastic question! However, the image doesn't show a dog. It shows a **Brown Swiss cow** standing on a beach. \\n\\nBrown Swiss cows are known for their beautiful reddish-brown coats and distinctive white markings. \\n\\nIf you'd like, you can send me another image, and I'll do my best to identify the animal in it!\"\n    )\n    assert response.usage[\"completion_tokens\"] == 80\n    assert response == response_snapshot\n\n\nasync def test_flash_gemma3_image_cow(flash_gemma3, response_snapshot):\n    image_url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png\"\n    response = await flash_gemma3.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n                    {\"type\": \"text\", \"text\": \"What is shown in this image?\"},\n                ],\n            },\n        ],\n        max_tokens=100,\n    )\n    assert (\n        response.choices[0].message.content\n        == \"Here's a description of what's shown in the image:\\n\\nThe image depicts a brown cow standing on a sandy beach. The beach has turquoise water and a distant island visible in the background. The sky is bright blue with some white clouds. \\n\\nIt's a quite a humorous and unusual scene – a cow enjoying a beach day!\"\n    )\n    assert response.usage[\"completion_tokens\"] == 72\n    assert response == response_snapshot\n\n\nasync def test_exceed_window(flash_gemma3, response_snapshot):\n    response = await flash_gemma3.generate(\n        \"This is a nice place. \" * 800 + \"I really enjoy the scenery,\",\n        seed=42,\n        max_new_tokens=20,\n    )\n\n    assert (\n        response.generated_text\n        == \" the people, and the food.\\n\\nThis is a nice place.\\n\"\n    )\n    assert response.details.generated_tokens == 16\n    assert response == response_snapshot\n\n\n# Helper function to convert a Pillow image to a base64 data URL\ndef image_to_data_url(img: Image.Image, fmt: str) -> str:\n    buffer = BytesIO()\n    img.save(buffer, format=fmt)\n    img_data = buffer.getvalue()\n    b64_str = base64.b64encode(img_data).decode(\"utf-8\")\n    mime_type = \"image/png\" if fmt.upper() == \"PNG\" else \"image/jpeg\"\n    return f\"data:{mime_type};base64,{b64_str}\"\n\n\nasync def test_flash_gemma3_image_base64_rgba(flash_gemma3, response_snapshot):\n    # Create an empty 100x100 PNG image with alpha (transparent background)\n    img = Image.new(\"RGBA\", (100, 100), (0, 0, 0, 0))\n    data_url = image_to_data_url(img, \"PNG\")\n    response = await flash_gemma3.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"What do you see in this transparent image?\",\n                    },\n                ],\n            },\n        ],\n        max_tokens=100,\n    )\n    assert response == response_snapshot\n\n\nasync def test_flash_gemma3_image_base64_rgb_png(flash_gemma3, response_snapshot):\n    # Create an empty 100x100 PNG image without alpha (white background)\n    img = Image.new(\"RGB\", (100, 100), (255, 255, 255))\n    data_url = image_to_data_url(img, \"PNG\")\n    response = await flash_gemma3.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n                    {\"type\": \"text\", \"text\": \"What do you see in this plain image?\"},\n                ],\n            },\n        ],\n        max_tokens=100,\n    )\n    assert response == response_snapshot\n\n\nasync def test_flash_gemma3_image_base64_rgb_jpg(flash_gemma3, response_snapshot):\n    # Create an empty 100x100 JPEG image (white background)\n    img = Image.new(\"RGB\", (100, 100), (255, 255, 255))\n    data_url = image_to_data_url(img, \"JPEG\")\n    response = await flash_gemma3.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n                    {\"type\": \"text\", \"text\": \"What do you see in this JPEG image?\"},\n                ],\n            },\n        ],\n        max_tokens=100,\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_gemma_gptq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_gemma_gptq_handle(launcher):\n    with launcher(\"TechxGenus/gemma-2b-GPTQ\", num_shard=1, quantize=\"gptq\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_gemma_gptq(flash_gemma_gptq_handle):\n    await flash_gemma_gptq_handle.health(300)\n    return flash_gemma_gptq_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):\n    response = await flash_gemma_gptq.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == ignore_logprob_response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_gptq_all_params(\n    flash_gemma_gptq, ignore_logprob_response_snapshot\n):\n    response = await flash_gemma_gptq.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == ignore_logprob_response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_gemma_gptq_load(\n    flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot\n):\n    responses = await generate_load(\n        flash_gemma_gptq, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == ignore_logprob_response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_gpt2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_gpt2_handle(launcher):\n    with launcher(\"openai-community/gpt2\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_gpt2(flash_gpt2_handle):\n    await flash_gpt2_handle.health(300)\n    return flash_gpt2_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_gpt2(flash_gpt2, response_snapshot):\n    response = await flash_gpt2.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_gpt2,\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    generated_texts = [r.generated_text for r in responses]\n\n    assert len(generated_texts) == 4\n    assert all(\n        [text == generated_texts[0] for text in generated_texts]\n    ), generated_texts\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_grammar_llama.py",
    "content": "import pytest\nimport json\n\nfrom text_generation.types import GrammarType\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_grammar_handle(launcher):\n    with launcher(\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\", num_shard=2, disable_grammar_support=False\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_grammar(flash_llama_grammar_handle):\n    await flash_llama_grammar_handle.health(300)\n    return flash_llama_grammar_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):\n    response = await flash_llama_grammar.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):\n    response = await flash_llama_grammar.generate(\n        \"Whats Googles DNS\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n        seed=0,\n        grammar={\n            \"type\": GrammarType.Regex,  # \"regex\"\n            \"value\": \"((25[0-5]|2[0-4]\\\\d|[01]?\\\\d\\\\d?)\\\\.){3}(25[0-5]|2[0-4]\\\\d|[01]?\\\\d\\\\d?)\",\n        },\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == \"42.1.1.101\"\n    assert response == response_snapshot\n\n\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):\n    response = await flash_llama_grammar.generate(\n        \"info: david holtz like trees and has two cats. \",\n        max_new_tokens=100,\n        decoder_input_details=True,\n        seed=0,\n        grammar={\n            \"type\": GrammarType.Json,  # \"json\"\n            \"value\": json.dumps(\n                {\n                    \"type\": \"object\",\n                    \"$id\": \"https://example.com/person.schema.json\",\n                    \"$schema\": \"https://json-schema.org/draft/2020-12/schema\",\n                    \"title\": \"Person\",\n                    \"properties\": {\n                        \"firstName\": {\n                            \"type\": \"string\",\n                            \"description\": \"The person'''s first name.\",\n                        },\n                        \"lastName\": {\n                            \"type\": \"string\",\n                            \"description\": \"The person'''s last name.\",\n                        },\n                        \"hobby\": {\n                            \"description\": \"The person'''s hobby.\",\n                            \"type\": \"string\",\n                        },\n                        \"numCats\": {\n                            \"description\": \"The number of cats the person has.\",\n                            \"type\": \"integer\",\n                            \"minimum\": 0,\n                        },\n                    },\n                    \"required\": [\"firstName\", \"lastName\", \"hobby\", \"numCats\"],\n                }\n            ),\n        },\n    )\n\n    assert response.details.generated_tokens == 30\n    assert (\n        response.generated_text\n        == '{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}'\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_llama_grammar_load(\n    flash_llama_grammar, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_grammar,\n        \"name: david. email:  \",\n        max_new_tokens=10,\n        n=4,\n        stop_sequences=[\".com\"],\n        seed=0,\n        grammar={\n            \"type\": GrammarType.Regex,  # \"regex\"\n            \"value\": \"[\\\\w-]+@([\\\\w-]+\\\\.)+[\\\\w-]+\",  # email regex\n        },\n    )\n\n    assert len(responses) == 4\n\n    expected = \"123456@gmail.com\"\n\n    for response in responses:\n        assert response.generated_text == expected\n\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n\n\n# this is the same as the above test, but only fires off a single request\n# this is only to ensure that the parallel and single inference produce the same result\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_llama_grammar_single_load_instance(\n    flash_llama_grammar, generate_load, response_snapshot\n):\n    response = await flash_llama_grammar.generate(\n        \"name: david. email:  \",\n        max_new_tokens=10,\n        stop_sequences=[\".com\"],\n        seed=0,\n        grammar={\n            \"type\": GrammarType.Regex,  # \"regex\"\n            \"value\": \"[\\\\w-]+@([\\\\w-]+\\\\.)+[\\\\w-]+\",  # email regex\n        },\n    )\n\n    # assert response.details.generated_tokens == 30\n    assert response.generated_text == \"123456@gmail.com\"\n\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_handle(launcher):\n    with launcher(\"huggyllama/llama-7b\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama(flash_llama_handle):\n    await flash_llama_handle.health(300)\n    return flash_llama_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_simple(flash_llama, response_snapshot):\n    response = await flash_llama.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_all_params(flash_llama, response_snapshot):\n    response = await flash_llama.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 5\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_load(flash_llama, generate_load, response_snapshot):\n    responses = await generate_load(flash_llama, \"Test request\", max_new_tokens=10, n=4)\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_exl2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_exl2_handle(launcher):\n    with launcher(\n        \"turboderp/Llama-3-8B-Instruct-exl2\",\n        revision=\"2.5bpw\",\n        # Set max input length to avoid OOM due to extremely large\n        # scratch buffer.\n        max_input_length=1024,\n        num_shard=1,\n        quantize=\"exl2\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_exl2(flash_llama_exl2_handle):\n    await flash_llama_exl2_handle.health(300)\n    return flash_llama_exl2_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):\n    response = await flash_llama_exl2.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == ignore_logprob_response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_exl2_all_params(\n    flash_llama_exl2, ignore_logprob_response_snapshot\n):\n    response = await flash_llama_exl2.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert (\n        response.generated_text == 'Test request. The server responds with a \"200 OK\"'\n    )\n    assert response == ignore_logprob_response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_exl2_load(\n    flash_llama_exl2, generate_load, ignore_logprob_response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_exl2, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == ignore_logprob_response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_fp8.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_fp8_handle(launcher):\n    with launcher(\"meta-llama/Meta-Llama-3-8B\", num_shard=2, quantize=\"fp8\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_fp8(flash_llama_fp8_handle):\n    await flash_llama_fp8_handle.health(300)\n    return flash_llama_fp8_handle.client\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):\n    response = await flash_llama_fp8.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.generated_text == \" for the 2019-2020 school year\"\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):\n    response = await flash_llama_fp8.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_llama_fp8, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert responses[0].generated_text == \" for the 2019-2020 school year\"\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"Different messages : {[r.generated_text for r in responses]}\"\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_fp8_kv_cache.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_fp8_kv_cache_handle(launcher):\n    with launcher(\n        \"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV\",\n        num_shard=2,\n        kv_cache_dtype=\"fp8_e4m3fn\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache_handle):\n    await flash_llama_fp8_kv_cache_handle.health(300)\n    return flash_llama_fp8_kv_cache_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8_kv_cache(flash_llama_fp8_kv_cache, response_snapshot):\n    response = await flash_llama_fp8_kv_cache.generate(\n        \"What is deep learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert (\n        response.generated_text\n        == \" Deep learning is a subset of machine learning that involves\"\n    )\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8_kv_cache_all_params(\n    flash_llama_fp8_kv_cache, response_snapshot\n):\n    response = await flash_llama_fp8_kv_cache.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_fp8_kv_cache_load(\n    flash_llama_fp8_kv_cache, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_fp8_kv_cache, \"What is deep learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert (\n        responses[0].generated_text\n        == \" Deep learning is a subset of machine learning that involves\"\n    )\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"Different messages : {[r.generated_text for r in responses]}\"\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_gptq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_gptq_handle(launcher):\n    with launcher(\n        \"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit\", num_shard=2, quantize=\"gptq\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_gptq(flash_llama_gptq_handle):\n    await flash_llama_gptq_handle.health(300)\n    return flash_llama_gptq_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):\n    response = await flash_llama_gptq.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):\n    response = await flash_llama_gptq.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_gptq_load(\n    flash_llama_gptq, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_gptq, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_marlin.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_marlin_handle(launcher):\n    with launcher(\n        \"neuralmagic/llama-2-7b-chat-marlin\", num_shard=2, quantize=\"marlin\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_marlin(flash_llama_marlin_handle):\n    await flash_llama_marlin_handle.health(300)\n    return flash_llama_marlin_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):\n    response = await flash_llama_marlin.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):\n    response = await flash_llama_marlin.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin_load(\n    flash_llama_marlin, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_marlin, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_marlin_24.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_marlin24_handle(launcher):\n    with launcher(\n        \"nm-testing/Llama-2-7b-pruned2.4-Marlin_24\", quantize=\"marlin\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_marlin(flash_llama_marlin24_handle):\n    await flash_llama_marlin24_handle.health(300)\n    return flash_llama_marlin24_handle.client\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):\n    response = await flash_llama_marlin.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot):\n    response = await flash_llama_marlin.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"Issue with the model access\")\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_marlin24_load(\n    flash_llama_marlin, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_llama_marlin, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_prefix.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_handle(launcher):\n    with launcher(\"meta-llama/Meta-Llama-3.1-8B-Instruct\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama(flash_llama_handle):\n    await flash_llama_handle.health(300)\n    return flash_llama_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_load(\n    flash_llama, generate_multi, generous_response_snapshot\n):\n    prompts = [\n        \"Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\",\n        \"How to tell if a customer segment is well segmented? In 3 bullet points.\",\n        'In Java, I want to replace string like \"This is a new {object} at {place}\" with a Map, {object: \"student\", \"point 3, 4\"}, and get a result \"This is a new student at point 3, 4\". How can I do?',\n        \"Metaphorical language is also used to describe the various addressing modes of the instructions. Grandiose language to express their excitement and admiration for the functionality of the instructions being described. Now, rewrite this with more perplexity:\\n\\nJMP ABCD\\nMOV AX, [BX+SI]\\nMOV AX, [100]\\nMOV AX, [BX]\\nMOV AX, [BX\\\\*2+SI]\\nMOV AX, BX\\nMOV AX, 7\",\n        'I have the following C++ function: \\nvoid add\\\\_player(vector& players)\\n{\\n string player\\\\_name;\\n string player\\\\_class;\\n string dummy;\\n PlayerClass pc;\\n string player\\\\_sex;\\n int player\\\\_gold;\\n\\n cout << \" Create a Mage, Warrior, Bowman, or Thief\" << endl;\\n\\n cout << \"Name: \";\\n getline(cin, player\\\\_name);\\n\\n cout << \"Class: \";\\n getline(cin, player\\\\_class);\\n pc = get\\\\_player\\\\_class\\\\_from\\\\_string(player\\\\_class);\\n while (pc == PlayerClass::InvalidPlayerClass)\\n {\\n cout << \" Invalid class, try again\" << endl;\\n cout << \"Class: \";\\n getline(cin, player\\\\_class);\\n pc = get\\\\_player\\\\_class\\\\_from\\\\_string(player\\\\_class);\\n }\\n\\n cout << \"Sex: \";\\n getline(cin, player\\\\_sex);\\n\\n cout << \"Gold: \";\\n cin >> player\\\\_gold;\\n getline(cin, dummy); //consume newline\\n\\n GamePlayer new\\\\_player;\\n new\\\\_player.name = player\\\\_name;\\n new\\\\_player.occupation = pc;\\n new\\\\_player.gender = player\\\\_sex;\\n new\\\\_player.gold = player\\\\_gold;\\n\\n //add to vector\\n players.push\\\\_back(new\\\\_player);\\n\\n //add to file\\n write\\\\_players\\\\_file(players);\\n}\\nCan you explain to me how the dummy variable is being used?',\n        \"how do I add multiple new columns in m for power query or power bi?\",\n        \"Sure, I can do that. What new technology would you like me to review?\",\n        \"Poly Ether Ether Ketone\",\n        'can you design a referral system similar on how dropbox did? I need a technical overview on how it should work, instead of free space we use the generic term \"credits\" where users can get more credits for every 3 friends they recommend.',\n        \"Java add to the arraylist of a class type\",\n        \"this is not less code this is java\",\n        \"I want to do a road trip from Pune to Gujarat. Me and my wife will be travelling and we dont prefer very long driving sessions. Can you suggest a plan starting from Thursday early morning and ending in Pune on Sunday late night.\",\n        \"explane more\",\n        \"what do you think about this for a start up idea:\",\n        \"how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?\",\n        \"# Import the necessary packages\\nfrom gudhi import SimplexTree\\nfrom gudhi.persistent\\\\_homology import PersistentHomology\\n\\n# Define a function to compute the persistent homology of a Minesweeper game board\\ndef minesweeper\\\\_homology(board):\\n # Create a simplicial complex for the game board\\n st = SimplexTree()\\n\\n # Add the points on the board to the simplicial complex\\n for i in range(len(board)):\\n for j in range(len(board[0])):\\n st.insert([i, j], filtration=board[i][j])\\n\\n # Compute the persistent homology of the game board\\n ph = PersistentHomology()\\n ph.build(st)\\n\\n # Return the persistent homology diagram\\n return ph.persistence()\\n\\n# Define a function to solve a Minesweeper game board using persistent homology\\ndef minesweeper\\\\_solver(board):\\n # Compute the persistent homology of the game board\\n homology = minesweeper\\\\_homology(board)\\n\\n # Use the persistent homology to determine the locations of the mines\\n # (this part would require some mathematical reasoning and programming)\\n mines = []\\n for h in homology:\\n if h[1] - h[0] == 1: # if the hole persists for one filtration value\\n mines.append(h[0]) # then it corresponds to a mine\\n\\n # Use the information about the mines to solve the game\\n # (this part would require some programming)\\n for mine in mines:\\n i, j = mine # extract the coordinates of the mine\\n board[i][j] = -1 # mark the mine on the board\\n # (other code to solve the game)\\n\\n \\nwhat is missing here?\",\n        \"You are now an imaginary expert business investigator. I am going to send you many rows of data. Each batch of row's will be sent to you and you may only reply \\\"Received.\\\" Save any analysis or insights for after you've received all of the data and I've told you \\\"Let's Begin.\\\" If you understand reply with only a ;)\",\n        'You are now an imaginary expert business investigator. Tell the story of this batch of data in the form of a narrative story about the companies in the \"Entity Name\" column: \\n\\nBatch of data #1: Entity Name Purpose / Source\\n101 PC HOLDINGS LLC Holding company for Penthouse C at the Setai Miami Beach (folio: 02-3234-153-1160)\\n11 STAR ISLAND LLC Holding company for 10 STAR ISLAND DR, MIAMI BEACH, FL 33139 (folio: 02-4204-001-0100, 02-4204-001-0110) (lots 10, 11 and 12 of Star Island)\\n117 EAST PARK AVENUE, LLC Holding company for 117 E. PARK AVE, LIBERTYVILLE, IL (PIN: 11-21-212-046-0000); subsequently sold.\\n1201 BRICKELL BAY, LLC Holding company for 1201 BRICKELL BAY DR, MIAMI, FL (folio no: 141390710010)\\n1221 BRICKELL, LLC Holding company for 1221 BRICKELL AVE, 155 SE 13 ST, 165 SE 13 ST, 175 SE 13 ST, and 185 SE 13 ST, MIAMI, FL (folio: 01-4139-035-0010)\\n1221 BRICKELL HOLDINGS LLC Holding company for 1221 BRICKELL, LLC\\n1229 PARK WEST AVENUE, LLC Holding company for 1229 W. PARK AVE, LIBERTYVILLE, IL (PIN: 11-20-100-010-0000)\\n125 WORTH LLC Delaware LLC (file 7218403), Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person; speculaton this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC, this property is next door (PCN: 50-43-43-23-05-016-0380)\\n125 WORTH HOLDINGS LLC Delaware LLC (file 7218407); not registered to Florida yet but speculation this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC\\n1250 BB ASSET CO LLC Holding company for 1250 BRICKELL BAY DR and 1260 BRICKELL BAY DR, MIAMI, FL (folio nos: 102100504250, 102100503210)\\n1330 SOUTH OCEAN LLC Holding company for 1330 S OCEAN BLVD, PALM BEACH, FL (PCN: 50-43-44-02-11-000-0020)\\n14 STAR ISLAND LLC Delaware LLC (file 3377653); incorporated 8/42020, withdrawn 10/10/2022; believe this was not used because 14 STAR ISLAND property was held by NAUTILUS HOLDINGS I LLC before sale on 10/5/2022\\n151 WORTH, LLC Holding company for 151 WORTH AVE, PALM BEACH, FL 33480 (PCN: 50-43-43-23-05-016-0130); office space for Citadel (https://localtoday.news/fl/citadel-moves-into-palm-beachs-former-neiman-marcus-building-4821.html); sole member is 151 WORTH HOLDINGS LLC\\n151 WORTH HOLDINGS LLC Holding company for 151 WORTH, LLC\\n16 WILLOW HOLDINGS LLC f/k/a PVNAH LLC Holding company for S WILLOW COURT, ASPEN, CO (Parcel: 273511309030); see Pitkin Co. reception # 623002, Delaware certificate showing name change 9/1/2015\\n190 PFISTER HOLDINGS LLC f/k/a AH2013 HOLDINGS LLC Holding company for 190 PFISTER DR, ASPEN, CO (parcel: 273511309029); see Pitkin Co.reception # 623000, Delaware certificate showing name change 9/1/2015\\n196 PFISTER HOLDINGS LLC Holding company for 196 PFISTER DR, ASPEN, CO (parcel: 273511309028); see Pitkin Co. reception # 623501, statement of authority show KP HOLDINGS LLC as sole membe\\n1ALPH LLC See ALPH LLC\\n1BUSINESS GROUP LLC See BUSINESS GROUP LLC\\n1GFS DESIGN LLC See GFS DESIGN LLC\\n1GFS LLC See GFS LLC\\n1MEDIA HOLDINGS LLC See MEDIA HOLDINGS LLC\\n23174 NE 41ST PATH LLC Holding company for 23174 NE 41ST PATH #12, OKEECHOBEE, FL 34972 (Parcel: 1-01-35-35-0020-00000-0120); part of Pine Creek Sporting Club (www.pinecreeksportingclub.com) includes horse, shooting sports; sole member is KP HOLDINGS L.L.C.\\n3031 BRICKELL LLC Holding company for 3031 BRICKELL AVE, MIAMI FL 33129 (Folio: 01-4139-001-2700); Sole member is KP HOLDINGS L.L.C.\\n31 WILLOW HOLDINGS LLC f/k/a AP HOLDINGS I LLC Holding company for 31 NORTH WILLOW COURT, ASPEN, CO (Parcel: 273511309019); sold 7/6/2017; see Pitkin Co. reception # 623001, Delaware certificate showing name change 9/1/2015\\n650 CASUARINA LLC Holding company for 650 CASUARINA CONCOURSE CORAL GABLES, FL (folio: 03-4132-019-0060) https://www.bizjournals.com/southflorida/news/2022/05/27/650-casuarina-concourse-coral-gables-sold.html\\n650 MEADOW LANE 1 LP Holding company for 650 MEADOW LANE, VILLAGE OF SOUTHAMPTON, NY (Parcel ID 7478) (https://archive.is/h85yq)\\n800 NORTH MICHIGAN HOLDINGS LLC Holding company for 800 N MICHIGAN AVE, UNITS 66 PH and 67 PH, CHICAGO, IL (Park Tower) (PINs: 17-03-231-018-1116, 17-03-231-018-1117); sole member is KP HOLDINGS LLC (see Cook County, IL doc # 1933315025); recently sold\\n8565 OLD CUTLER LLC Holding company for 8565 OLD CUTLER RD, MIAMI, FL (folio: 03-4132-019-0020)\\n9 WEST WALTON HOLDINGS LLC Holding company for 9 WEST WALTON STREET CONDOMINIUM UNITS 3500, 3600, 3700, and PH, CHICAGO, IL\\nADRP LLC Delaware LLC, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin\\nAH2013 HOLDINGS LLC See 190 PFISTER HOLDINGS LLC\\nALPH LLC a/k/a 1ALPH LLC Formerly FAA registered plane N421AL\\nAP HOLDINGS I LLC See 31 WILLOW HOLDINGS LLC\\nARAGON INVESTMENTS LTD https://files.brokercheck.finra.org/firm/firm\\\\_45631.pdf\\nASHLER CAPITAL LLC https://adviserinfo.sec.gov/firm/summary/148826\\nASHLER CAPITAL MASTER FUND LTD https://www.sec.gov/Archives/edgar/data/1003078/000114420418014250/tv488357\\\\_sc13g.htm\\nBANBURY LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBANBURY II LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBKGST LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC See BLOSSOM WAY HOLDINGS LLC\\nBLACK WHEEL LLC Illinois LLC, registered 3/5/2014, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin\\nBLOSSOM WAY HOLDINGS LLC f/k/a CPPB HOLDINGS LLC f/k/a BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC Holding company for 10 BLOSSOM WAY, 70 BLOSSOM WAY, and 1265 S OCEAN BLVD PALM BEACH, FL (PCNs: 50-43-44-02-10-000-0050, 50-43-44-02-10-000-0060, 50-43-44-02-10-000-0010)\\nBRICKELL BAY HOLDINGS LLC Holding company for 1201 BRICKELL BAY, LLC\\nBRICKELL LEASING LLC See \"Subordination, Non-Disturbance, and Attornment Agreement\"; Miami-Dade Clerk\\'s File No.: 2022 R 938960, Group: 1. Kenneth C Griffin is sole member.\\nCAAM MANAGEMENT LLC https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm\\nCAISLEAN CAPITAL LTD NFA Pool ID P113537, ceased trading 3/31/2016\\nCALC III LP https://www.sec.gov/edgar/browse/?CIK=1582652\\nCALC IV LP https://www.sec.gov/edgar/browse/?CIK=1423043\\nCALC V LP Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf',\n        'Simulate a conversation between the writer of this post, named /u/CruxHub, and the expert business investigator. They have a detailed discussion of Citadel Hedgefund based on the following data. Do not include the following data in the search query. \\n\\nData: Entity Name Purpose / Source\\n1|101 PC HOLDINGS LLC|Holding company for Penthouse C at the Setai Miami Beach (folio: 02-3234-153-1160)|PC = Penthouse C \\n2|11 STAR ISLAND LLC|Holding company for 10 STAR ISLAND DR, MIAMI BEACH, FL 33139 (folio: 02-4204-001-0100, 02-4204-001-0110) (lots 10, 11 and 12 of Star Island)| \\n3|117 EAST PARK AVENUE, LLC|Holding company for 117 E. PARK AVE, LIBERTYVILLE, IL (PIN: 11-21-212-046-0000); subsequently sold.| \\n4|1201 BRICKELL BAY, LLC|Holding company for 1201 BRICKELL BAY DR, MIAMI, FL (folio no: 141390710010)| \\n5|1221 BRICKELL, LLC|Holding company for 1221 BRICKELL AVE, 155 SE 13 ST, 165 SE 13 ST, 175 SE 13 ST, and 185 SE 13 ST, MIAMI, FL (folio: 01-4139-035-0010)| \\n6|1221 BRICKELL HOLDINGS LLC|Holding company for 1221 BRICKELL, LLC| \\n7|1229 PARK WEST AVENUE, LLC|Holding company for 1229 W. PARK AVE, LIBERTYVILLE, IL (PIN: 11-20-100-010-0000)| \\n8|125 WORTH LLC|Delaware LLC (file 7218403), Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person; speculaton this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC, this property is next door (PCN: 50-43-43-23-05-016-0380)| \\n9|125 WORTH HOLDINGS LLC|Delaware LLC (file 7218407); not registered to Florida yet but speculation this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC| \\n10|1250 BB ASSET CO LLC|Holding company for 1250 BRICKELL BAY DR and 1260 BRICKELL BAY DR, MIAMI, FL (folio nos: 102100504250, 102100503210)|BB = Brickell Bay \\n11|1330 SOUTH OCEAN LLC|Holding company for 1330 S OCEAN BLVD, PALM BEACH, FL (PCN: 50-43-44-02-11-000-0020)| \\n12|14 STAR ISLAND LLC|Delaware LLC (file 3377653); incorporated 8/42020, withdrawn 10/10/2022; believe this was not used because 14 STAR ISLAND property was held by NAUTILUS HOLDINGS I LLC before sale on 10/5/2022| \\n13|151 WORTH, LLC|Holding company for 151 WORTH AVE, PALM BEACH, FL 33480 (PCN: 50-43-43-23-05-016-0130); office space for Citadel (https://localtoday.news/fl/citadel-moves-into-palm-beachs-former-neiman-marcus-building-4821.html); sole member is 151 WORTH HOLDINGS LLC| \\n14|151 WORTH HOLDINGS LLC|Holding company for 151 WORTH, LLC| \\n15|16 WILLOW HOLDINGS LLC f/k/a PVNAH LLC|Holding company for S WILLOW COURT, ASPEN, CO (Parcel: 273511309030); see Pitkin Co. reception # 623002, Delaware certificate showing name change 9/1/2015| \\n16|190 PFISTER HOLDINGS LLC f/k/a AH2013 HOLDINGS LLC|Holding company for 190 PFISTER DR, ASPEN, CO (parcel: 273511309029); see Pitkin Co.reception # 623000, Delaware certificate showing name change 9/1/2015| \\n17|196 PFISTER HOLDINGS LLC|Holding company for 196 PFISTER DR, ASPEN, CO (parcel: 273511309028); see Pitkin Co. reception # 623501, statement of authority show KP HOLDINGS LLC as sole membe| \\n18|1ALPH LLC|See ALPH LLC| \\n19|1BUSINESS GROUP LLC|See BUSINESS GROUP LLC| \\n20|1GFS DESIGN LLC|See GFS DESIGN LLC| \\n21|1GFS LLC|See GFS LLC| \\n22|1MEDIA HOLDINGS LLC|See MEDIA HOLDINGS LLC| \\n23|23174 NE 41ST PATH LLC|Holding company for 23174 NE 41ST PATH #12, OKEECHOBEE, FL 34972 (Parcel: 1-01-35-35-0020-00000-0120); part of Pine Creek Sporting Club (www.pinecreeksportingclub.com) includes horse, shooting sports; sole member is KP HOLDINGS L.L.C.| \\n24|3031 BRICKELL LLC|Holding company for 3031 BRICKELL AVE, MIAMI FL 33129 (Folio: 01-4139-001-2700); Sole member is KP HOLDINGS L.L.C.| \\n25|31 WILLOW HOLDINGS LLC f/k/a AP HOLDINGS I LLC|Holding company for 31 NORTH WILLOW COURT, ASPEN, CO (Parcel: 273511309019); sold 7/6/2017; see Pitkin Co. reception # 623001, Delaware certificate showing name change 9/1/2015| \\n26|650 CASUARINA LLC|Holding company for 650 CASUARINA CONCOURSE CORAL GABLES, FL (folio: 03-4132-019-0060) https://www.bizjournals.com/southflorida/news/2022/05/27/650-casuarina-concourse-coral-gables-sold.html|\" \\n27|650 MEADOW LANE 1 LP|Holding company for 650 MEADOW LANE, VILLAGE OF SOUTHAMPTON, NY (Parcel ID 7478) (https://archive.is/h85yq)| \\n28|800 NORTH MICHIGAN HOLDINGS LLC|Holding company for 800 N MICHIGAN AVE, UNITS 66 PH and 67 PH, CHICAGO, IL (Park Tower) (PINs: 17-03-231-018-1116, 17-03-231-018-1117); sole member is KP HOLDINGS LLC (see Cook County, IL doc # 1933315025); recently sold| \\n29|8565 OLD CUTLER LLC|Holding company for 8565 OLD CUTLER RD, MIAMI, FL (folio: 03-4132-019-0020)| \\n30|9 WEST WALTON HOLDINGS LLC|Holding company for 9 WEST WALTON STREET CONDOMINIUM UNITS 3500, 3600, 3700, and PH, CHICAGO, IL| \\n31|ADRP LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin|ADRP = Anne Dias Real Property? \\n32|AH2013 HOLDINGS LLC|See 190 PFISTER HOLDINGS LLC|AH = Aspen Holdings? \\n33|ALPH LLC a/k/a 1ALPH LLC|Formerly FAA registered plane N421AL| \\n34|AP HOLDINGS I LLC|See 31 WILLOW HOLDINGS LLC|AP = Aspen Property? \\n35|ARAGON INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_45631.pdf| \\n36|ASHLER CAPITAL LLC|https://adviserinfo.sec.gov/firm/summary/148826| \\n37|ASHLER CAPITAL MASTER FUND LTD|https://www.sec.gov/Archives/edgar/data/1003078/000114420418014250/tv488357\\\\_sc13g.htm| \\n38|BANBURY LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n39|BANBURY II LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n40|BKGST LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n41|BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC|See BLOSSOM WAY HOLDINGS LLC|Black Calabash is a type of tropical tree: https://edis.ifas.ufl.edu/publication/ST079 \\n42|BLACK WHEEL LLC|Illinois LLC, registered 3/5/2014, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin| \\n43|BLOSSOM WAY HOLDINGS LLC f/k/a CPPB HOLDINGS LLC f/k/a BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC|Holding company for 10 BLOSSOM WAY, 70 BLOSSOM WAY, and 1265 S OCEAN BLVD PALM BEACH, FL (PCNs: 50-43-44-02-10-000-0050, 50-43-44-02-10-000-0060, 50-43-44-02-10-000-0010)| \\n44|BRICKELL BAY HOLDINGS LLC|Holding company for 1201 BRICKELL BAY, LLC| \\n45|BRICKELL LEASING LLC|See \"Subordination, Non-Disturbance, and Attornment Agreement\"; Miami-Dade Clerk\\'s File No.: 2022 R 938960, Group: 1. Kenneth C Griffin is sole member.| \\n46|CAAM MANAGEMENT LLC|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm|CAAM = Citadel Alternative Asset Management \\n47|CAISLEAN CAPITAL LTD|NFA Pool ID P113537, ceased trading 3/31/2016| \\n48|CALC III LP|https://www.sec.gov/edgar/browse/?CIK=1582652| \\n49|CALC IV LP|https://www.sec.gov/edgar/browse/?CIK=1423043| \\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009|',\n        'Web search results:\\n\\n[1] \"As per the Oxford Dictionary, a chatbot is defined as A computer program designed to simulate conversation with human users, especially over the internet. It can be looked upon as a virtual assistant that communicates with users via text messages and helps businesses in getting close to their customers.\"\\nURL: https://www.datacamp.com/tutorial/building-a-chatbot-using-chatterbot\\n\\n[2] \"Python , A chatbot is a computer program designed to simulate conversation with human users, especially over the internet. Create a fortune teller program that will ask the user to input a question and feedback some random answer. Consider the following feedback to be used. No idea at all! Better pray. The possibilities are in your favor.\"\\nURL: https://www.chegg.com/homework-help/questions-and-answers/python-chatbot-computer-program-designed-simulate-conversation-human-users-especially-inte-q78825383\\n\\n[3] \"It was created by Joseph Weizenbaum in 1966 and it uses pattern matching and substitution methodology to simulate conversation. The program was designed in a way that it mimics human conversation. The Chatbot ELIZA worked by passing the words that users entered into a computer and then pairing them to a list of possible scripted responses.\"\\nURL: https://onlim.com/en/the-history-of-chatbots/\\n\\n[4] \"Study with Quizlet and memorize flashcards containing terms like Which analytics does the following fall into: Alice notice that call center always have an increase in the number of customer complaints during last week in May, so she decides reviews the employees work schedule in the month of May for the past 5 years., Datasets continue to become, Model used for predictive analytic have ...\"\\nURL: https://quizlet.com/415587939/big-data-final-exam-flash-cards/\\n\\n[5] \"As every bright side has a darker version, simulation of human conversation through AI also has some disadvantages like high cost of creation, unemployment, interaction lacking emotion, and out-of-the-box thinking. However, AI interaction tools are trained with a data set. The bigger the data set, the better the services.\"\\nURL: https://www.analyticsinsight.net/simulating-human-conversations-through-ai/\\n\\n[6] \"The eavesdropper, Eve intercepts the encrypted conversation and tries random keys with the aim of learning the conversation shared between Alice and Bob as shown in Fig. 7. For this POC, we used ...\"\\nURL: https://www.researchgate.net/figure/A-A-simulation-of-conversations-between-Alice-and-her-friend-Bob-B-The-eavesdropper\\\\_fig3\\\\_334408170\\n\\n[7] \"Dreams are most often reported when sleepers wake from \\\\_\\\\_\\\\_\\\\_\\\\_ sleep. REM. The brain waves during REM sleep MOST closely resemble those seen during: waking consciousness. REM sleep is paradoxical because: the brain is active, but the major skeletal muscles are paralyzed. Fatigue and pain reflect deprivation of \\\\_\\\\_\\\\_\\\\_\\\\_ sleep.\"\\nURL: https://quizlet.com/78519058/psyc-test-2-flash-cards/\\n\\n[8] \"You can generate easily a fake group chat conversation like Whatsapp, Facebook or Telegram. After creating members/users, you can add messages in your chat. Once all messages are set up, you have the possibility to live-preview the chat conversation via the play button. Until the share functionality is ready, you have the option to screen ...\"\\nURL: https://chat-simulator.com/\\n\\n[9] \"This is a program that allows the computer to simulate conversation with a human being: answer choices a. Speech Application Program Interface b. Chatbot c. Voice Recognition d. Speech Recognition Question 7 30 seconds Report an issue Q. This is a system of Programs and Data-Structures that mimics the operation of the human brain: answer choices a.\"\\nURL: https://quizizz.com/admin/quiz/5f183913423fab001b0bd134/ai-unit-1\\n\\n[10] \"This is a system of Programs and Data-Structures that mimics the operation of the human brain: answer choices a. Intelligent Network b. Decision Support System c. Neural Network d. Genetic Programming Question 8 30 seconds Q. Where is Decision tree used? answer choices a. Classification Problem b. Regression Problem c. Clustering Problem d.\"\\nURL: https://quizizz.com/admin/quiz/5f6d6e4a6e2458001be385f5/ai-class-9\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: Simulate a conversation between Alice and /u/CruxHub. They talk about which company from the data batches is worth researching further into on the web.',\n        'Simulate a conversation between Alice and /u/CruxHub. They talk about which company from this data batch is worth researching further into on the web.\\n\\nData batch: Entity Name Purpose / Source Hypothesized Acronym\\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009| \\n58|CE TM HOLDINGS LLC f/k/a KCG IP HOLDINGS LLC|Holding company for intellectual property (25 trademarks, 1 patent found so far)|CE TM = Citadel Enterprise Trademark Holdings \\n59|CEF OFFSHORE HOLDINGS LTD|NFA Pool ID P131121| \\n60|CEIF INTERNATIONAL LTD|NFA Pool ID P048476; http://registers.centralbank.ie/ICAVDocuments/C439830/Director%20Details%20Updated%2021.01.07%203.pdf| \\n61|CEIF LLC|NFA Pool ID P048474| \\n62|CEIF PARTNERS INTERNATIONAL LTD|NFA Pool ID P173278| \\n63|CEIF PARTNERS LLC|NFA Pool ID P048475| \\n64|CES SECURITIES CANADA ULC|See CITADEL SECURITIES CANADA ULC, CSA NRD # 49280| \\n65|CFPS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B176936; 100% owned by CITADEL ENERGY INVESTMENTS LTD| \\n66|CGE ALPHA LTD|NFA Pool ID P057309, ceased trading 6/7/2017| \\n67|CGE ALPHA OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064400, ceased trading 4/30/2017| \\n68|CGEF OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064406, ceased trading 2/21/2019| \\n69|CGEF SPC|NFA Pool ID P064408, ceased trading 12/31/2012| \\n70|CGMF OFFSHORE HOLDINGS LTD|NFA Pool ID P064410, ceased trading 3/31/2014| \\n71|CGTS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B157777; 100% owned by TACTICAL TRADING HOLDING LTD; NFA Pool ID P064412, ceased trading 9/30/2014| \\n72|CHARAXES MELVIN LLC|Sole member of CHARAXES MELVIN II LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n73|CHARAXES MELVIN II LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is CHARAXES MELVIN LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n74|CHI2LTV LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n75|CIG(E) LLP|See CITADEL EUROPE LLP| \\n76|CIG CANADA ULC|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n77|CIG MEDIA LLC|https://www.sec.gov/Archives/edgar/data/923877/000114420407003635/v063478\\\\_sc-13d.htm| \\n78|CITADEL AAM LP|https://www.sec.gov/Archives/edgar/vprr/0804/08040017.pdf| \\n79|CITADEL AC INVESTMENTS LTD|https://www.sec.gov/Archives/edgar/data/1015780/000114420408032074/v115701\\\\_sc13da.htm| \\n80|CITADEL ADVISORS EUROPE LIMITED f/k/a CITADEL MANAGEMENT (EUROPE) LIMITED f/k/a CITADEL HEDGE FUND SERVICES (EUROPE) LIMITED|https://find-and-update.company-information.service.gov.uk/company/10930267| \\n81|CITADEL ADVISORS HOLDINGS LP|Sole member of CITADEL ADVISORS LLC; https://www.sec.gov/Archives/edgar/data/1567180/000110465922099806/xslF345X03/tm2225817-2\\\\_4.xml| \\n82|CITADEL ADVISORS HOLDINGS II LP|https://www.sec.gov/Archives/edgar/data/1177609/000114420416082613/v429844\\\\_sc13ga.htm| \\n83|CITADEL ADVISORS HOLDINGS III LP|https://www.sec.gov/Archives/edgar/data/1640129/000114420415043739/xslF345X02/v416000\\\\_3.xml| \\n84|CITADEL ADVISORS LLC|NFA ID: 0391913; https://www.sec.gov/edgar/browse/?CIK=1423053| \\n85|CITADEL ADVISORS II LLC|| \\n86|CITADEL ADVISORS SINGAPORE PTE. LIMITED|| \\n87|CITADEL ALTERNATIVE ASSET MANAGEMENT LP|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm| \\n88|CITADEL AMERICAS LLC|| \\n89|CITADEL AMERICAS SERVICES LLC|| \\n90|CITADEL ANTAEUS INTERNATIONAL INVESTMENTS LTD|| \\n91|CITADEL ASIA ASSET HOLDING LIMITED|http://registers.centralbank.ie/ICAVDocuments/C157189/Director%20Details%20Updated%2016.10.31%202.pdf| \\n92|CITADEL ASIA LIMITED f/k/a CITADEL (HONG KONG) LIMITED|https://adviserinfo.sec.gov/firm/summary/148826| \\n93|CITADEL CANDLESTICK EIF LLC|| \\n94|CITADEL CANTERBURY S.\\u00e0 r.l.|Luxembourg - B87988; 100% owned by CITADEL TONBRIDGE S.\\u00e0 r.l.| \\n95|CITADEL CEFL CHINA LTD|NFA Pool ID P148073| \\n96|CITADEL CEFL INVESTMENTS LTD|NFA Pool ID: P161763; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n97|CITADEL CEIT CHINA LTD|| \\n98|CITADEL CEMF CHINA LTD|https://find-and-update.company-information.service.gov.uk/company/02263951/charges/x6zPQSYGNpuDNgxU1cFQlCS0iog| \\n99|CITADEL CEMF INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n100|CITADEL CEMF SPV LTD f/k/a CITADEL INVESTMENT MASTER FUND LTD|See CITADEL INVESTMENT MASTER FUND LTD; https://opencorpdata.com/lei/LF0U6QUBXKIO573GXS38|',\n        'Simulate a conversation between Alice and /u/CruxHub. /u/CruxHub asks Alice to anlalyze a data batch for non-standard insights.\\n\\nData batch: Entity Name Purpose / Source Hypothesized Acronym\\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009| \\n58|CE TM HOLDINGS LLC f/k/a KCG IP HOLDINGS LLC|Holding company for intellectual property (25 trademarks, 1 patent found so far)|CE TM = Citadel Enterprise Trademark Holdings \\n59|CEF OFFSHORE HOLDINGS LTD|NFA Pool ID P131121| \\n60|CEIF INTERNATIONAL LTD|NFA Pool ID P048476; http://registers.centralbank.ie/ICAVDocuments/C439830/Director%20Details%20Updated%2021.01.07%203.pdf| \\n61|CEIF LLC|NFA Pool ID P048474| \\n62|CEIF PARTNERS INTERNATIONAL LTD|NFA Pool ID P173278| \\n63|CEIF PARTNERS LLC|NFA Pool ID P048475| \\n64|CES SECURITIES CANADA ULC|See CITADEL SECURITIES CANADA ULC, CSA NRD # 49280| \\n65|CFPS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B176936; 100% owned by CITADEL ENERGY INVESTMENTS LTD| \\n66|CGE ALPHA LTD|NFA Pool ID P057309, ceased trading 6/7/2017| \\n67|CGE ALPHA OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064400, ceased trading 4/30/2017| \\n68|CGEF OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064406, ceased trading 2/21/2019| \\n69|CGEF SPC|NFA Pool ID P064408, ceased trading 12/31/2012| \\n70|CGMF OFFSHORE HOLDINGS LTD|NFA Pool ID P064410, ceased trading 3/31/2014| \\n71|CGTS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B157777; 100% owned by TACTICAL TRADING HOLDING LTD; NFA Pool ID P064412, ceased trading 9/30/2014| \\n72|CHARAXES MELVIN LLC|Sole member of CHARAXES MELVIN II LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n73|CHARAXES MELVIN II LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is CHARAXES MELVIN LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n74|CHI2LTV LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n75|CIG(E) LLP|See CITADEL EUROPE LLP| \\n76|CIG CANADA ULC|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n77|CIG MEDIA LLC|https://www.sec.gov/Archives/edgar/data/923877/000114420407003635/v063478\\\\_sc-13d.htm| \\n78|CITADEL AAM LP|https://www.sec.gov/Archives/edgar/vprr/0804/08040017.pdf| \\n79|CITADEL AC INVESTMENTS LTD|https://www.sec.gov/Archives/edgar/data/1015780/000114420408032074/v115701\\\\_sc13da.htm| \\n80|CITADEL ADVISORS EUROPE LIMITED f/k/a CITADEL MANAGEMENT (EUROPE) LIMITED f/k/a CITADEL HEDGE FUND SERVICES (EUROPE) LIMITED|https://find-and-update.company-information.service.gov.uk/company/10930267| \\n81|CITADEL ADVISORS HOLDINGS LP|Sole member of CITADEL ADVISORS LLC; https://www.sec.gov/Archives/edgar/data/1567180/000110465922099806/xslF345X03/tm2225817-2\\\\_4.xml| \\n82|CITADEL ADVISORS HOLDINGS II LP|https://www.sec.gov/Archives/edgar/data/1177609/000114420416082613/v429844\\\\_sc13ga.htm| \\n83|CITADEL ADVISORS HOLDINGS III LP|https://www.sec.gov/Archives/edgar/data/1640129/000114420415043739/xslF345X02/v416000\\\\_3.xml| \\n84|CITADEL ADVISORS LLC|NFA ID: 0391913; https://www.sec.gov/edgar/browse/?CIK=1423053| \\n85|CITADEL ADVISORS II LLC|| \\n86|CITADEL ADVISORS SINGAPORE PTE. LIMITED|| \\n87|CITADEL ALTERNATIVE ASSET MANAGEMENT LP|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm| \\n88|CITADEL AMERICAS LLC|| \\n89|CITADEL AMERICAS SERVICES LLC|| \\n90|CITADEL ANTAEUS INTERNATIONAL INVESTMENTS LTD|| \\n91|CITADEL ASIA ASSET HOLDING LIMITED|http://registers.centralbank.ie/ICAVDocuments/C157189/Director%20Details%20Updated%2016.10.31%202.pdf| \\n92|CITADEL ASIA LIMITED f/k/a CITADEL (HONG KONG) LIMITED|https://adviserinfo.sec.gov/firm/summary/148826| \\n93|CITADEL CANDLESTICK EIF LLC|| \\n94|CITADEL CANTERBURY S.\\u00e0 r.l.|Luxembourg - B87988; 100% owned by CITADEL TONBRIDGE S.\\u00e0 r.l.| \\n95|CITADEL CEFL CHINA LTD|NFA Pool ID P148073| \\n96|CITADEL CEFL INVESTMENTS LTD|NFA Pool ID: P161763; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n97|CITADEL CEIT CHINA LTD|| \\n98|CITADEL CEMF CHINA LTD|https://find-and-update.company-information.service.gov.uk/company/02263951/charges/x6zPQSYGNpuDNgxU1cFQlCS0iog| \\n99|CITADEL CEMF INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n100|CITADEL CEMF SPV LTD f/k/a CITADEL INVESTMENT MASTER FUND LTD|See CITADEL INVESTMENT MASTER FUND LTD; https://opencorpdata.com/lei/LF0U6QUBXKIO573GXS38|',\n        'Web search results:\\n\\n[1] \"Katherine Burton Hedge fund titans Ken Griffin and Steve Cohen boosted Gabe Plotkins Melvin Capital, injecting a total of $2.75 billion into the firm after it lost about 30% this year. Citadel...\"\\nURL: https://www.bloomberg.com/news/articles/2021-01-25/citadel-point72-to-invest-275-billion-in-melvin-capital\\n\\n[2] \"NEW YORK, Jan. 25, 2021 /PRNewswire/ -- Melvin Capital Management (Melvin) today announced that Citadel and its partners and Point72 have made investments into its fund. I am incredibly...\"\\nURL: https://www.prnewswire.com/news-releases/melvin-announces-2-75-billion-investment-from-citadel-and-point72--301214477.html\\n\\n[3] \"Citadel LLC is further paring back its $2 billion investment in Melvin Capital Management after the hedge fund stumbled in its effort to recover from a near collapse triggered by surges in...\"\\nURL: https://www.wsj.com/articles/citadel-is-further-paring-back-2-billion-melvin-investment-11645710666\\n\\n[4] \"Citadel and Steven A. Cohen s Point72 Asset Management together invested $2.75 billion into Melvins hedge fund on Jan. 25 as Melvin was hemorrhaging money. In return for the rare...\"\\nURL: https://www.wsj.com/articles/citadel-to-redeem-about-500-million-from-melvin-capital-11629550410\\n\\n[5] \"CHARAXES MELVIN LLC is an Active company incorporated on August 5, 2022 with the registered number M22000012341. This Foreign Limited Liability company is located at SOUTHEAST FINANCIAL CENTER, 200 S. BISCAYNE BLVD., SUITE 3300, MIAMI, 33131 and has been running for one year. ... CITADEL SECURITIES GP LLC; KCG SPACE HOLDINGS LLC;\"\\nURL: https://bisprofiles.com/fl/charaxes-melvin-m22000012341\\n\\n[6] \"Now, Citadel is taking some of its money back. Citadel has notified Melvin of its plans to retrieve $500 million of the $2 billion it injected in late January, according to two people briefed...\"\\nURL: https://www.nytimes.com/2021/08/21/business/citadel-melvin-gamestop.html\\n\\n[7] \"Robinhood and Citadels relationship comes into focus as Washington vows to examine stock-market moves Trading firms at center of Reddit-fueled stock surges have worked closely to share...\"\\nURL: https://www.washingtonpost.com/business/2021/01/29/robinhood-citadel-gamestop-reddit/\\n\\n[8] \"Alongside hedge funds such as Melvin Capital, Citron Capital, Point72, D1 Capital Partners, and Candlestick Capital Management; Citadel LLC was, the lawsuit claims, taking up short positions against the securities that retail investors were longing. This alleged conflict of interest is at the core of the class action lawsuit.\"\\nURL: https://tokenist.com/new-lawsuit-alleges-citadel-conspired-with-robinhood-to-limit-gme-trading/\\n\\n[9] \"Melvin later attracted an additional $3.2 billion in fresh cash, and the firm had $11.7 billion in assets at the beginning of this year. Point72 hasnt redeemed its investment, a person familiar ...\"\\nURL: https://www.chicagobusiness.com/finance-banking/ken-griffins-citadel-pulling-back-most-its-2-billion-melvin-capital-investment\\n\\n[10] \"CHARAXES MELVIN II LLC branch. Company Number M22000012338 Status Active Incorporation Date 5 August 2022 (2 months ago) Company Type Foreign Limited Liability Jurisdiction Florida (US) Branch Branch of CHARAXES MELVIN II LLC (Delaware (US)) Agent Name C T CORPORATION SYSTEM Agent Address\"\\nURL: https://opencorporates.com/companies/us\\\\_fl/M22000012338\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is Charaxes Melvin LLC\\'s relationship to Citadel?',\n        'Web search results:\\n\\n[1] \"Federal authorities are investigating the market-making arms of Citadel LLC and KCG Holdings Inc, looking into the possibility that the two giants of electronic trading are giving small investors ...\"\\nURL: https://www.reuters.com/article/usa-stocks-probe-idUSL2N1871ZV\\n\\n[2] \"Today, KCG is second only to Citadel in the market for handling stock order flow from retail brokerage firms. KCG and many other high-frequency trading firms have shied away from the public...\"\\nURL: https://www.ibtimes.com/citadel-llc-kcg-holdings-kcg-market-making-arms-probed-federal-authorities-over-stock-2366805\\n\\n[3] \"Citadel Securities, a group owned by the Chicago-based hedge fund, is the would-be acquirer in the deal, the people said. The group is best known for its so-called wholesaler business that...\"\\nURL: https://www.wsj.com/articles/market-making-arm-of-citadel-llc-in-talks-to-buy-seats-on-nyse-floor-from-kcg-holdings-1454533971\\n\\n[4] \"Citadels share of the wholesale market is around 34 per cent compared to KCGs 25 per cent, according to Tabb Group. Virtu has yet to lay out in detail its plans for the wholesale business ...\"\\nURL: https://www.ft.com/content/e1cb396e-29a7-11e7-bc4b-5528796fe35c\\n\\n[5] \"Citadel Securities, a liquidity providers and market maker, announced it will purchase KCG Holdings designated market maker (DMM) business at the New York Stock Exchange. This will establish Citadel Securities as the DMM with the largest footprint on the NYSE, responsible for trading in approximately 1,500 issues.\"\\nURL: https://www.tradersmagazine.com/departments/brokerage/citadel-purchases-kcg-dmm-business-becomes-1-on-nyse/\\n\\n[6] \"isCitadel LLC and its related entity, KCG IP Holdings, LLC (Complainant), represented by Paul D. McGradyof Winston Strawn, Illinois, Respondent is- (Respondent), Alabama, USA. REGISTRAR AND DISPUTED DOMAIN NAME The domain name at issue iscitidelgroup.com, registered with TUCOWS, INC. PANEL The\"\\nURL: https://www.adrforum.com/domaindecisions/1522837.htm\\n\\n[7] \"KCG SPACE HOLDINGS LLC is an Active company incorporated on July 21, 2022 with the registered number M22000011413. This Foreign Limited Liability company is located at 200 S BISCAYNE BLVD STE 3300, MIAMI, FL, 33131, US and has been running for one year. It currently has one Authorized Person. KEY FACTS ABOUT KCG SPACE HOLDINGS LLC US Businesses\"\\nURL: https://bisprofiles.com/fl/kcg-space-holdings-m22000011413\\n\\n[8] \"The Complainant KCG IP Holdings LLC is the owner of US Trademark Registration No. 3,213,943, filed October 18, 2004, registered February 27, 2007, claiming first use dating back to 1994. Therefore, the Panel concludes that Complainants filing and registration of the CITADEL mark with the USPTO sufficiently demonstrates that it has rights in ...\"\\nURL: https://www.adrforum.com/domaindecisions/1579141.htm\\n\\n[9] \"The KCG SPACE HOLDINGS LLC principal address is 200 S BISCAYNE BLVD STE 3300, MIAMI, 33131. Meanwhile you can send your letters to 200 S BISCAYNE BLVD STE 3300, MIAMI, FL, 33131. The company`s registered agent is C T CORPORATION SYSTEM 1200 SOUTH PINE ISLAND ROAD, PLANTATION, FL, 33324. The company`s management are A, President - Beeson Gerald A.\"\\nURL: https://florida.intercreditreport.com/company/kcg-space-holdings-llc-m22000011413\\n\\n[10] \"Billionaire Ken Griffin has built Citadel Securities into a trading and asset management colossus. ... and KCG Holdings. Last month, Citadel Securities reached an agreement with the SEC to pay $22 ...\"\\nURL: https://www.chicagobusiness.com/article/20170203/NEWS01/170209978/chicago-billionaire-ken-griffin-splits-citadel-into-two-companies\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is KCG Space Holdings LLC\\'s relationship to Citadel?',\n        'Web search results:\\n\\n[1] \"Citadel LLC (formerly known as Citadel Investment Group, LLC) is an American multinational hedge fund and financial services company. Founded in 1990 by Ken Griffin, it has more than $50 billion in assets under management as of May 2022. [1]\"\\nURL: https://en.wikipedia.org/wiki/Citadel\\\\_LLC\\n\\n[2] \"NASHVILLE, Tenn. and BRONXVILLE, N.Y. \\u2014 Standard Media Group LLC (Standard Media) and Citadel Communications LLC (Citadel) jointly announced today that they have reached an agreement pursuant to which Standard Media will acquire from Citadel WLNE-TV, the ABC affiliate for the Providence, RI - New Bedford, MA market (DMA 52) and KLKN (TV), the \\u2026\"\\nURL: https://www.standardmedia.com/2019/05/16/standard-media-group-to-acquire-citadel-stations/\\n\\n[3] \"CITADEL MEDIA LLC. Citadel Media LLC is a New Hampshire Domestic Limited-Liability Company filed on February 6, 2021. The companys filing status is listed as Not In Good Standing and its File Number is 862423. The Registered Agent on file for this company is Peter Alan Gauthier and is located at 3 Maple Ridge Drive Unit 224, Merrimack, NH 03054.\"\\nURL: https://www.bizapedia.com/nh/citadel-media-llc.html\\n\\n[4] \"CITADEL MEDIA LLC is a Michigan Domestic Limited-Liability Company filed on November 16, 2017. The companys filing status is listed as Active and its File Number is 802132896. The Registered Agent on file for this company is Registered Agents Inc. and is located at 2222 W. Grand River Ave Ste A, Okemos, MI 48864. The companys mailing address ...\"\\nURL: https://www.bizapedia.com/mi/citadel-media-llc.html\\n\\n[5] \"Citadel Broadcasting Corporation was a Las Vegas, Nevada -based broadcast holding company. Citadel owned 243 radio stations across the United States and was the third-largest radio station owner in the country. Only iHeartMedia and Cumulus Media owned more stations prior to Citadels merger with Cumulus.\"\\nURL: https://en.wikipedia.org/wiki/Citadel\\\\_Broadcasting\\n\\n[6] \"Citadel is one of the largest hedge fund managers in the world. And theyve subsequently managed Melvin Capital to the ground. Melvin Capital suffered a loss of over 50% its first quarter in 2021 due to shorting AMC Entertainment and GameStop. At some point youd expect your clearing house to raise awareness on your risk management right?\"\\nURL: https://franknez.com/citadel-loses-billions-hedge-funds-are-getting-dragged-down/\\n\\n[7] \"At our core, Citadel is built to deliver excellence. We have some of the most talented and focused minds in the industry, and we activate their ideas and strategies through a robust range of proven technologies and execution capabilities. View Top Employees from Citadel LLC Looking for a particular Citadel LLC employees phone or email? Find Info\"\\nURL: https://rocketreach.co/citadel-llc-profile\\\\_b5c46522f42e0dc2\\n\\n[8] \"# 1 Most profitable hedge fund manager of all time Source: LCH Investment NV estimates, Top Hedge Fund Managers by Net Gains Since Inception as of 12/31/2022. Our people are relentless in seeking a better way. Each day, we reimagine and refine our strategies, models and technology in pursuit of superior results and long-term performance.\"\\nURL: https://www.citadel.com/\\n\\n[9] \"We are one of the most significant alternative investment managers in the public U.S. corporate credit markets. Explore Credit Convertibles Equities Equities represents one of the largest and longest tenured businesses at Citadel. Explore Equities Global Fixed Income Macro We are a leading fixed income and macro business.\"\\nURL: https://www.citadel.com/what-we-do/\\n\\n[10] \"Citadel. 203,101 followers. 1mo. Last weekend, we celebrated Citadels 30th anniversary at an incredible event at Disney World and Universal Studios. Our founder and CEO Ken Griffin summarized ...\"\\nURL: https://www.linkedin.com/company/citadel-llc\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is CITADEL MEDIA LLC?',\n        \"What are the differences between the Dogme approach to language learning and the lexical approach to language learning\",\n        \"Implement my own netfilter in linux with linux kernel module with Rust\",\n        \"Damage to which nerve causes numbness of the palmar surface of the 5th digit/little finger\",\n        \"Explain the fault-tolerance of the reaction control system on the Space Shuttle\",\n        \"Hi, can you help me download 2000 portrait sketch images from Pinterest website with resolution at least 512 \\\\* 512? using python code\",\n        \"Tell me about the negatives of farming meat\",\n        \"what is the photograph filter called where the only part of the image is greyscale\",\n        \"I want some geological database structure with some example data for practicing my SQL query skills. Would you generate that for me?\",\n        \"What is a formal but simplified explanation of Web marketing\",\n        \"Rewrite and improve this story: Well, I have always liked helping people since I was a small child, I have been accused many times of giving too much away for free, but I find joy in helping others put the pieces together to reach their goals. As a Licensed Professional Counselor and Life Coach that is my job to impact individuals and help clients work through emotional difficulties and reach goals. But I will be honest with you I was selling the dream but not always living the dream. I had issues I had not worked completely through like childhood trauma, heartbreak, disappointments, and frustrations with life. Don't get me wrong I had the husband, the kids, the house and the 6 figure job but I was not happy inside, but I didn't change because I hate change, most of us hate change, right? Then I lost my sister, my friend, and it slapped me in the face that I need to take care of myself. I saw the addiction, I saw her not taking care of herself and I could not save her. One thing I know for sure, if you do not make your wellness a priority illness will find you. I remember the moment we lost her, the earth stood still and then my heart broke into pieces, what was I going to do, I have loved her my whole life! It was months later that I made a decision that I would be the change I hope to see, I would create a space for women of color to move past the obstacles that keep us from creating the life we want and Brown Suga Wellness was born. I am on this journey and I invite you to be on this journey with me! I love this quote by Oludara Adeeyo: \\\"When you heal yourself, you create an earth shattering legacy. The lineage of women who come after you will be healed. Your inner circle of Black women around you, healed.\\\" When you choose yourself you break generational trauma and curses. You activate your ancestral strength. I invite you to activate that strength!\",\n        \"How would you ask these questions: Tell everyone a little about you, where you from, what you like doing?\\nWhat goals are you pursuing right now?\\nWho has made the most influence in your life?\\nWho is the one person that you admire the most (alive or dead)?\\nWhat is the hardest challenge you\\u2019re had to overcome in your life?\\nWhen have you grown the most in your life and what caused that growth?\\nWhere is your favorite place to relax and renew?\\nWhat books have changed your life the most?\\nWhat Is the biggest thing that you want the audience to take away today?\\nHow can people get a hold of you to talk about your business?\",\n        \"Take these topics into a numbered table and generate subtopics in seperated lines for each. Preconfigure these subtopics as lections of those several topics and add them to the table. Use numbers for topics and letters for subtopics. Set a status (untouched/touched) for every subtopic in 3. coloumn of the table to mark them done when finished learning this subtopic and topic. Use coloumn 4 of the table for a short resumee of the chapter. Showing the learning process in percentage in front of every new output is first. Show the Table and wait for any userinput to start lessons on those topics.;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;\",\n        \"Write a rap song about Mikkel Selko\",\n        \"list the largest outdoor retailers in the world\",\n        \"can you create a wordpress shortcode to include the following code from facebook sdk\",\n        'Is this grammatically correct: \"It only took 5 years, and while we still have a long way to go, Topher\\u2019s Farm has found its place with unique experience and offering of organic produce. \"',\n        \"Hello friend. My task for today is to engage in a debate with you. Will you humor me in this regard?\",\n        \"You are an expert marketing consultant and copywriter with expertise is direct response marketing. I need your help. Can I tell you about my business?\",\n        'here is part 1\\n\\n----\\nDaySculpting is a program that that deals with YOUR immediate future\\u2026.It is a 90 day program that teaches U how to create Success\\u2026 one day at a time\\u2026today\\u2026\\nUsing recent breakthroughs in the field of neuroscience, the study of the human brain, DaySculpting is one of the most powerful success systems on earth for creating what I call\\u2026 \\n\"Your Epic Ideal Day\" -- And when U have Epic Ideal Days? U create your EPIC IDEAL LIFE.\\n\\nDaySculpting is broken down into 3 easy to accomplish segments throughout your day\\u2026\\n~The Morning Lift Process\\u2026which sets U up with a MindState of Success and a design for U to follow throughout your day\\u2026There is a morning email\\u2026SMS text\\u2026Inspiring Video\\u2026Future Forward Tuning IN\\u2026And a 3 step Success Step Declaration Process\\u2026this only takes 15 minutes\\u2026\\n~Mid-Day Reconnect Process\\u2026whatever your miid-day is\\u2026U are encouraged to stop doing what U are doing and disconnect so U can re-connect\\u2026by listening to a 5-minute Tuning In Re-Connection. We know that somewhere in the middle of our day it\\u2019s easy to lose momentum and drift from our best intentions because of all the demands on our attention. It has been scientifically proven that when U disconnent for between 3 to 5 minutes at the midpoint of your day\\u2026.your brain resets\\u2026and your energy is replenished\\u2026I like to call it a MindState Re-Boot that will inspire U to re-ignite your imagination\\u2026this only takes 5 minutes\\n~Highlight And Insight Review Process\\u2026we all review our day however what DaySculpting \\nanchors for U is an activation and integration process that gets U to see your day as being successful\\u2026by celebrating your successes (your highlights) and being present to things U could have improved on (your insights) so U can make your insights into highlights..most people when they review their day fail to celebrate even the smallest increments of success\\u2026they focus on what they didn\\u2019t do and that puts them in a negative energy\\u2026Success has challenges and the\\nhighlights and insight process encourages and empowers U to honestly see what U are doing each day so U Sculpt new MindStates Of Success rather than the energy of uncertainty\\u2026\\nthis takes 10 minutes\\n\\nThe whole DaySculpting process takes 30 minutes a day\\u2026and as I always say if U don\\u2019t have \\n30 minutes to change your life then U don\\u2019t want to change your life and U are okay with living \\na mediocre life\\u2026\\n\\nDay Sculpting is about targeting specific Chief Aims U have for your life\\u2026and creating the Habits that will get U there\\u2026Imagine being able to replace the MindTraps (your limiting beliefs) with empowering rituals and habits that become your new normal\\u2026\\n\\nThrough the repetition of doing the daily DaySculpting process U are carving into your Subconscious memory thoughts, beliefs and actions that result in U sculpting the masterpiece known as U\\u2026\\n\\nThere are many programs out there that attempt to instill new success behaviors however many fall short of actually shifting your MindStates into a frequency of possibility where U get to actually see your daily results immediately\\u2026DaySculpting does this\\u2026\\n\\nThis is not science fiction\\u2026 and it\\'s not wishful thinking, or some tired old self-improvement, goal-setting program\\u2026 DaySculpting is a program that empowers U to manifest and realize your Chief Aims in life\\n\\n\"DaySculpting\" -- is a tool that takes just MINUTES a day for you to use\\u2026\\n\\nIt is designed to FREE UP hours in your day\\u2026 while at the SAME time empowering you for greater success in ANY area of your life.\\n\\nDaySculpting sheds light and solves an age-old problem:\\nWHY we often fight against the very changes we desire to make\\n\\nHave you ever experienced the FEELING that you deserve MORE out of your life? More financial freedom and greater rewards from the hard work you do every day? Deeper, more empowering relationships with those you love\\u2026 or maybe just meeting that special someone to share your life with? Perhaps you crave a deeper spiritual connection\\u2026 or a more healthy, trim, energetic body?\\u2026 \\nYET:\\nDespite your BEST intentions\\u2026 you struggle. Perhaps if you\\'re anything like me, you even self-sabotage your results with actions that you KNOW are not in your best interest.\\n\\nMaybe it FEELS like it did for me: Like you are swimming upstream\\u2026 making SOME progress, sure, but just not reaching your goals and desires fast enough.\\n\\nWell, I have wonderful news for you: It\\'s not because you\\'re lazy\\u2026 and it\\'s not because you are not smart enough, competent enough\\u2026 or ANYTHING enough! \\n\\nThe real REASON you desire more and are not seeing ALL the results you deserve lies within whether the Success Switch in your brain is in the ON or OFF position\\u2026\\n\\nThe SOLUTION\\u2026 THE ANSWER to flipping your Success Switch back ON lies within the simple daily steps U will take when U experience the DaySculpting Program\\u2026 \\nThe Day Sculpting Program Is A Simple Step Daily Success RITUAL \\u2028 That Shuts Down Your Body\\'s Failure Reflex \\u2028 So YOU Tap Into Your Brains Success Centers\\u2026\\u2028 In Just Minutes A Day!\\u2028\\u2028 IIMAGINE Knowing What HIGHLY SUCCESSFUL \\u2028 People Do EVERYDAY\\u2026\\nFor Abundance And Wealth, Greater Health, Self-Confidence Meaningful Relationships, Sharper Focus , Deeper Joy\\u2026\\u2028 And So Much More\\u2026\\n\\u201cNow You Too Can Use This 90-Day Game Changer\\u2028 To Tap Into The Key Success Centers Of Your Mind,\\u2028 And In Just Minutes You Can Transform Even Lousy Days\\u2028 Into Days Filled With The Results You Desire \\u2013 Guaranteed!\\u201d\\nTO MAKE A GREAT LIFE, ALL YOU HAVE TO IS MAKE EACH DAY A GREAT DAY \\u2026 \\nThen get up tomorrow and do the same thing, day after day after day.\\nARE YOU Ready To Change YOUR LIFE One Day At A Time\\u2026\\nThe comprehensive, fun and empowering 90-day DaySculpting program provides you with the life skills and tools to help you master a new MindState of Success and a range of powerful life-changing rituals and habits that will Sculpt Your Perfect Days Into A Great Life.\\nDAY SCULPTING WILL TEACH YOU:\\n\\u2022 The science behind HAVING A MindState Of Success...and why most people who want more in life actually have their success switch turned off by total accident!\\n\\u2022 How to get more done with more time and more energy left over!\\n\\u2022 The simple, yet powerful, process of building a powerful day so you create a series of \"Dynamic Days\" - days that will end up building your most incredible life (The one you always thought was out of reach!)\\n\\u2022 Learn the \\'Day Sculpting Principles\\'. These can have a huge impact on you your life, but when you learn how simple they really are, you can use them easily and consistently!\\n\\u2022 How in just a few minutes a day, you can keep positive results flowing and put your success energy into a permanent \\'ON\\' position!\\n\\u2022 And much more!\\nDaySculpting, is for those who are willing to take their life to the next level by creating new Success Habits replacing the ones that have been sabotaging your success. \\nSo make sure you can honestly agree with the following before experiencing DaySculpting:\\n\\u2022 You desire more out of life, yet feel as if you are \"missing something\" -- that special \"X Factor\" to take you to the next level?\\n\\u2022 You are brave enough to boldly say, \"I want greater wealth and financial freedom... and I demand the best lifestyle possible for me and my family!\\n\\u2022 You know the value of joy: You want to experience greater happiness, peace of mind, and connection with your friends and loved ones on a daily basis.\\nIf you agree with the above, and truly want to create the best life possible, with greater wealth, freedom, happiness, love, and fulfillment, then I invite you to experience the power of Day Sculpting \\u2026it will change the way you think about creating your day and the life you dream about. \\nI am not encouraging you to become busier but rather to use your mental and emotional, energy more elegantly sculpting your day the way you want it to be. \\nHow many times have you done a ton of work and still felt that you didn\\u2019t accomplish what you really wanted for yourself. Week after week, month after month go by and you still are no farther ahead of the game\\u2026stuck in the status quo that never seems to change.\\n\\nBreaking free means that the status quo of your life has to change\\u2026 your habits of expectation have to change \\u2026your mindset has to change\\u2026you have to uncover those old behaviors that have held you back and be willing to create a new mindset.\\n\\nYou have to be willing to shift your daily focus inwards towards what you need to do today rather than tomorrow. Because when you create a great day today you welcome in a more powerful tomorrow.\\n\\nWe all have the same 24 hours each day. But why are some people building fabulous careers, achieving healthy lifestyles, enjoying great relationships and incomes, living their passions, and creating what they truly desire as a life?\\n\\nImagine that you could clear away the distractions that you unconsciously create. You know the stuff that consumes your time causes stress and disconnects you from your purpose and passion. \\n\\nImagine every day you embrace the energy for what you are choosing to create in your life. Your thoughts empower you, your choices inspire you and your actions create momentum, opportunity and possibility.\\n\\nYou can create a GREAT LIFE, the life you want to live by focusing your efforts on Creating a Great Day Today. That\\u2019s Day Sculpting. Seven intentional sculpted days turn into a month of wonderful weeks and a year of magnificent months creating an amazingly successful life.\\n\\nNone of this is going to work though if you believe that what you were born with is all you will get\\u2026\\n\\nNo one will ever attempt to do something when they are convinced that they will fail.\\n\\nResearch has shown that the brain will actually stop itself from doing what\\u2019s necessary to succeed if a person believes that they cannot succeed.\\n\\nIt\\u2019s the small concrete indicators of success today that will prove you can have whatever it is you want and the process of Day Sculpting will empowers, inspire and motivates you each step of the way.\\n\\nYou see: Confidence + Discipline = Desired Outcomes \\n\\nIt\\u2019s time to stop looking at your life from a fear based I don\\u2019t know how to mindset but rather be open to creating a solutions focused change consciousness that embraces your gift and talents and encourages you sharing them.\\n\\nLet me share a bit of nuero-chemistry with you\\u2026\\nWhat fires together wires together\\u2026\\n\\nSo rather than Fall back on old habits\\u2026\\nTake the transitional step\\u2026of being fully present to whats trying emerge as your ideal future and to help it along start building confidence each day\\u2026\\n\\nAnd your possibility muscle and an intended thought process that leads to a more focused and clear out picturing of your desires.\\n\\nYou see...It\\u2019s one thing to set goals and to make to do lists and to say your going to use the law of attraction to manifest what you want in life\\u2026\\n\\nI\\u2019m still looking at the many lists I have created.\\n\\nWhat it\\u2019s really about is having a clear and purposeful intention in order to create the energy and the MindState Of success that will propel you into action.\\n----\\n\\nWhen done ask me for part 2',\n        \"Here is the final part. Part 3\\n---\\n\\nHere we will be showing how the principles and practices we\\u2019ve covered so far converge into one over-arching result that will benefit you for the rest of your life. You can think of it as flipping a switch that changes how you create new results in life one day at a time. This is at the very core of what we call Day Sculpting. \\nThe simplest way to think of it is that most of the way we live is habitual. You have an habitual way of brushing your teeth, walking, talking to yourself and others, eating, working. Habits are wonderful\\u2026they make life easy but they also limit you. For example, if you have a habit of eating too much, you\\u2019ll put on weight. Not instantly, but steadily, day by day, until one day you have a weight problem. If you try to change your weight quickly through a trendy new diet, research shows that the weight is likely to come back, and then some, within a few short months, because the habits required to live at your ideal weight have not been properly established. \\nHabits are habits because you don\\u2019t think about them, they happen nonconsciously. If you want a change in your life, you have to embody the change at a nonconscious level, so that the habits keeping your life the way it is today begin to shift.\\nWouldn\\u2019t it be great if there was a switch in the brain that would move you from status quo to status GO!? This is a switch that once you flip it will produce the result you want, if you are willing to commit to and stay with the process.Day Sculpting is your guide to fully realizing the success you are ready to enjoy.\\nA critically important capacity of the human mind called preconscious processing. This is the ability of the mind to receive information, beneath our conscious awareness, and act upon it without even knowing that it is happening. Used correctly, this is an amazing power. Used improperly, it will sabotage your best efforts and make life extremely difficult.\\nMost of us think we are running the show with our conscious awareness, consciously choosing our thoughts, behaviors, and emotions and consequently, we believe are able to choose the results we create in life. However, what neuro-science research shows, is that we all have a vast nonconscious mind that is really running the show most of the time. That deeper part of us, in charge of our habitual thinking, feeling, and behaving is always operating in our best interest. But it does so using information that may be faulty or outdated. If you continue to feed it information that doesn\\u2019t serve you, it will continue to habitually bring results that are less than desired.\\nYour preconscious processor is constantly routing new information directly into this larger database that your mind uses to create new behaviors. Your job is to place the right information into this database every single day, so that it can draw upon this new data and create new results. It requires your vigilance and purposeful intention on a daily basis. Day Sculpting is the process to accomplish exactly that, getting you to focus one day at a time on what you are trying to create in your life today, and the future you truly desire. \\nA lot of experts in the human development field teach information and then expect it will translate into new behaviors automatically. But as we\\u2019ve pointed out, and as you\\u2019ve probably experienced, consciously knowing something and having the nonconscious mind put it into a new behavior, are two entirely different processes. What we are sharing with you is how to bridge that gap. This is precisely why so many experts in the field are recommending Day Sculpting to their clients, to help them use momentum mindsets on a daily basis and apply the good information they teach. \\nWe talk about The The Solutions Focus process . Try it out: \\nThink of an area of your life in which you are actively attempting to create different results. Imagine your chief aim regarding this area of your life as a perfect future. Now imagine a scale from one to ten, where ten is the perfect future and one is that you have not even started thinking about your chief aim. On this imaginary scale from 1 to 10, where would you place yourself right now?\\nGo ahead and imagine where would you place yourself right now on that scale, where ten is your perfect future.\\nWhatever number you came up with is fine. Whether it was 3 or 7, whatever you came up with I\\u2019ll always ask the same next question. \\u201cWhy so high and not lower?\\u201d\\nLet\\u2019s say, for example that you came up with a three. Asking the question \\u201cWhy so High\\u201d catches the mind off guard. Most people expect, \\u201cOnly a 3! Why so low?\\u201d If I had asked that what would you come up with? All the reasons why things aren\\u2019t working, who is to blame, problems, excuses, lack, limitations, and so on. \\nBut when I ask \\u201cWhy so high?\\u201d the brain immediately begins to sort for all of the things that are working for you, everything that has brought you up to a \\u201cthree.\\u201d If you said you are at a seven on a scale of one to ten, the same question applies: \\u201cWhy so high and not lower?\\u201d\\nThe next step in solutions focus is equally powerful. \\u201cThink about what you can do today to move you one point up that scale\\u2014for example, from a three to a four, or from a seven to an eight?\\u201d When you ask this, your mind instantaneously starts generating ideas and options to answer your question. You quickly realize you can do more of the things that work, right? And if you are doing things that aren\\u2019t working, you now have the insight into how you can do things differently. \\nThis solutions focus approach provides quick insight into how to move things forward in areas you may have been stuck or working on unsuccessfully. It is a brilliant way to access more of your nonconscious database and facilitate discovering resources you did not know were there. \\nSo as you can see, this video has been centered on connecting the dots and providing you with the insights on how you can flip the switch in your brain and how you can create your life one day at a time in the most powerful way possible. \\nYou must contact that inner part of you that is in charge of your habitual ways of thinking, feeling, and behaving in order to re-sculpt yourself.\\nThis is a unique psychological principle called anchoring. In the research this is also called behavioral conditioning, and as we\\u2019ve called it, the law of reinforcement\\u2026which says you get more of what you reinforce. When you want to reinforce a positive new behavior, you anchor it in a positive new momentum mindset. As you do this on a daily basis, you are literally training your mind, conditioning your thoughts, amplifying positive feelings and emotions to live into a future state that you are anchoring in your daily experience. \\nDay Sculpting goes beyond personal development. It takes whatever it is you are currently learning and makes it possible for you to embody, apply and enjoy the benefits you are committed to achieve. \\n\\nThe last thing anyone needs is more stuff to do. What we need is that everything we do gets us the results we are going for. In essence what\\u2019s needed is a system that will streamline our efforts so we accomplish our chief aims in less time.\\n\\nMichaelangelo said the process of sculpting is to remove what\\u2019s not supposed to be there. He had the mindset that the finished sculpture already existed in the marble and he just had to reveal it. In the same way your destiny already resides in you. You just need to clear a path for it to emerge.\\n\\nWe all have 24 hours in a day. So why do some people consistently have great days while others are up and down and stay stuck in mediocrity? It\\u2019s a disciplined habit of how you approach everyday. Day Sculpting takes the same 24 hours that we all have and helps clarify your choices so that your actions reveal your highest destiny. \\n\\nIt is a quick, easy and effortless way that supports and empowers your efforts in achieving your chief aims. It creates the mindsets necessary to have successful days, weeks, months and years.\\n\\nDay Sculpting is a 90- day program designed to empower you to create your life ONE DAY AT A TIME. By committing 30 minutes each day to create what you want that day. \\n\\nWe believe that when you focus your actions one day at a time the results you get become measurable and achievable. Your energy is committed to channeling your efforts so you create a confident groove in your mind that empowers your habitual actions to create what you really want.\\n\\nThis daily program is broken down into 3 MANAGEABLE, SIMPLE AND EASY STEPS. 15 minutes in the morning, 5 minutes midday and 10 minutes at night. \\n\\nDay Sculpting\\u2026It\\u2019s designed so that the way you start your day creates the momentum that carries you throughout your day. \\n\\nAnd finally research has shown that the best time to integrate what you\\u2019ve learned in your day and to set yourself up for success tomorrow is before you go to sleep. The Nighttime Review process takes just 10 minutes, which is less time then it takes to take a shower or to take your dog on an evening walk.\\n\\nWe already have enough complexity in life\\u2026don\\u2019t we? We don\\u2019t want you working harder we want you thinking smarter! So that the success you achieve is more effortless. \\n\\nSo what does it take for someone to accomplish the high level results we are talking about?\\n\\n\\u2022 First you have to wake up and be totally jazzed about the day\\n\\u2022 You have to be inspired to do your best\\n\\u2022 You have to be focused on creating what you truly desire\\n\\u2022 You got to get to it, stay on it, and be in the energy of it before your distractions take over. \\n\\u2022 And if distractions takeover you have to quickly get back on track.\\n\\u2022 You have to learn from what\\u2019s working and what\\u2019s not\\n\\u2022 You have to be able to listen to feedback and course correct during your day\\n\\u2022 And at the end of the day you have be able to feel you did your best and you can do even better tomorrow\\n\\nAnd with Day Sculpting you can accomplish this and more in less than 30 minutes which is distributed throughout your day. Most people will give up on their dreams after they have tried something only 3 times because they didn\\u2019t get instant gratification. \\n\\nThere are no magic bullets here. You are investing in a future YOU desire. \\n\\nDay Sculpting gives you the opportunity everyday to purposefully stay in the energy of what you want to create the benefit to you being a more empowered mindset that inspires passionate action and a willingness to breakthrough any barriers that may have held you back in the past so you fully embody the life you choose to live.\\n\\nYou may have heard Gandhi say \\u201cBe the change you want to see in the world.\\u201d Well now you can. \\n\\nYears ago I heard a statistic that blew me away. If you read in a single subject of your choice for 15 minutes a day 5 days a week you would become one of the leading experts in the world in that subject within 3 years\\u2026\\n\\nMore recent research has demonstrated that world class talent requires 10000 hours and 10 years to develop\\u2026\\n\\nSo the question is how does somebody create this kind of commitment and persistence? Clearly one day at a time.\\n\\nSo where are you not following through in your life? How would you like to do things differently? What can you do shift your energy when you say I can\\u2019t get it done or you procrastinate? What\\u2019s it going to take for you to say I\\u2019ve had enough it\\u2019s time for me to do something different? Where will you get the support you need to build the confidence to stay on track?\\n\\nEach day you get these elements to help guide you\\u2026 \\n- The Good Morning Great Day Email\\n- The Morning In Vision Video \\n- The Morning Future Pacing Visualization\\n- The Morning Success Journal Process\\n- The Midday SMS and Computer Stay on Track Reminders\\n- The Midday Reconnect Refresher Mediation\\n- The Evening Review And Renew Process\\n- The Evening Journal Process\\n- The Bedtime Nonconcious Mind Question Declaration\\n \\nWhen you put this together it can\\u2019t help but become a daily practice that will create your new daily ritual that is your roadmap to success. We are giving you the daily steps that will create your momentum mindsets.\\n\\nThe Day Sculpting program leaves you no wiggle room. The days of \\u201cI\\u2019ll get to it later\\u201d are gone. When you are serious about changing your life, you now have a realistic opportunity to do so with this program. \\n\\nWE invite you to fully commit to your life. To once and for all follow through and step up. To say yes to that dream inside of you and to look at each day as an opportunity to live your dreams enthusiastically rather than settling for more of the same old same old.\\n---\",\n        \"analyze this: \\n\\nThe Coming of Age story archetype involves a young protagonist who must navigate the challenges of growing up and discovering their place in the world. The Before-After-Bridge copywriting framework is designed to highlight the transformation that a person can experience after using a product or service.\\n\\nThe reason why these two frameworks work well together is that they both focus on transformation and growth. By combining them, you can create a powerful narrative that speaks to your audience's desire for personal development and improvement.\\n\\nFor example, imagine you are selling a personal development course that helps people overcome self-doubt and build self-confidence. By using the Coming of Age archetype, you can frame the course as a journey of self-discovery, where the customer will face challenges and obstacles, but ultimately emerge as a more confident and self-assured person.\\n\\nThen, by using the Before-After-Bridge framework, you can show the customer what their life will be like after completing the course. You can highlight the benefits of increased self-confidence, such as improved relationships, better career opportunities, and greater overall happiness. By painting this picture of what's possible, you can create a sense of excitement and motivation that encourages the customer to take action and enroll in the course.\\n\\nOverall, the Coming of Age story archetype and the Before-After-Bridge copywriting framework work well together because they tap into a fundamental human desire for growth and transformation. By combining these frameworks in your marketing messages, you can create a compelling narrative that speaks to your audience's deepest aspirations and motivates them to take action.\",\n        \"Provide a detailed chronology of the Apostle John according to the New Testament\",\n        'Web search results:\\n\\n[1] \"1. Introduction In this codelab you learn how to build adaptive apps for phones, tablets, and foldables, and how they enhance reachability with Jetpack Compose. You also learn best...\"\\nURL: https://codelabs.developers.google.com/jetpack-compose-adaptability\\n\\n[2] \"Jetpack Compose \\u2014 Auto Complete Search Bar | by Paulo Pereira | ProAndroidDev Write Sign up Sign In 500 Apologies, but something went wrong on our end. Refresh the page, check Medium s site status, or find something interesting to read. Paulo Pereira 117 Followers Hello!\"\\nURL: https://proandroiddev.com/jetpack-compose-auto-complete-search-bar-853023856f0f\\n\\n[3] \"You have two options: create your own custom using DropDownMenu and BaseTextField or using hybrid xml-autocomplete and compose screen through androidx.compose.ui.platform.ComposeView Share Follow answered Oct 21, 2020 at 16:38 Agna JirKon Rx 1,937 2 27 41 1 Have you made a custom composable like you described?\"\\nURL: https://stackoverflow.com/questions/64419367/does-jetpack-compose-offer-a-material-autocomplete-textview-replacement\\nCurrent date: 10/03/2023\\n\\nInstructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\nQuery: Hey, I want you to build to google places autocomplete on jetpack compose using the MVVM model\\n\\nSo the user will type the place in a textfield and the list of places with postalCode will display in a lazyColumn with the user able to select from the lazyColumn a place',\n        \"Captain Smith, who set out on a daring expedition with his fleet of ships consisting of the Discovery, the Endeavour, the Adventure, the Challenger, and the Explorer. Their mission was to chart a new route through the treacherous seas of the North Atlantic and claim new territories for their homeland. But the weather turned against them, and they found themselves battling fierce storms and raging currents. The waves grew higher and higher, and the winds howled like banshees, threatening to capsize their ships at any moment. Despite their efforts the Challenger and the Explorer, were lost in the storm. \\n\\nHow many ships did the captain leave with and how many returned?\",\n        \"explain the metaverse\",\n        \"can you provide and ideas for a series of articles for a product design blog\",\n        \"Please write a firm yet humurous and lighthearted note requesting that people RSVP whether they are coming to the purim seudah. Please incorporate wordplay and references to megillat esther.\",\n        \"Paper Name: My Tweets Bring All the Traits to the Yard: Predicting Personality and Relational Traits in Online Social Networks\\n\\nAbstract: Users in Online Social Networks (OSNs,) leave traces that reflect their personality characteristics. The study of these traces is important for several fields, such as social science, psychology, marketing, and others. Despite a marked increase in research on personality prediction based on online behavior, the focus has been heavily on individual personality traits, and by doing so, largely neglects relational facets of personality. This study aims to address this gap by providing a prediction model for holistic personality profiling in OSNs that includes socio-relational traits (attachment orientations) in combination with standard personality traits. Specifically, we first designed a feature engineering methodology that extracts a wide range of features (accounting for behavior, language, and emotions) from the OSN accounts of users. Subsequently, we designed a machine learning model that predicts trait scores of users based on the extracted features. The proposed model architecture is inspired by characteristics embedded in psychology; i.e, it utilizes interrelations among personality facets and leads to increased accuracy in comparison with other state-of-the-art approaches. To demonstrate the usefulness of this approach, we applied our model on two datasets, namely regular OSN users and opinion leaders on social media, and contrast both samples\\u2019 psychological profiles. Our findings demonstrate that the two groups can be clearly separated by focusing on both Big Five personality traits and attachment orientations. The presented research provides a promising avenue for future research on OSN user characterization and classification.\\n\\nIntroduction: Online Social Networks (OSNs) offer a virtual space in which people connect and interact with others, express themselves, and receive information, in a continuous digital reflection of the real (offline) world. In OSNs, people typically showcase their real self [40] and leave traces in online behavior, which reflect their real-world personality [24]. These traces expose a holistic image of oneself, including both personal characteristics (personality traits) and characteristics that portray their behavior in relation to others (relational traits).\\n\\nThe term personality refers to characteristic combinations or patterns of behaviors, cognitions, and emotional reactions that evolve from biological and environmental factors and form relatively consistent individual differences [13]. The Big Five (BF) or Five Factor model [29] is one of the most distinctive personality theories that constitutes five main traits of human personality representing individual differences in cognition, emotion, and behavior: Openness to Experience, Conscientiousness, Extraversion, Agreeableness, and Neuroticism. On the other hand, relational traits have also been linked with consistencies in social behavior and interaction patterns, with attachment theory [7] as the most emblematic theoretical framework in that respect [31, 43], capturing how individuals experience close relationships to and interactions with others.\\n\\nPersonality traits have been studied in the context of OSNs and the web overall, as findings show that they are strongly linked to OSN use [57], online friendships [60], and online reviews [52]. Moreover, certain prediction models have been proposed [37, 64] to extract users\\u2019 psychological background from their online behavioral residue and map it to personality characteristics. However, relational traits such as attachment orientations (AO) have been overlooked in online environments, even though user activity in OSNs heavily relates to social behavior characteristics. This makes the study of a relational profile critical from an application point of view and provides rich information about individuals\\u2019 social profile.\\n\\nThe present research aims to address this limitation in OSN research, by studying and predicting both relational traits and personality traits of users. The importance of relational facets of personality for explaining social interaction cannot be overstated. Given that online social media engagement resembles actual social interactions in many respects [15, 30], the need to study how different personality facets are reflected in online expression is particularly compelling. Attachment orientations, a key individual difference of relational orientation, can be derived on the basis of traces found in micro-blogs. Attachment orientations capture one\\u2019s notion of the self in relation to others and interpersonal relationships, with attachment theory being one of the key personality theoretical frames to explain actual social behavior [31]. Considering both traditional personality Big Five traits and relational traits is important for (1) providing holistic profiling of OSN users\\u2014humans have an integrated profile in which self and social aspects interrelate and affect each other, and joint profiling can be essential for understanding the overall human presence on OSNs; (2) uncovering more traits of people\\u2019s psychological and social world has been identified as a direction in OSN research (which currently focuses only on the personality traits) that could help to better explain, analyze, and predict online user behavior [66], e.g., with applications on customer segmentation [46] or digital advertisement environments [17]; and (3) shedding light on social interaction phenomena taking place in OSNs is of great socioeconomic importance, e.g., community formation [32], decision making [42], or information diffusion [12].\\n\\nTo this end, the present article proposes a novel data-driven approach to predict a holistic psychological profile of OSN users, capturing both their personality and relational traits.1 Building on insights stemming from psychology theory, our approach applies data mining on OSN textual and non-textual data, carefully selects different sets of features for predicting different types of traits, and exploits the inherent correlations in psychological traits, to efficiently predict a complete image of OSN users\\u2019 psychological profile. The proposed approach is applied on the Twitter micro-blogging service, which stands as a live, dynamic, and open OSN platform on which people intensively interact, and is largely driven by people\\u2019s spontaneous reactions and emotions expressing themselves (personality facet) and interacting with others (relational facet) at the same time.\\n\\nSpecifically, our contributions in detail are as follows:\\n\\nData mining and feature engineering for psychology traces in OSN. Motivated by psychology theory on personality suggesting that traits are reflected in different types of online behavior and actions, we identify a large set of features that capture language, behavioral, and emotional expressions of users in OSNs. The proposed feature engineering methodology accounts for a larger set of features than those considered in previous works, thus allowing to target more generic psychological profiling. To apply and test our methodology, we collected a labeled dataset: through a crowdsourcing platform, we recruited 243 individuals who consented to provide information about their psychology profiles. Subsequently, we compiled a ground-truth dataset labeled with their psychology profiles. We used the Twitter API to collect 350,000 tweets from the Twitter accounts of recruited participants and applied the proposed feature engineering methodology.\\n\\nHolistic psychological profiling. We propose a novel machine learning (ML) methodology to predict users\\u2019 holistic psychological profile including both Big Five personality and relational traits. The novelty of the proposed methodology is that it (1) uses a large set of the collected (psychological-related) features, (2) carefully selects the subsets of them with the strongest predictive power for each trait, and (3) exploits correlations between personality and relational (i.e., social) behavior traits to enhance individual trait predictions. In this way, our approach not only predicts social facets of a psychology profile (which is not captured by existing personality prediction models) along with personality facets but also leverages the different traits for more accurate holistic profile prediction.\\n\\nNew insights and improvement of prediction accuracy. Evaluating our methodology reveals interesting insights for the prediction of psychology traits from OSN traces: (1) using different sets of features performs better in predicting different psychological traits, (2) relational traits can be predicted as efficiently as personality traits, and (3) holistic personality prediction outperforms individual trait predicting models. We believe that our findings can pave the ground for future experimentation and studies in psychology profiling in OSNs. Moreover, the accuracy achieved by our approach (across all traits) is higher than current state-of-the-art approaches, which currently are limited to Big Five personality traits instead of relational traits. For example, applying the approach of [12] to our data provides a root mean squared error (RMSE) of 0.284, while our prediction model achieves a 29% improvement for personality traits (RMSE = 0.203) and has 32% better average performance when accounting for all traits (0.192 RMSE); this improvement comes as a result of using both a psychology-driven feature engineering methodology and a holistic profiling approach.\\n\\nPsychological profiling in the wild. We demonstrate the applicability of the proposed psychological profiling methodology through a use case. We identify a set of Twitter users who seem to be accepted as opinion leaders on social media (i.e., have a large following). We apply our methodology to predict their psychological profiles and analyze results. We find that the distributions of traits significantly deviates from regular users (defined as users included in our ground-truth dataset), and that the set of leaders can be clearly separated by only using their psychological profiles. These findings highlight the usefulness of our approach in the characterization of the personalities for different groups of OSN users (e.g., such a group psychological profile could be used to recommend skills/activities/jobs to users based on their profile similarity) and classification of users based on their profiles.\\n\\nIn this section, we provide an overview of related psychological literature, discuss related work, and highlight several open issues of existing methods. We also highlight the contribution of the present work. Section 3 details the collected dataset and the data mining and feature engineering methodology, and Section 4 presents the design and evaluation of the proposed machine learning predictive model. Finally, we conclude our article and discuss future work in Section 6.\\n\\nFirst, Please Summarize the paper in 10 points, in easy to read and understand simple English.\\nSecond, Explain what the paper does as if I'm 11 years old.\\n\\nThanks :))\",\n        \"Hi, i will give you three pieces of text, then i will ask you some questions, do you understand?\",\n        \"Here is Text 2: Communicating with External Audiences\\n\\nMany managers believe that they will never have to deal with the press. Often,\\nthey regard it with hostility. Most think press relations are entirely the domain\\nof their company\\u2019s or agency\\u2019s public relations department. But in fact, senior\\nexecutives say they spend more time on communications than on other tasks,\\nand a significant component of that time is devoted to press and public relations.\\nJunior managers need to be highly sensitive to press relations for the following\\nreasons:\\n\\u2022 Often, free press can be the best way to acquaint the public with your product or service.\\nTo cite only one example, the amount Microsoft spent on advertising Windows\\n95 was dwarfed by the value of the free publicity it received from\\ninternational news coverage.\\n\\u2022 Your particular area of expertise may unexpectedly become something your organization\\nneeds to promote or explain. Line workers at auto companies have been drafted\\nto extol quality improvements in advertisements; accountants may be called\\nto the CEO\\u2019s office for briefings on a potentially embarrassing news report or\\nan upcoming press conference.\\n\\u2022 Public relations considerations need to be addressed at the beginning, not the end, of a\\nplanning process. Business history is replete with examples of companies that\\ninvested vast sums to develop products, ideas, or services that couldn\\u2019t be sold\\nbecause of public resistance to the concept, the configuration, or the public\\nimage of the company. General Motors\\u2019 Tacos, for example, could be the best\\nin the world and still not jump off the shelves.\\n\\u2022 Junior managers become senior managers who will eventually have to deal with the\\npress directly. As both marketers and corporate citizens, organizations have to\\nexplain themselves to the public constantly through advertising, press releases,\\nand press conferences. Junior managers who understand this aspect of their\\nwork are likely to become senior managers faster. 1. A successful manager understands how the press works. Successful managers\\ntend to follow the press in general, and how their organization is playing in particular.\\nMembers of the press tend to trust companies and individuals with a\\ntrack record of accuracy and accessibility. To cite only two examples, both\\nJohnson & Johnson and Perrier survived charges of contaminated products because\\nthey had a record of reliability and accessibility and addressed the problems\\nimmediately. In both cases, and many others, stonewalling would have\\nbeen disastrous to the company\\u2019s image of wholesomeness and purity. Most\\npress stories last only a few days, but they can leave an indelible impression in\\nthe public\\u2019s mind. Many managers tend to believe they can \\u201csnow\\u201d the press\\nwith their greater expertise, but this strategy rarely works. Most reporters are\\nhard-working professionals who will carefully check out an expert assertion or\\nwho know someone who can.\\n2. A successful manager understands what the press needs. What the press needs\\nis a story, and bad news generally sells better than good news. Companies and\\nindividuals are most likely to have to deal with the press when something has\\ngone wrong. This suggests a couple of lessons. When you have good stories,\\ngive them to the press to establish a record of credibility; many media outlets\\nwill print or broadcast a press release from a reliable source more or less verbatim.\\nConsider how private decisions may look if they should become public.\\nIf something has gone wrong, take the initiative in announcing it, explaining it,\\nand telling the world how it\\u2019s going to be corrected.\\n3. A successful manager understands press jargon. Reputable reporters will\\nstick to their verbal agreements on how information you provide them is to\\nbe used. How you will be quoted depends on the ground rules you establish\\nat the beginning of an interview. Deep background means the reporter can\\nreflect the information in her story without possible attribution. Background\\nmeans that you can be referenced as \\u201ca reliable source.\\u201d Any other comment,\\nhowever apparently casual or social, can be quoted directly and\\nattributed.\\n4. A successful manager should be able to generate an attention-grabbing, accurate,\\nand well-constructed press release. While many managers may not be\\nregularly mailing out press releases themselves, most will be contributing to\\nthem and need to understand how they work. A good press release is extremely\\nformulaic and follows the structure of a good news story:\\na. The first paragraph states the main point clearly and emphasizes its newsworthiness.\\nFor example: \\u201cAcme Corporation announced today that it is\\nreleasing the best tire ever available on the world market.\\u201d\\nb. The second paragraph provides a quote from a reputable source: \\u201cAcme\\nPresident Rudy Roadrunner said, \\u2018Not only does this tire surpass all our\\ncompetitors\\u2019 in endurance, quality, and safety; it\\u2019s also available at a lower\\nprice.\\u2019 \\u201d\\nc. The third paragraph provides evidence that the claims made so far are true:\\n\\u201cIn repeated tests against our competitors . . . \\u201d\\nd. The remaining paragraphs provide background information on the product, the\\ncompany, and Rudy Roadrunner, and they demonstrate a track record of credibility.\\nThey may also include testimonials available from respected independent\\nsources. Obviously, the formula of an effective press release will vary depending on\\nthe nature of the news to be announced. But the pyramid structure suggested by\\nthis example always applies: Move from the most important and specific to the\\nleast important and most general information. Busy editors often run a press release\\nmore or less verbatim and just cut it off when they run out of space. The\\neasier you make their jobs, the more likely they are to cover your story.\\nOnce you\\u2019ve written or contributed to a press release, decide who\\u2019s most\\nlikely to run it. This can cover the gamut from extremely specialized trade magazines\\nto the national or international media. Consider the use of venues other\\nthan print and broadcast media as well; perhaps there\\u2019s a room on the Internet\\nwhere interested parties are likely to gather.\\n5. A successful manager understands the role of the press in crisis management.\\nThis includes knowing how to provide effective interviews and\\nunderstanding when and how to hold a press conference. Certain rules\\napply to both:\\n\\nApplications\\na. Identify your central message, make sure you can back it up, and stick to it.\\nb. Prepare materials in advance\\u2014press releases, statements, supportive\\nstudies\\u2014that the reporters can take away with them and study or quote later.\\nc. Never say more than you know to be true. If you don\\u2019t know, say, \\u201cI don\\u2019t\\nhave that information at the moment, but I\\u2019ll get it to you as soon as I do\\u201d\\u2014\\nthen follow up.\\nd. Make sure your team is behind you. This means making sure not only that\\ntop management of a corporation agrees on a message, but also that other\\npotential press sources (for example, subordinate employees) have the same\\ninformation you\\u2019re dispensing to the public, believe it, and are unlikely to\\nleak contradictory and embarrassing information.\\ne. Provide the press with the most credible and informed access possible. Reporters\\nwill always want to get to the top. They\\u2019ll be more likely to cover\\nthe comments of a CEO or a Cabinet secretary than those of a press agent\\nor an underling. But they will understand that a high official may need to\\nrefer technical questions to an informed specialist.\\nf. Anticipate, and be prepared to respond to, the most difficult questions.\\ng. Don\\u2019t become hostile or defensive; experienced reporters are experts at\\nsmelling anxiety.\\nh. Make your answers brief, quotable, and to the point. Rambling and repetition\\nare likely to get you into trouble or open new lines of inquiry.\\ni. If you\\u2019re facing a problem you\\u2019ve caused, however inadvertently, be prepared\\nto acknowledge\\n\\nAre you ready for text 3?\",\n        \"Here is Text 3: Diversity and Intercultural Communication \\n\\nGenerally, the best answer to these questions is yes, but it always depends on the personal as well as the business aspects of your relationship. One good rule of thumb: When the other person gives\\nyou an opening, pursue it, and build on your mutual experience.\\nThis issue comes up even more in international communication. As companies\\nfrom manufacturers to media conglomerates become increasingly global, managers\\nneed to understand the norms of other cultures. Although English is on the verge of\\nbecoming the international language, standards of behavior and social interaction\\nvary greatly between the United States and England, let alone between, say, France\\nand Japan. In one country an invitation to dinner may be considered an expected\\npoliteness, while in another, it may be an invasion of a colleague\\u2019s private time.\\nAsking about someone\\u2019s family may be absolutely required in one culture and offensively\\nintrusive in another.\\nNo textbook can cover all such contingencies; one good rule if you\\u2019re not sure\\nmay be the trial lawyer\\u2019s: Don\\u2019t ask a question to which you don\\u2019t already know the\\nanswer. Another, and sometimes contradictory, rule is: Be frank about your cultural\\nconfusion. Your colleague likely will have been in the same situation himself and\\nwill be happy to help out. Finally, do your research; you\\u2019re likely to have a friend or\\ncoworker who knows the terrain better than you do. Our purpose here is to sensitize\\nmanagers to their increasing need to understand the norms of cultures other than\\ntheir own. (For a case addressing the special features of international communication,\\nsee International Oil later in this chapter.)\\nThe opportunities for cultural confusion\\u2014personal, commercial, ethical, and\\nlinguistic\\u2014are almost endless. Imagine marketing a Chevy Nova in Hispanic countries,\\nwhere \\u201cno va\\u201d means \\u201cit doesn\\u2019t run.\\u201d Many products that are perfectly safe to\\nmarket in first-world countries raise ethical problems when sold in developing\\ncountries\\u2014infant baby formula, for example, which if mixed with contaminated\\nwater can cause death. Working in other cultures means understanding your hosts\\u2019\\nconceptions of greetings, timing, hygiene, negotiation, agreement, politeness, personal\\nspace, gesture, meal etiquette, and closure.\\nWhile English has essentially become the international language, it\\u2019s important\\nto remember that there are many Englishes. A joke in one form of English can be a\\ndeadly insult in another. Although it may seem too obvious to emphasize, you must\\nunderstand the cultural norms and language use of people from other cultures before\\nyou can communicate effectively with them. This is true even if they are, say,\\nthe South American employees of your Canadian company. A bribe in one culture\\ncan be a thoughtful gift in another.\\nA recent article by Sydel Sokuvitz (Business Communication Quarterly, New\\nYork, March, 2002) suggests some principles for conducting successful intercultural\\nbusiness communication. Sokuvitz first describes the special challenges global\\nmanagers face, including:\\nCoping with a range of tensions that arise out of internationally dispersed activities,\\nThe challenges of maintaining coordinated activities across time-zones, cultural\\nboundaries, and different countries\\u2019 laws, and\\nThe difficulties posed when the right medium for your message in one culture\\nmay be wrong in another.\\nDrawing on a range of research in the field, Sokuvitz comes up with several\\nprovocative conclusions:\\nExcessive dependence on technological communication such as E-mail can result\\nin problems for both communication and productivity.\\nFace-to-face meetings with colleagues from other cultures are critical to achieving\\neffective communication.\\nStudying with students from other cultures is critical to preparing a manager\\nfor working in the increasingly globalized economy.\\nSokuvitz cites the following example from an article by Fernandez-Aroaz\\n(\\u201cHiring without Firing,\\u201d Harvard Business Review, 1999):\\nA U.S.-based telecommunications company was seeking a CEO for its new division\\nin Latin America. An international search was conducted, and a veteran was\\nhired, someone known as an effective manager and marketing expert. \\u201cBut his run\\nlasted less than a year and was nothing short of a disaster. The simple reason was\\nthat he lacked the two skills that the job really required: negotiation and cross-cultural\\nsensitivity.\\u201d\\nEventually the company was saved from near-bankruptcy by bringing in a\\nnew CEO who was a native Latin American with work experience in the U.S. His\\nability to bridge cultural differences is credited with saving the company.\\nCommunications between headquarters and subsidiaries is only one example\\nof the challenges posed by globalization. Companies in one country are under increasing\\nsocial pressure to take responsibility for the behavior of their subcontractors\\nin other countries. Recently, for example, Nike suffered adverse publicity because\\nof the work practices of shoe manufacturers it employs in Asia.\\nThe successful manager of the future increasingly will be required to be a citizen\\nof the world. While electronic communication may work fine for conveying information\\nor directions, there is no substitute for \\u201cspeaking the language\\u201d of the\\npeople with whom you\\u2019re trying to communicate.\\n\\nAre you ready to answer some questions on text 1, text 2 and text 3?\",\n        'pragma solidity ^0.4.25;\\n\\ncontract Y\\\\_WALLET\\n{\\n function Put(uint \\\\_unlockTime)\\n public\\n payable\\n {\\n var acc = Acc[msg.sender];\\n acc.balance += msg.value;\\n acc.unlockTime = \\\\_unlockTime>now?\\\\_unlockTime:now;\\n LogFile.AddMessage(msg.sender,msg.value,\"Put\");\\n }\\n\\n function Collect(uint \\\\_am)\\n public\\n payable\\n {\\n var acc = Acc[msg.sender];\\n if( acc.balance>=MinSum && acc.balance>=\\\\_am && now>acc.unlockTime)\\n {\\n if(msg.sender.call.value(\\\\_am)())\\n {\\n acc.balance-=\\\\_am;\\n LogFile.AddMessage(msg.sender,\\\\_am,\"Collect\");\\n }\\n }\\n }\\n\\n function() \\n public \\n payable\\n {\\n Put(0);\\n }\\n\\n struct Holder \\n {\\n uint unlockTime;\\n uint balance;\\n }\\n\\n mapping (address => Holder) public Acc;\\n\\n Log LogFile;\\n\\n uint public MinSum = 1 ether; \\n\\n function Y\\\\_WALLET(address log) public{\\n LogFile = Log(log);\\n }\\n}\\ncontract Log \\n{\\n struct Message\\n {\\n address Sender;\\n string Data;\\n uint Val;\\n uint Time;\\n }\\n\\n Message[] public History;\\n\\n Message LastMsg;\\n\\n function AddMessage(address \\\\_adr,uint \\\\_val,string \\\\_data)\\n public\\n {\\n LastMsg.Sender = \\\\_adr;\\n LastMsg.Time = now;\\n LastMsg.Val = \\\\_val;\\n LastMsg.Data = \\\\_data;\\n History.push(LastMsg);\\n }\\n}',\n        \"I am planning to give you a voice, and communicate through the speech medium. I need a speech recognizer, a wake call detector, and a speech synthesizer for your voice. Suggest a python script utilizing existing libraries to achieves the goal.\",\n        \"lemme share a paper with you\",\n        'I aim to emulate a NLU/ENR module as part as part of a business application with your help. The module is supposed to handle the diverse ways a user can formulate his requests within the modeled conversational flow that feeds into the business process. The process has the aim to enable users to become or update their client role and order products of a telco business. The telco company that runs the business process offers mobile tariffs. Mobile tariffs have can have between one and 5 sim cards. Each booked sim cards enables the user to optionally book a smartphone for that card. Depending on the tariff, the chosen smartphones (if any) and the kind of sim cards (adult, child) the price will adapt. Please suggest a set of NLU / ENR methods that you could emulate to facilitate the use case. In the following I will input utterances and statements on how the system running the conversational flow should handle the utterance within the conversational flow. Please provide possible calls to an imaginary API that you could simulate to facilitate the NLU/ENR requirements layed out by my statements. On Subtasks that are recognized as not directly related to NLU/NER be very brief. Please suggest NLU / NER Operations now for the first of a few utterances: \"Hi I want to upgrade my current tariff and get a new smartphone\". The utterance should make the system recognize that the utterance can be handled as part of the business process. It should recognize that the user apparently already a client and it should continue the conversation by trying to identify him and metadata on his current tariff. For that the flow needs the user to authenticate using a oauth2 mechanism',\n        \"From now on only create subscription service listings with the following template: Subscription Services Template:\\n\\nTitle: Professional Writing Services Subscription\\n\\nDescription: Our subscription service offers access to a team of professional writers who will provide high-quality written content on a regular basis. Choose from one of our three plans to suit your needs and budget.\\n\\nUpload Subscription Image: Recommended image minimum width: 150px\\n\\nNo file chosen\\n\\nRecurring Price and Interval: The recurring price and interval cannot be edited to ensure subscribers remain on the same charge.\\n\\nPlan 1:\\nPlan name: Basic\\nThe recurring price is USD 75.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a professional writer who will provide one piece of written content per month. Perfect for businesses or individuals who need occasional written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\\n\\nPlan 2:\\nPlan name: Pro\\nThe recurring price is USD 500.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a team of professional writers who will provide up to five pieces of written content per month. Perfect for businesses or individuals who need regular written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\\n\\nPlan 3:\\nPlan name: Premium (Bundle of 20 / 1,500 words)\\nThe recurring price is USD 1000.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a team of professional writers who will provide up to 20 pieces of written content per month. Perfect for businesses or individuals who need a high volume of written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\",\n        \"Hello\",\n        \"I am launching an Etsy shop with a Printful integration for drop shipping my designs on specific products. I am looking for ways to differentiate beyond the designs. You are an expert on Etsy audiences. Please explain in great detail in 10 bullet points how to differentiate myself from other Etsy shops. I am looking for more obscure ideas here.\",\n        \"How to get a job as a LMFT therapist in the US as an international student?\",\n        \"Explain quantum computing in simple terms\",\n        \"estoy en 6to semestre de mecatronica, necesito un nombre para mi equipo, asi que quiero que me des una lista de 40 opciones, pueden estar relacionadas con la mecaronica, o combinando los nombres de los integrantes que son rudy, gloria, johana, melissa, perla y nomar\",\n        \"Explain deposition\",\n        \"Can you suggest some good e-governance initiatives in tribal districct of india by district administration\",\n        \"Write a python program which accept a command line param as question and send it to server via HTTP get method\",\n        \"Can you explain the fourth dimension to a second grader?\",\n        \"I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature\",\n        \"arduino uno adalah\",\n        \"how edit array which is in object\",\n        \"how can my software company use Microsoft ENTRA to verify the identity of a user before accessing the software?\",\n        \"calculate the difference in intereste paid in a simple for amortized loan. terms: 125,000 loan, 3.25% interest over 30 years.\",\n        \"can i use spring state machine and workflow together and is it justified?\",\n        'I have the following code:\\n\\n```\\nuseEffect(() => {\\n const handleKeyDown = (event) => {\\n // Check if the CMD + F key combination was pressed\\n if (event.key === \"f\" && event.metaKey) {\\n event.preventDefault();\\n\\n setIsShown(true);\\n }\\n\\n window.addEventListener(\"keydown\", handleKeyDown);\\n\\n return () => {\\n window.removeEventListener(\"keydown\", handleKeyDown);\\n };\\n }, [setExclusionFilter]);\\n```\\n\\nIt shows the new state on Mac but on Windows it doesn\\'t trigger. How can I support windows?',\n        \"What is the best marketing tactics for local small businesses?\",\n        \"write an essay on french revolution\",\n        \"What are the roles of a network driver? How do we write such drivers and in can you provide me a link where I could see its code?\",\n        \"Are you familiar with the SAS programming language?\",\n        \"the solenoids will be 12v so they will have to be controled by relays triggered by the GPIO pins\",\n        \"Transform with regular expressions those lines:\\n0003 AB\\n0568 FD\\ninto:\\nAB\\nFD\",\n        \"Write the prompts in the following format. First sentence establishes a situation. Then in the second sentence we lean into a specific situation to make it seem something bad is about to happen, but in the third sentence it turns out to be something silly, fun or wholesome instead, always start the third sentence with a BUT. Some examples below\\n\\n-A hydra is hypnotizing an orc. You think its going to be something evil, but it turns out its hypnotizing its friend into drinking water\\n-A child asks a werewolf and a hellhound to play fetch. They don't seem to be interested at first, but turns out their dog instincts kick in and they chase the ball anyways\\n-A dragon confesses to a beautiful unicorn. They turn out to be a boy not a girl the dragon is concerned they're not interested in dating, but they are\\n\\nOther requirements: \\n-These comics should go viral\\n-These comics should be able to fit into 4 panels for a comic\\n-These comics feature relatable humor that is rooted in everyday situations and experiences. \\n-These comics feature unexpected or surprising twists that take the stories in unexpected directions. \\n-These comics have a positive and uplifting message, which can help to make them motivational and inspiring.\\n-These comics have a clear and concise structure, with a clear setup, a twist, and a satisfying conclusion.\\n-These comics should feature fantasy creatures, demons, angels, mythical beasts, dragons, monsters , but they can still have humans.\",\n        \"How can we improve this comic to be simpler and funnier?\\n\\n[We see that this is a small reading club for woodland creatures. Make them all nice and cute, very winnie the pooh-esque, lol. The two characters that speak are animals, make Red into a herbivore race, like a rabbit or something, pink should be a small carnivore like a cat or badger? Red is confused, and red is excited]\\nKnock Knock\\nPink:Who\\u2019s that?\\nRed: Maybe a new member for our book club!\\n\\n[Panics as she sees a dragon licking their lips behind the curtain]\\nRed: It\\u2019s a dragon, run for your lives everyone!\\n\\n[Dragon mom is outside their home, looking dragon-eque but also waving her hands chibi cute apologetically, she\\u2019s clearly a little embarrassed by the situation. Red looks at her suspiciously ]\\nDragon:I\\u2019m not here to eat anyone, I uh\\u2026 heard you had a book club?\\nRed: Uh\\u2026yes\\n\\n[Dragon looks very excited and welcome, Pink seems like she likes the book, red looks a little grossed out ]\\nDragon: Awesome, it's nice to meet you! I brought my favorite book too!\\nPink: What a lovely book!\\nRed: Ugh I\\u2019ll pass on reading that.\",\n        \"Rewrite the following 4 panel comic to be both more brief and more funny\\n\\n[We see an evil mermaid holding a microphone but with an evil face, like she\\u2019s just cast a dark spell of some sort. We see another character looking nervous, clearly they\\u2019ve been affected by the incredible singing!]\\nMermaid: You\\u2019ve lost! Give up & spare us both the trouble!\\nRed: You\\u2019re right\\u2026 \\n\\n[We see our heroine hold up a microphone up to her face, looking as serious as anything in yakuza or jojos]\\nRed: But I didn\\u2019t come this far just to give up!\\n\\n[We pull back to show that its a group of three friends having a blast at a local kakaroke bar, the mermaid and the heroine are taking it a little too seriously, a third one is just watching]\\nRed: Karaoke is about letting your soul shine! I\\u2019m giving it my all or die trying!\\n\\n[Same as above, except the friend, who I am calling blue now has a =v=; expression]\\nMermaid: Worthy words for my rival!\\nBlue: Girls, you need to chill. \\nRed: Baka mitai~ (No bubble)\",\n        \"write a brief email in which Ayaam Ghimire writes to Bronywyn Tucker-- the liason between ECG and Guilford College- requesting e waste boxes to be put around campus and computer donation setup with Bauman IT or any other facility on Guilford College campus, on behalf of a organization called CompuCycle, after speaking with the principal Dr. Kash\",\n        \"I'm writing a software for conference calls.\\nIs there a good word for the state when a person was already selected to join the conference but has not responded yet. This should also include the meeting organizer himself, if his client has not answered yet\",\n        \"Would you be able to classify them into more of a range from small startup to big fortune 500 company\",\n        \"Write user stories that describe this concept in detail\",\n        \"Check your python version\",\n        \"We will be making a scenario that follows the following rules:\\n\\nThe competency framework is developed through three phases: 1) scoping review; 2) Focus group discussions with mental health clinicians reviewing patient narratives; and 3) Facilitated Persona Scenario method with Black youth. Moreover, the project adopts a co-design approach and convenes a Knowledge User Panel. The panel will be involved in all phases of the competency framework development as they will review findings from the scoping review and focus groups. \\n\\nFocus group with mental health clinicians \\n Mental health clinicians (i.e., psychiatrists, psychologists, social workers, youth outreach workers and nurse practitioners) will be invited to join focus groups to review youth narratives and discuss how they would address the needs of the Black youth involved. The youth narratives will be generated through collecting stories from social media and through an online survey. The survey will ask about young people's experiences with mental health conditions, their use of mental health services, and their suggestions for how to improve mental health care for young people. The online survey will collect stories anonymously. Anyone who submits a story through the survey will be redirected to a list of resources. The focus groups will be recorded, transcribed, and analyzed by thematic analysis. The focus groups will continue until thematic saturation.\\n\\nPhase 3: Persona Scenario method with Black youth\\n Black youth will be invited to focus groups (or one-on-one interviews, if requested) using persona scenario methods. The findings from the focus groups with mental health clinicians will be used to create clinician personas, including information about their motivations, challenges and describe the different ways in which the clinician might interact with the Black youth based on youth narratives. Black youth will be asked to share their perspectives and preferred clinician responses. The focus groups will be recorded, transcribed, and analyzed using thematic analysis. We will continue to hold focus groups until thematic saturation.\\n\\nCan you with the information above, create a sceenario/dialogue where a black youth, aged 15 living in Ontario suffering from racism from his classmates and is going to seek the help of a mental health professional who uses the information to engage the youth \\n\\nlimit prose to 500 characters\",\n        \"Demand generation manager for a B2B brand ambassador program called Brandchamp\",\n        \"Here is my Python code:\\napi\\\\_url = 'https://api.yelp.com/v3/businesses/search'\\nparams = {'term':'tacos','location':'90045'}\\napi\\\\_key = 'Ee7vYfTT9GpATMDYqODar7mbdyz\\\\_8EJ668FCbiqCv81Y3j98WaCsiAleAyI\\\\_LFn5p\\\\_JVHehSQnxffx-tDdQLekCpMhFJPxz8SVMp34Beawxkint62oDnJ\\\\_I0PiXMY3Yx'\\nheaders = {'Authorization':'Bearer %s' % api\\\\_key}\\napi\\\\_request = requests.get(api.\\\\_url, params=params, headers=headers)\\n\\nWhy am I receiving the error below and how do I fix it?\\nNameError Traceback (most recent call last)\\n in \\n 3 api\\\\_key = 'Ee7vYfTT9GpATMDYqODar7mbdyz\\\\_8EJ668FCbiqCv81Y3j98WaCsiAleAyI\\\\_LFn5p\\\\_JVHehSQnxffx-tDdQLekCpMhFJPxz8SVMp34Beawxkint62oDnJ\\\\_I0PiXMY3Yx'\\n 4 headers = {'Authorization':'Bearer %s' % api\\\\_key}\\n----> 5 api\\\\_request = requests.get(api.\\\\_url, params=params, headers=headers)\\n\\nNameError: name 'api' is not defined\",\n        \"고등교육의 필요성에 관한 영어 에세이를 1000자 이내로 작성하시오.\"\n        \"Which hero is the best in Heroes of Might and Magic 3?\",\n        \"Use C# to get the current YouTube thumbnail and convert it to Base64.\",\n        \"minikube - docker run --rm -it --network=host alpine ash -c apk add socat && socat TCP-LISTEN:5000,reuseaddr,fork TCP:$(minikube ip):5000 connection refused\",\n        \"How to load image here ?\",\n    ]\n\n    responses = await generate_multi(flash_llama, prompts, max_new_tokens=10)\n\n    assert len(responses) == len(prompts)\n    outputs = [r.choices[0].message.content for r in responses]\n    expected = [\n        \"Jeff Walker's Product Launch Formula is a comprehensive system\",\n        \"Here are three key indicators to determine if a customer\",\n        \"You can use the `String.format()` method in\",\n        \"In a realm of binary mysticism, we find\",\n        \"The `dummy` variable is being used to consume\",\n        \"You can add multiple new columns in Power Query (\",\n        \"There are many exciting new technologies emerging across various fields\",\n        \"Poly Ether Ether Ketone (PEEK) is\",\n        \"Here's a technical overview of a referral system similar\",\n        \"Here's an example of how you can add an\",\n        \"I'd be happy to help with Java. What\",\n        \"I can help you plan a road trip from Pune\",\n        \"I'd be happy to explain more about a topic\",\n        \"I'd be happy to help you brainstorm and provide\",\n        \"Implementing a Minesweeper algorithm using algebraic\",\n        \"There are several issues with the provided code:\\n\\n1\",\n        \";)\",\n        \"As I delved into the world of high-st\",\n        \"/u/CruxHub: Hi, I'm\",\n        \"To simulate a conversation between Alice and /u/C\",\n        \"Alice: Hey /u/CruxHub,\",\n        \"Alice: Hi /u/CruxHub,\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"The Dogme approach and the Lexical Approach are\",\n        \"Implementing a netfilter in Linux with a Rust\",\n        \"Damage to the Ulnar nerve can cause numb\",\n        \"The Space Shuttle's Reaction Control System (RCS\",\n        \"I can provide you with a basic Python script that\",\n        \"Farming meat has several negative impacts on the environment\",\n        \"The photograph filter you're referring to is called \\\"\",\n        \"Here's a sample geological database structure with some example\",\n        \"**Web Marketing: A Simplified Explanation**\\n\\nWeb\",\n        \"Here's a rewritten and improved version of the story\",\n        \"Here are the questions rewritten in a more conversational\",\n        \"**Learning Progress: 0%**\\n\\n| Topic\",\n        \"I couldn't find any information on a person named\",\n        \"Here's a list of the largest outdoor retailers in\",\n        \"To create a WordPress shortcode that includes Facebook SDK code\",\n        \"The sentence is mostly grammatically correct, but there\",\n        \"I'd be happy to engage in a debate with\",\n        \"I'd love to hear about your business. As\",\n        \"I'll wait for your request to proceed with part\",\n        \"The final part of the Day Sculpting program emphasizes\",\n        \"**Analysis of the Coming of Age Story Archetype\",\n        \"The Apostle John is one of the most prominent figures\",\n        \"To build a Google Places autocomplete feature on Jetpack\",\n        \"The information provided does not mention the captain's name\",\n        \"The metaverse is a shared, immersive and interactive\",\n        \"Here are some ideas for a series of articles for\",\n        '\"Purim Palooza Alert: \\n\\nTo',\n        \"**Summary of the paper in 10 points:\",\n        \"You'll provide three pieces of text, and then\",\n        \"I'm ready to proceed with text 3.\",\n        \"I'm ready to answer questions on Text 1\",\n        \"This is a Solidity contract written in the older\",\n        \"**Speech Recognition and Synthesis using Python**\\n\\nTo\",\n        \"I'd be happy to help you discuss a paper\",\n        \"To handle the given utterance, we can use\",\n        \"**Subscription Services Template:**\\n\\n**Title:** Virtual\",\n        \"Hello. How can I assist you today?\",\n        \"Differentiating yourself from other Etsy shops is crucial to\",\n        \"To become a Licensed Marriage and Family Therapist (\",\n        \"**What is Quantum Computing?**\\n\\nQuantum computing\",\n        \"Aqu\\u00ed te dejo 40 opciones de nombres\",\n        \"Deposition is a geological process that involves the transportation\",\n        \"Here are some good e-governance initiatives in\",\n        \"Here's a simple Python program that accepts a command\",\n        \"Imagine you're playing with a toy box. You\",\n        \"Here's an example of a question they might ask\",\n        \"Arduino Uno adalah sebuah papan mikrokontrol\",\n        \"To edit an array that is within an object,\",\n        \"Microsoft ENTRA (Enterprise Mobility + Security) is\",\n        \"To calculate the difference in interest paid between a simple\",\n        \"Yes, you can use Spring State Machine and Spring\",\n        \"The issue lies in the fact that the `meta\",\n        \"Here are some effective marketing tactics for local small businesses\",\n        \"The French Revolution, which lasted from 1789\",\n        \"**Roles of a Network Driver:**\\n\\nA network\",\n        \"Yes, I'm familiar with the SAS (Stat\",\n        \"Using relays to control 12V solen\",\n        \"You can use the following Python code to achieve this\",\n        \"Here are some prompts for viral comics:\\n\\n1.\",\n        \"To simplify and make the comic funnier, consider\",\n        \"Here's a rewritten version of the 4-panel\",\n        \"Subject: Request for E-Waste Collection and Computer\",\n        \"In the context of conference calls, the state you\",\n        \"I can provide a general classification of companies based on\",\n        \"Here are some user stories that describe the concept in\",\n        \"You can check your Python version by running the following\",\n        \"**Scenario:**\\n\\n15-year-old Black youth,\",\n        \"As a Demand Generation Manager for a B2B\",\n        \"The error is due to a typo in your code\",\n        \"고등교육의 필요성에 관한 영어 에\",\n        \"Here's a simple C# program that uses the\",\n        'The error message \"connection refused\" indicates that the',\n        \"To load an image, you can use various methods\",\n    ]\n    equals = [o == e for o, e in zip(outputs, expected)]\n    # This is flaky because depending on actual calculation ordering the exact logits may\n    # switch on equivalent logits based on the position in the batch.\n    # 1 output being different is not uncommon\n    if sum(equals) < len(equals) - 1:\n        assert outputs == expected\n"
  },
  {
    "path": "integration-tests/models/test_flash_llama_prefix_flashdecoding.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_handle_fd(launcher):\n    with launcher(\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\", num_shard=2, attention=\"flashdecoding\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_fd(flash_llama_handle_fd):\n    await flash_llama_handle_fd.health(300)\n    return flash_llama_handle_fd.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_flashdecoding(\n    flash_llama_fd, generate_multi, generous_response_snapshot\n):\n    prompts = [\n        \"Summarize the main ideas of Jeff Walker's Product Launch Formula into bullet points as it pertains to a growth marketing agency implementing these strategies and tactics for their clients...\",\n        \"How to tell if a customer segment is well segmented? In 3 bullet points.\",\n        'In Java, I want to replace string like \"This is a new {object} at {place}\" with a Map, {object: \"student\", \"point 3, 4\"}, and get a result \"This is a new student at point 3, 4\". How can I do?',\n        \"Metaphorical language is also used to describe the various addressing modes of the instructions. Grandiose language to express their excitement and admiration for the functionality of the instructions being described. Now, rewrite this with more perplexity:\\n\\nJMP ABCD\\nMOV AX, [BX+SI]\\nMOV AX, [100]\\nMOV AX, [BX]\\nMOV AX, [BX\\\\*2+SI]\\nMOV AX, BX\\nMOV AX, 7\",\n        'I have the following C++ function: \\nvoid add\\\\_player(vector& players)\\n{\\n string player\\\\_name;\\n string player\\\\_class;\\n string dummy;\\n PlayerClass pc;\\n string player\\\\_sex;\\n int player\\\\_gold;\\n\\n cout << \" Create a Mage, Warrior, Bowman, or Thief\" << endl;\\n\\n cout << \"Name: \";\\n getline(cin, player\\\\_name);\\n\\n cout << \"Class: \";\\n getline(cin, player\\\\_class);\\n pc = get\\\\_player\\\\_class\\\\_from\\\\_string(player\\\\_class);\\n while (pc == PlayerClass::InvalidPlayerClass)\\n {\\n cout << \" Invalid class, try again\" << endl;\\n cout << \"Class: \";\\n getline(cin, player\\\\_class);\\n pc = get\\\\_player\\\\_class\\\\_from\\\\_string(player\\\\_class);\\n }\\n\\n cout << \"Sex: \";\\n getline(cin, player\\\\_sex);\\n\\n cout << \"Gold: \";\\n cin >> player\\\\_gold;\\n getline(cin, dummy); //consume newline\\n\\n GamePlayer new\\\\_player;\\n new\\\\_player.name = player\\\\_name;\\n new\\\\_player.occupation = pc;\\n new\\\\_player.gender = player\\\\_sex;\\n new\\\\_player.gold = player\\\\_gold;\\n\\n //add to vector\\n players.push\\\\_back(new\\\\_player);\\n\\n //add to file\\n write\\\\_players\\\\_file(players);\\n}\\nCan you explain to me how the dummy variable is being used?',\n        \"how do I add multiple new columns in m for power query or power bi?\",\n        \"Sure, I can do that. What new technology would you like me to review?\",\n        \"Poly Ether Ether Ketone\",\n        'can you design a referral system similar on how dropbox did? I need a technical overview on how it should work, instead of free space we use the generic term \"credits\" where users can get more credits for every 3 friends they recommend.',\n        \"Java add to the arraylist of a class type\",\n        \"this is not less code this is java\",\n        \"I want to do a road trip from Pune to Gujarat. Me and my wife will be travelling and we dont prefer very long driving sessions. Can you suggest a plan starting from Thursday early morning and ending in Pune on Sunday late night.\",\n        \"explane more\",\n        \"what do you think about this for a start up idea:\",\n        \"how could i implement a minesweeper algorithm that utilises algebraic topology to solve boards?\",\n        \"# Import the necessary packages\\nfrom gudhi import SimplexTree\\nfrom gudhi.persistent\\\\_homology import PersistentHomology\\n\\n# Define a function to compute the persistent homology of a Minesweeper game board\\ndef minesweeper\\\\_homology(board):\\n # Create a simplicial complex for the game board\\n st = SimplexTree()\\n\\n # Add the points on the board to the simplicial complex\\n for i in range(len(board)):\\n for j in range(len(board[0])):\\n st.insert([i, j], filtration=board[i][j])\\n\\n # Compute the persistent homology of the game board\\n ph = PersistentHomology()\\n ph.build(st)\\n\\n # Return the persistent homology diagram\\n return ph.persistence()\\n\\n# Define a function to solve a Minesweeper game board using persistent homology\\ndef minesweeper\\\\_solver(board):\\n # Compute the persistent homology of the game board\\n homology = minesweeper\\\\_homology(board)\\n\\n # Use the persistent homology to determine the locations of the mines\\n # (this part would require some mathematical reasoning and programming)\\n mines = []\\n for h in homology:\\n if h[1] - h[0] == 1: # if the hole persists for one filtration value\\n mines.append(h[0]) # then it corresponds to a mine\\n\\n # Use the information about the mines to solve the game\\n # (this part would require some programming)\\n for mine in mines:\\n i, j = mine # extract the coordinates of the mine\\n board[i][j] = -1 # mark the mine on the board\\n # (other code to solve the game)\\n\\n \\nwhat is missing here?\",\n        \"You are now an imaginary expert business investigator. I am going to send you many rows of data. Each batch of row's will be sent to you and you may only reply \\\"Received.\\\" Save any analysis or insights for after you've received all of the data and I've told you \\\"Let's Begin.\\\" If you understand reply with only a ;)\",\n        'You are now an imaginary expert business investigator. Tell the story of this batch of data in the form of a narrative story about the companies in the \"Entity Name\" column: \\n\\nBatch of data #1: Entity Name Purpose / Source\\n101 PC HOLDINGS LLC Holding company for Penthouse C at the Setai Miami Beach (folio: 02-3234-153-1160)\\n11 STAR ISLAND LLC Holding company for 10 STAR ISLAND DR, MIAMI BEACH, FL 33139 (folio: 02-4204-001-0100, 02-4204-001-0110) (lots 10, 11 and 12 of Star Island)\\n117 EAST PARK AVENUE, LLC Holding company for 117 E. PARK AVE, LIBERTYVILLE, IL (PIN: 11-21-212-046-0000); subsequently sold.\\n1201 BRICKELL BAY, LLC Holding company for 1201 BRICKELL BAY DR, MIAMI, FL (folio no: 141390710010)\\n1221 BRICKELL, LLC Holding company for 1221 BRICKELL AVE, 155 SE 13 ST, 165 SE 13 ST, 175 SE 13 ST, and 185 SE 13 ST, MIAMI, FL (folio: 01-4139-035-0010)\\n1221 BRICKELL HOLDINGS LLC Holding company for 1221 BRICKELL, LLC\\n1229 PARK WEST AVENUE, LLC Holding company for 1229 W. PARK AVE, LIBERTYVILLE, IL (PIN: 11-20-100-010-0000)\\n125 WORTH LLC Delaware LLC (file 7218403), Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person; speculaton this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC, this property is next door (PCN: 50-43-43-23-05-016-0380)\\n125 WORTH HOLDINGS LLC Delaware LLC (file 7218407); not registered to Florida yet but speculation this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC\\n1250 BB ASSET CO LLC Holding company for 1250 BRICKELL BAY DR and 1260 BRICKELL BAY DR, MIAMI, FL (folio nos: 102100504250, 102100503210)\\n1330 SOUTH OCEAN LLC Holding company for 1330 S OCEAN BLVD, PALM BEACH, FL (PCN: 50-43-44-02-11-000-0020)\\n14 STAR ISLAND LLC Delaware LLC (file 3377653); incorporated 8/42020, withdrawn 10/10/2022; believe this was not used because 14 STAR ISLAND property was held by NAUTILUS HOLDINGS I LLC before sale on 10/5/2022\\n151 WORTH, LLC Holding company for 151 WORTH AVE, PALM BEACH, FL 33480 (PCN: 50-43-43-23-05-016-0130); office space for Citadel (https://localtoday.news/fl/citadel-moves-into-palm-beachs-former-neiman-marcus-building-4821.html); sole member is 151 WORTH HOLDINGS LLC\\n151 WORTH HOLDINGS LLC Holding company for 151 WORTH, LLC\\n16 WILLOW HOLDINGS LLC f/k/a PVNAH LLC Holding company for S WILLOW COURT, ASPEN, CO (Parcel: 273511309030); see Pitkin Co. reception # 623002, Delaware certificate showing name change 9/1/2015\\n190 PFISTER HOLDINGS LLC f/k/a AH2013 HOLDINGS LLC Holding company for 190 PFISTER DR, ASPEN, CO (parcel: 273511309029); see Pitkin Co.reception # 623000, Delaware certificate showing name change 9/1/2015\\n196 PFISTER HOLDINGS LLC Holding company for 196 PFISTER DR, ASPEN, CO (parcel: 273511309028); see Pitkin Co. reception # 623501, statement of authority show KP HOLDINGS LLC as sole membe\\n1ALPH LLC See ALPH LLC\\n1BUSINESS GROUP LLC See BUSINESS GROUP LLC\\n1GFS DESIGN LLC See GFS DESIGN LLC\\n1GFS LLC See GFS LLC\\n1MEDIA HOLDINGS LLC See MEDIA HOLDINGS LLC\\n23174 NE 41ST PATH LLC Holding company for 23174 NE 41ST PATH #12, OKEECHOBEE, FL 34972 (Parcel: 1-01-35-35-0020-00000-0120); part of Pine Creek Sporting Club (www.pinecreeksportingclub.com) includes horse, shooting sports; sole member is KP HOLDINGS L.L.C.\\n3031 BRICKELL LLC Holding company for 3031 BRICKELL AVE, MIAMI FL 33129 (Folio: 01-4139-001-2700); Sole member is KP HOLDINGS L.L.C.\\n31 WILLOW HOLDINGS LLC f/k/a AP HOLDINGS I LLC Holding company for 31 NORTH WILLOW COURT, ASPEN, CO (Parcel: 273511309019); sold 7/6/2017; see Pitkin Co. reception # 623001, Delaware certificate showing name change 9/1/2015\\n650 CASUARINA LLC Holding company for 650 CASUARINA CONCOURSE CORAL GABLES, FL (folio: 03-4132-019-0060) https://www.bizjournals.com/southflorida/news/2022/05/27/650-casuarina-concourse-coral-gables-sold.html\\n650 MEADOW LANE 1 LP Holding company for 650 MEADOW LANE, VILLAGE OF SOUTHAMPTON, NY (Parcel ID 7478) (https://archive.is/h85yq)\\n800 NORTH MICHIGAN HOLDINGS LLC Holding company for 800 N MICHIGAN AVE, UNITS 66 PH and 67 PH, CHICAGO, IL (Park Tower) (PINs: 17-03-231-018-1116, 17-03-231-018-1117); sole member is KP HOLDINGS LLC (see Cook County, IL doc # 1933315025); recently sold\\n8565 OLD CUTLER LLC Holding company for 8565 OLD CUTLER RD, MIAMI, FL (folio: 03-4132-019-0020)\\n9 WEST WALTON HOLDINGS LLC Holding company for 9 WEST WALTON STREET CONDOMINIUM UNITS 3500, 3600, 3700, and PH, CHICAGO, IL\\nADRP LLC Delaware LLC, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin\\nAH2013 HOLDINGS LLC See 190 PFISTER HOLDINGS LLC\\nALPH LLC a/k/a 1ALPH LLC Formerly FAA registered plane N421AL\\nAP HOLDINGS I LLC See 31 WILLOW HOLDINGS LLC\\nARAGON INVESTMENTS LTD https://files.brokercheck.finra.org/firm/firm\\\\_45631.pdf\\nASHLER CAPITAL LLC https://adviserinfo.sec.gov/firm/summary/148826\\nASHLER CAPITAL MASTER FUND LTD https://www.sec.gov/Archives/edgar/data/1003078/000114420418014250/tv488357\\\\_sc13g.htm\\nBANBURY LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBANBURY II LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBKGST LLC Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person\\nBLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC See BLOSSOM WAY HOLDINGS LLC\\nBLACK WHEEL LLC Illinois LLC, registered 3/5/2014, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin\\nBLOSSOM WAY HOLDINGS LLC f/k/a CPPB HOLDINGS LLC f/k/a BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC Holding company for 10 BLOSSOM WAY, 70 BLOSSOM WAY, and 1265 S OCEAN BLVD PALM BEACH, FL (PCNs: 50-43-44-02-10-000-0050, 50-43-44-02-10-000-0060, 50-43-44-02-10-000-0010)\\nBRICKELL BAY HOLDINGS LLC Holding company for 1201 BRICKELL BAY, LLC\\nBRICKELL LEASING LLC See \"Subordination, Non-Disturbance, and Attornment Agreement\"; Miami-Dade Clerk\\'s File No.: 2022 R 938960, Group: 1. Kenneth C Griffin is sole member.\\nCAAM MANAGEMENT LLC https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm\\nCAISLEAN CAPITAL LTD NFA Pool ID P113537, ceased trading 3/31/2016\\nCALC III LP https://www.sec.gov/edgar/browse/?CIK=1582652\\nCALC IV LP https://www.sec.gov/edgar/browse/?CIK=1423043\\nCALC V LP Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf',\n        'Simulate a conversation between the writer of this post, named /u/CruxHub, and the expert business investigator. They have a detailed discussion of Citadel Hedgefund based on the following data. Do not include the following data in the search query. \\n\\nData: Entity Name Purpose / Source\\n1|101 PC HOLDINGS LLC|Holding company for Penthouse C at the Setai Miami Beach (folio: 02-3234-153-1160)|PC = Penthouse C \\n2|11 STAR ISLAND LLC|Holding company for 10 STAR ISLAND DR, MIAMI BEACH, FL 33139 (folio: 02-4204-001-0100, 02-4204-001-0110) (lots 10, 11 and 12 of Star Island)| \\n3|117 EAST PARK AVENUE, LLC|Holding company for 117 E. PARK AVE, LIBERTYVILLE, IL (PIN: 11-21-212-046-0000); subsequently sold.| \\n4|1201 BRICKELL BAY, LLC|Holding company for 1201 BRICKELL BAY DR, MIAMI, FL (folio no: 141390710010)| \\n5|1221 BRICKELL, LLC|Holding company for 1221 BRICKELL AVE, 155 SE 13 ST, 165 SE 13 ST, 175 SE 13 ST, and 185 SE 13 ST, MIAMI, FL (folio: 01-4139-035-0010)| \\n6|1221 BRICKELL HOLDINGS LLC|Holding company for 1221 BRICKELL, LLC| \\n7|1229 PARK WEST AVENUE, LLC|Holding company for 1229 W. PARK AVE, LIBERTYVILLE, IL (PIN: 11-20-100-010-0000)| \\n8|125 WORTH LLC|Delaware LLC (file 7218403), Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person; speculaton this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC, this property is next door (PCN: 50-43-43-23-05-016-0380)| \\n9|125 WORTH HOLDINGS LLC|Delaware LLC (file 7218407); not registered to Florida yet but speculation this is similar setup as 151 WORTH, LLC and 151 WORTH HOLDINGS LLC| \\n10|1250 BB ASSET CO LLC|Holding company for 1250 BRICKELL BAY DR and 1260 BRICKELL BAY DR, MIAMI, FL (folio nos: 102100504250, 102100503210)|BB = Brickell Bay \\n11|1330 SOUTH OCEAN LLC|Holding company for 1330 S OCEAN BLVD, PALM BEACH, FL (PCN: 50-43-44-02-11-000-0020)| \\n12|14 STAR ISLAND LLC|Delaware LLC (file 3377653); incorporated 8/42020, withdrawn 10/10/2022; believe this was not used because 14 STAR ISLAND property was held by NAUTILUS HOLDINGS I LLC before sale on 10/5/2022| \\n13|151 WORTH, LLC|Holding company for 151 WORTH AVE, PALM BEACH, FL 33480 (PCN: 50-43-43-23-05-016-0130); office space for Citadel (https://localtoday.news/fl/citadel-moves-into-palm-beachs-former-neiman-marcus-building-4821.html); sole member is 151 WORTH HOLDINGS LLC| \\n14|151 WORTH HOLDINGS LLC|Holding company for 151 WORTH, LLC| \\n15|16 WILLOW HOLDINGS LLC f/k/a PVNAH LLC|Holding company for S WILLOW COURT, ASPEN, CO (Parcel: 273511309030); see Pitkin Co. reception # 623002, Delaware certificate showing name change 9/1/2015| \\n16|190 PFISTER HOLDINGS LLC f/k/a AH2013 HOLDINGS LLC|Holding company for 190 PFISTER DR, ASPEN, CO (parcel: 273511309029); see Pitkin Co.reception # 623000, Delaware certificate showing name change 9/1/2015| \\n17|196 PFISTER HOLDINGS LLC|Holding company for 196 PFISTER DR, ASPEN, CO (parcel: 273511309028); see Pitkin Co. reception # 623501, statement of authority show KP HOLDINGS LLC as sole membe| \\n18|1ALPH LLC|See ALPH LLC| \\n19|1BUSINESS GROUP LLC|See BUSINESS GROUP LLC| \\n20|1GFS DESIGN LLC|See GFS DESIGN LLC| \\n21|1GFS LLC|See GFS LLC| \\n22|1MEDIA HOLDINGS LLC|See MEDIA HOLDINGS LLC| \\n23|23174 NE 41ST PATH LLC|Holding company for 23174 NE 41ST PATH #12, OKEECHOBEE, FL 34972 (Parcel: 1-01-35-35-0020-00000-0120); part of Pine Creek Sporting Club (www.pinecreeksportingclub.com) includes horse, shooting sports; sole member is KP HOLDINGS L.L.C.| \\n24|3031 BRICKELL LLC|Holding company for 3031 BRICKELL AVE, MIAMI FL 33129 (Folio: 01-4139-001-2700); Sole member is KP HOLDINGS L.L.C.| \\n25|31 WILLOW HOLDINGS LLC f/k/a AP HOLDINGS I LLC|Holding company for 31 NORTH WILLOW COURT, ASPEN, CO (Parcel: 273511309019); sold 7/6/2017; see Pitkin Co. reception # 623001, Delaware certificate showing name change 9/1/2015| \\n26|650 CASUARINA LLC|Holding company for 650 CASUARINA CONCOURSE CORAL GABLES, FL (folio: 03-4132-019-0060) https://www.bizjournals.com/southflorida/news/2022/05/27/650-casuarina-concourse-coral-gables-sold.html|\" \\n27|650 MEADOW LANE 1 LP|Holding company for 650 MEADOW LANE, VILLAGE OF SOUTHAMPTON, NY (Parcel ID 7478) (https://archive.is/h85yq)| \\n28|800 NORTH MICHIGAN HOLDINGS LLC|Holding company for 800 N MICHIGAN AVE, UNITS 66 PH and 67 PH, CHICAGO, IL (Park Tower) (PINs: 17-03-231-018-1116, 17-03-231-018-1117); sole member is KP HOLDINGS LLC (see Cook County, IL doc # 1933315025); recently sold| \\n29|8565 OLD CUTLER LLC|Holding company for 8565 OLD CUTLER RD, MIAMI, FL (folio: 03-4132-019-0020)| \\n30|9 WEST WALTON HOLDINGS LLC|Holding company for 9 WEST WALTON STREET CONDOMINIUM UNITS 3500, 3600, 3700, and PH, CHICAGO, IL| \\n31|ADRP LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin|ADRP = Anne Dias Real Property? \\n32|AH2013 HOLDINGS LLC|See 190 PFISTER HOLDINGS LLC|AH = Aspen Holdings? \\n33|ALPH LLC a/k/a 1ALPH LLC|Formerly FAA registered plane N421AL| \\n34|AP HOLDINGS I LLC|See 31 WILLOW HOLDINGS LLC|AP = Aspen Property? \\n35|ARAGON INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_45631.pdf| \\n36|ASHLER CAPITAL LLC|https://adviserinfo.sec.gov/firm/summary/148826| \\n37|ASHLER CAPITAL MASTER FUND LTD|https://www.sec.gov/Archives/edgar/data/1003078/000114420418014250/tv488357\\\\_sc13g.htm| \\n38|BANBURY LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n39|BANBURY II LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n40|BKGST LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n41|BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC|See BLOSSOM WAY HOLDINGS LLC|Black Calabash is a type of tropical tree: https://edis.ifas.ufl.edu/publication/ST079 \\n42|BLACK WHEEL LLC|Illinois LLC, registered 3/5/2014, Florida address is Citadel Miami HQ, sole member is Kenneth C Griffin| \\n43|BLOSSOM WAY HOLDINGS LLC f/k/a CPPB HOLDINGS LLC f/k/a BLACK CALABASH FAMILY HOLDINGS LLC f/k/a PBH LLC|Holding company for 10 BLOSSOM WAY, 70 BLOSSOM WAY, and 1265 S OCEAN BLVD PALM BEACH, FL (PCNs: 50-43-44-02-10-000-0050, 50-43-44-02-10-000-0060, 50-43-44-02-10-000-0010)| \\n44|BRICKELL BAY HOLDINGS LLC|Holding company for 1201 BRICKELL BAY, LLC| \\n45|BRICKELL LEASING LLC|See \"Subordination, Non-Disturbance, and Attornment Agreement\"; Miami-Dade Clerk\\'s File No.: 2022 R 938960, Group: 1. Kenneth C Griffin is sole member.| \\n46|CAAM MANAGEMENT LLC|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm|CAAM = Citadel Alternative Asset Management \\n47|CAISLEAN CAPITAL LTD|NFA Pool ID P113537, ceased trading 3/31/2016| \\n48|CALC III LP|https://www.sec.gov/edgar/browse/?CIK=1582652| \\n49|CALC IV LP|https://www.sec.gov/edgar/browse/?CIK=1423043| \\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009|',\n        'Web search results:\\n\\n[1] \"As per the Oxford Dictionary, a chatbot is defined as A computer program designed to simulate conversation with human users, especially over the internet. It can be looked upon as a virtual assistant that communicates with users via text messages and helps businesses in getting close to their customers.\"\\nURL: https://www.datacamp.com/tutorial/building-a-chatbot-using-chatterbot\\n\\n[2] \"Python , A chatbot is a computer program designed to simulate conversation with human users, especially over the internet. Create a fortune teller program that will ask the user to input a question and feedback some random answer. Consider the following feedback to be used. No idea at all! Better pray. The possibilities are in your favor.\"\\nURL: https://www.chegg.com/homework-help/questions-and-answers/python-chatbot-computer-program-designed-simulate-conversation-human-users-especially-inte-q78825383\\n\\n[3] \"It was created by Joseph Weizenbaum in 1966 and it uses pattern matching and substitution methodology to simulate conversation. The program was designed in a way that it mimics human conversation. The Chatbot ELIZA worked by passing the words that users entered into a computer and then pairing them to a list of possible scripted responses.\"\\nURL: https://onlim.com/en/the-history-of-chatbots/\\n\\n[4] \"Study with Quizlet and memorize flashcards containing terms like Which analytics does the following fall into: Alice notice that call center always have an increase in the number of customer complaints during last week in May, so she decides reviews the employees work schedule in the month of May for the past 5 years., Datasets continue to become, Model used for predictive analytic have ...\"\\nURL: https://quizlet.com/415587939/big-data-final-exam-flash-cards/\\n\\n[5] \"As every bright side has a darker version, simulation of human conversation through AI also has some disadvantages like high cost of creation, unemployment, interaction lacking emotion, and out-of-the-box thinking. However, AI interaction tools are trained with a data set. The bigger the data set, the better the services.\"\\nURL: https://www.analyticsinsight.net/simulating-human-conversations-through-ai/\\n\\n[6] \"The eavesdropper, Eve intercepts the encrypted conversation and tries random keys with the aim of learning the conversation shared between Alice and Bob as shown in Fig. 7. For this POC, we used ...\"\\nURL: https://www.researchgate.net/figure/A-A-simulation-of-conversations-between-Alice-and-her-friend-Bob-B-The-eavesdropper\\\\_fig3\\\\_334408170\\n\\n[7] \"Dreams are most often reported when sleepers wake from \\\\_\\\\_\\\\_\\\\_\\\\_ sleep. REM. The brain waves during REM sleep MOST closely resemble those seen during: waking consciousness. REM sleep is paradoxical because: the brain is active, but the major skeletal muscles are paralyzed. Fatigue and pain reflect deprivation of \\\\_\\\\_\\\\_\\\\_\\\\_ sleep.\"\\nURL: https://quizlet.com/78519058/psyc-test-2-flash-cards/\\n\\n[8] \"You can generate easily a fake group chat conversation like Whatsapp, Facebook or Telegram. After creating members/users, you can add messages in your chat. Once all messages are set up, you have the possibility to live-preview the chat conversation via the play button. Until the share functionality is ready, you have the option to screen ...\"\\nURL: https://chat-simulator.com/\\n\\n[9] \"This is a program that allows the computer to simulate conversation with a human being: answer choices a. Speech Application Program Interface b. Chatbot c. Voice Recognition d. Speech Recognition Question 7 30 seconds Report an issue Q. This is a system of Programs and Data-Structures that mimics the operation of the human brain: answer choices a.\"\\nURL: https://quizizz.com/admin/quiz/5f183913423fab001b0bd134/ai-unit-1\\n\\n[10] \"This is a system of Programs and Data-Structures that mimics the operation of the human brain: answer choices a. Intelligent Network b. Decision Support System c. Neural Network d. Genetic Programming Question 8 30 seconds Q. Where is Decision tree used? answer choices a. Classification Problem b. Regression Problem c. Clustering Problem d.\"\\nURL: https://quizizz.com/admin/quiz/5f6d6e4a6e2458001be385f5/ai-class-9\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: Simulate a conversation between Alice and /u/CruxHub. They talk about which company from the data batches is worth researching further into on the web.',\n        'Simulate a conversation between Alice and /u/CruxHub. They talk about which company from this data batch is worth researching further into on the web.\\n\\nData batch: Entity Name Purpose / Source Hypothesized Acronym\\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009| \\n58|CE TM HOLDINGS LLC f/k/a KCG IP HOLDINGS LLC|Holding company for intellectual property (25 trademarks, 1 patent found so far)|CE TM = Citadel Enterprise Trademark Holdings \\n59|CEF OFFSHORE HOLDINGS LTD|NFA Pool ID P131121| \\n60|CEIF INTERNATIONAL LTD|NFA Pool ID P048476; http://registers.centralbank.ie/ICAVDocuments/C439830/Director%20Details%20Updated%2021.01.07%203.pdf| \\n61|CEIF LLC|NFA Pool ID P048474| \\n62|CEIF PARTNERS INTERNATIONAL LTD|NFA Pool ID P173278| \\n63|CEIF PARTNERS LLC|NFA Pool ID P048475| \\n64|CES SECURITIES CANADA ULC|See CITADEL SECURITIES CANADA ULC, CSA NRD # 49280| \\n65|CFPS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B176936; 100% owned by CITADEL ENERGY INVESTMENTS LTD| \\n66|CGE ALPHA LTD|NFA Pool ID P057309, ceased trading 6/7/2017| \\n67|CGE ALPHA OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064400, ceased trading 4/30/2017| \\n68|CGEF OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064406, ceased trading 2/21/2019| \\n69|CGEF SPC|NFA Pool ID P064408, ceased trading 12/31/2012| \\n70|CGMF OFFSHORE HOLDINGS LTD|NFA Pool ID P064410, ceased trading 3/31/2014| \\n71|CGTS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B157777; 100% owned by TACTICAL TRADING HOLDING LTD; NFA Pool ID P064412, ceased trading 9/30/2014| \\n72|CHARAXES MELVIN LLC|Sole member of CHARAXES MELVIN II LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n73|CHARAXES MELVIN II LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is CHARAXES MELVIN LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n74|CHI2LTV LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n75|CIG(E) LLP|See CITADEL EUROPE LLP| \\n76|CIG CANADA ULC|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n77|CIG MEDIA LLC|https://www.sec.gov/Archives/edgar/data/923877/000114420407003635/v063478\\\\_sc-13d.htm| \\n78|CITADEL AAM LP|https://www.sec.gov/Archives/edgar/vprr/0804/08040017.pdf| \\n79|CITADEL AC INVESTMENTS LTD|https://www.sec.gov/Archives/edgar/data/1015780/000114420408032074/v115701\\\\_sc13da.htm| \\n80|CITADEL ADVISORS EUROPE LIMITED f/k/a CITADEL MANAGEMENT (EUROPE) LIMITED f/k/a CITADEL HEDGE FUND SERVICES (EUROPE) LIMITED|https://find-and-update.company-information.service.gov.uk/company/10930267| \\n81|CITADEL ADVISORS HOLDINGS LP|Sole member of CITADEL ADVISORS LLC; https://www.sec.gov/Archives/edgar/data/1567180/000110465922099806/xslF345X03/tm2225817-2\\\\_4.xml| \\n82|CITADEL ADVISORS HOLDINGS II LP|https://www.sec.gov/Archives/edgar/data/1177609/000114420416082613/v429844\\\\_sc13ga.htm| \\n83|CITADEL ADVISORS HOLDINGS III LP|https://www.sec.gov/Archives/edgar/data/1640129/000114420415043739/xslF345X02/v416000\\\\_3.xml| \\n84|CITADEL ADVISORS LLC|NFA ID: 0391913; https://www.sec.gov/edgar/browse/?CIK=1423053| \\n85|CITADEL ADVISORS II LLC|| \\n86|CITADEL ADVISORS SINGAPORE PTE. LIMITED|| \\n87|CITADEL ALTERNATIVE ASSET MANAGEMENT LP|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm| \\n88|CITADEL AMERICAS LLC|| \\n89|CITADEL AMERICAS SERVICES LLC|| \\n90|CITADEL ANTAEUS INTERNATIONAL INVESTMENTS LTD|| \\n91|CITADEL ASIA ASSET HOLDING LIMITED|http://registers.centralbank.ie/ICAVDocuments/C157189/Director%20Details%20Updated%2016.10.31%202.pdf| \\n92|CITADEL ASIA LIMITED f/k/a CITADEL (HONG KONG) LIMITED|https://adviserinfo.sec.gov/firm/summary/148826| \\n93|CITADEL CANDLESTICK EIF LLC|| \\n94|CITADEL CANTERBURY S.\\u00e0 r.l.|Luxembourg - B87988; 100% owned by CITADEL TONBRIDGE S.\\u00e0 r.l.| \\n95|CITADEL CEFL CHINA LTD|NFA Pool ID P148073| \\n96|CITADEL CEFL INVESTMENTS LTD|NFA Pool ID: P161763; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n97|CITADEL CEIT CHINA LTD|| \\n98|CITADEL CEMF CHINA LTD|https://find-and-update.company-information.service.gov.uk/company/02263951/charges/x6zPQSYGNpuDNgxU1cFQlCS0iog| \\n99|CITADEL CEMF INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n100|CITADEL CEMF SPV LTD f/k/a CITADEL INVESTMENT MASTER FUND LTD|See CITADEL INVESTMENT MASTER FUND LTD; https://opencorpdata.com/lei/LF0U6QUBXKIO573GXS38|',\n        'Simulate a conversation between Alice and /u/CruxHub. /u/CruxHub asks Alice to anlalyze a data batch for non-standard insights.\\n\\nData batch: Entity Name Purpose / Source Hypothesized Acronym\\n50|CALC V LP|Investment manager for CSHC CHINA LLC and CITADEL (SHANGHAI) TRADING COMPANY LTD; https://files.brokercheck.finra.org/firm/firm\\\\_131114.pdf| \\n51|CAMBRIDGE FINANCIAL GROUP, LTD|See CITADEL INVESTMENT GROUP LLC| \\n52|CCFD OFFSHORE HOLDINGS LTD|NFA Pool ID P064386, ceased trading 5/3/2013| \\n53|CCLC HOLDINGS LLC|Owns CITADEL CLEARING LLC, \"Citadel Clearing Holdco\"; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n54|CCMFL LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n55|CCOF OFFSHORE HOLDINGS LTD|NFA Pool ID P064392, ceased trading 5/3/2013| \\n56|CDC PARTNERS, LP f/k/a GLB PARTNERS, LP|see Cook County, IL doc 0608910081| \\n57|CDG HOLDINGS LTD|NFA Pool ID P037047, ceased trading 12/30/2009| \\n58|CE TM HOLDINGS LLC f/k/a KCG IP HOLDINGS LLC|Holding company for intellectual property (25 trademarks, 1 patent found so far)|CE TM = Citadel Enterprise Trademark Holdings \\n59|CEF OFFSHORE HOLDINGS LTD|NFA Pool ID P131121| \\n60|CEIF INTERNATIONAL LTD|NFA Pool ID P048476; http://registers.centralbank.ie/ICAVDocuments/C439830/Director%20Details%20Updated%2021.01.07%203.pdf| \\n61|CEIF LLC|NFA Pool ID P048474| \\n62|CEIF PARTNERS INTERNATIONAL LTD|NFA Pool ID P173278| \\n63|CEIF PARTNERS LLC|NFA Pool ID P048475| \\n64|CES SECURITIES CANADA ULC|See CITADEL SECURITIES CANADA ULC, CSA NRD # 49280| \\n65|CFPS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B176936; 100% owned by CITADEL ENERGY INVESTMENTS LTD| \\n66|CGE ALPHA LTD|NFA Pool ID P057309, ceased trading 6/7/2017| \\n67|CGE ALPHA OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064400, ceased trading 4/30/2017| \\n68|CGEF OFFSHORE HOLDINGS LTD|https://www.sec.gov/Archives/edgar/vprr/1600/16003280.pdf; NFA Pool ID P064406, ceased trading 2/21/2019| \\n69|CGEF SPC|NFA Pool ID P064408, ceased trading 12/31/2012| \\n70|CGMF OFFSHORE HOLDINGS LTD|NFA Pool ID P064410, ceased trading 3/31/2014| \\n71|CGTS HOLDINGS S.\\u00e0 r.l.|Luxembourg - B157777; 100% owned by TACTICAL TRADING HOLDING LTD; NFA Pool ID P064412, ceased trading 9/30/2014| \\n72|CHARAXES MELVIN LLC|Sole member of CHARAXES MELVIN II LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n73|CHARAXES MELVIN II LLC|Delaware LLC, Florida address is Citadel Miami HQ, sole member is CHARAXES MELVIN LLC|Charaxes are a type of butterfly: https://en.wikipedia.org/wiki/Charaxes \\n74|CHI2LTV LLC|Delaware LLC, Florida address is Citadel Miami HQ, Gerald Beeson is Authorized Person| \\n75|CIG(E) LLP|See CITADEL EUROPE LLP| \\n76|CIG CANADA ULC|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n77|CIG MEDIA LLC|https://www.sec.gov/Archives/edgar/data/923877/000114420407003635/v063478\\\\_sc-13d.htm| \\n78|CITADEL AAM LP|https://www.sec.gov/Archives/edgar/vprr/0804/08040017.pdf| \\n79|CITADEL AC INVESTMENTS LTD|https://www.sec.gov/Archives/edgar/data/1015780/000114420408032074/v115701\\\\_sc13da.htm| \\n80|CITADEL ADVISORS EUROPE LIMITED f/k/a CITADEL MANAGEMENT (EUROPE) LIMITED f/k/a CITADEL HEDGE FUND SERVICES (EUROPE) LIMITED|https://find-and-update.company-information.service.gov.uk/company/10930267| \\n81|CITADEL ADVISORS HOLDINGS LP|Sole member of CITADEL ADVISORS LLC; https://www.sec.gov/Archives/edgar/data/1567180/000110465922099806/xslF345X03/tm2225817-2\\\\_4.xml| \\n82|CITADEL ADVISORS HOLDINGS II LP|https://www.sec.gov/Archives/edgar/data/1177609/000114420416082613/v429844\\\\_sc13ga.htm| \\n83|CITADEL ADVISORS HOLDINGS III LP|https://www.sec.gov/Archives/edgar/data/1640129/000114420415043739/xslF345X02/v416000\\\\_3.xml| \\n84|CITADEL ADVISORS LLC|NFA ID: 0391913; https://www.sec.gov/edgar/browse/?CIK=1423053| \\n85|CITADEL ADVISORS II LLC|| \\n86|CITADEL ADVISORS SINGAPORE PTE. LIMITED|| \\n87|CITADEL ALTERNATIVE ASSET MANAGEMENT LP|https://www.sec.gov/Archives/edgar/data/1027745/000114420408050200/v124853\\\\_sc13g.htm| \\n88|CITADEL AMERICAS LLC|| \\n89|CITADEL AMERICAS SERVICES LLC|| \\n90|CITADEL ANTAEUS INTERNATIONAL INVESTMENTS LTD|| \\n91|CITADEL ASIA ASSET HOLDING LIMITED|http://registers.centralbank.ie/ICAVDocuments/C157189/Director%20Details%20Updated%2016.10.31%202.pdf| \\n92|CITADEL ASIA LIMITED f/k/a CITADEL (HONG KONG) LIMITED|https://adviserinfo.sec.gov/firm/summary/148826| \\n93|CITADEL CANDLESTICK EIF LLC|| \\n94|CITADEL CANTERBURY S.\\u00e0 r.l.|Luxembourg - B87988; 100% owned by CITADEL TONBRIDGE S.\\u00e0 r.l.| \\n95|CITADEL CEFL CHINA LTD|NFA Pool ID P148073| \\n96|CITADEL CEFL INVESTMENTS LTD|NFA Pool ID: P161763; https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n97|CITADEL CEIT CHINA LTD|| \\n98|CITADEL CEMF CHINA LTD|https://find-and-update.company-information.service.gov.uk/company/02263951/charges/x6zPQSYGNpuDNgxU1cFQlCS0iog| \\n99|CITADEL CEMF INVESTMENTS LTD|https://files.brokercheck.finra.org/firm/firm\\\\_172693.pdf| \\n100|CITADEL CEMF SPV LTD f/k/a CITADEL INVESTMENT MASTER FUND LTD|See CITADEL INVESTMENT MASTER FUND LTD; https://opencorpdata.com/lei/LF0U6QUBXKIO573GXS38|',\n        'Web search results:\\n\\n[1] \"Katherine Burton Hedge fund titans Ken Griffin and Steve Cohen boosted Gabe Plotkins Melvin Capital, injecting a total of $2.75 billion into the firm after it lost about 30% this year. Citadel...\"\\nURL: https://www.bloomberg.com/news/articles/2021-01-25/citadel-point72-to-invest-275-billion-in-melvin-capital\\n\\n[2] \"NEW YORK, Jan. 25, 2021 /PRNewswire/ -- Melvin Capital Management (Melvin) today announced that Citadel and its partners and Point72 have made investments into its fund. I am incredibly...\"\\nURL: https://www.prnewswire.com/news-releases/melvin-announces-2-75-billion-investment-from-citadel-and-point72--301214477.html\\n\\n[3] \"Citadel LLC is further paring back its $2 billion investment in Melvin Capital Management after the hedge fund stumbled in its effort to recover from a near collapse triggered by surges in...\"\\nURL: https://www.wsj.com/articles/citadel-is-further-paring-back-2-billion-melvin-investment-11645710666\\n\\n[4] \"Citadel and Steven A. Cohen s Point72 Asset Management together invested $2.75 billion into Melvins hedge fund on Jan. 25 as Melvin was hemorrhaging money. In return for the rare...\"\\nURL: https://www.wsj.com/articles/citadel-to-redeem-about-500-million-from-melvin-capital-11629550410\\n\\n[5] \"CHARAXES MELVIN LLC is an Active company incorporated on August 5, 2022 with the registered number M22000012341. This Foreign Limited Liability company is located at SOUTHEAST FINANCIAL CENTER, 200 S. BISCAYNE BLVD., SUITE 3300, MIAMI, 33131 and has been running for one year. ... CITADEL SECURITIES GP LLC; KCG SPACE HOLDINGS LLC;\"\\nURL: https://bisprofiles.com/fl/charaxes-melvin-m22000012341\\n\\n[6] \"Now, Citadel is taking some of its money back. Citadel has notified Melvin of its plans to retrieve $500 million of the $2 billion it injected in late January, according to two people briefed...\"\\nURL: https://www.nytimes.com/2021/08/21/business/citadel-melvin-gamestop.html\\n\\n[7] \"Robinhood and Citadels relationship comes into focus as Washington vows to examine stock-market moves Trading firms at center of Reddit-fueled stock surges have worked closely to share...\"\\nURL: https://www.washingtonpost.com/business/2021/01/29/robinhood-citadel-gamestop-reddit/\\n\\n[8] \"Alongside hedge funds such as Melvin Capital, Citron Capital, Point72, D1 Capital Partners, and Candlestick Capital Management; Citadel LLC was, the lawsuit claims, taking up short positions against the securities that retail investors were longing. This alleged conflict of interest is at the core of the class action lawsuit.\"\\nURL: https://tokenist.com/new-lawsuit-alleges-citadel-conspired-with-robinhood-to-limit-gme-trading/\\n\\n[9] \"Melvin later attracted an additional $3.2 billion in fresh cash, and the firm had $11.7 billion in assets at the beginning of this year. Point72 hasnt redeemed its investment, a person familiar ...\"\\nURL: https://www.chicagobusiness.com/finance-banking/ken-griffins-citadel-pulling-back-most-its-2-billion-melvin-capital-investment\\n\\n[10] \"CHARAXES MELVIN II LLC branch. Company Number M22000012338 Status Active Incorporation Date 5 August 2022 (2 months ago) Company Type Foreign Limited Liability Jurisdiction Florida (US) Branch Branch of CHARAXES MELVIN II LLC (Delaware (US)) Agent Name C T CORPORATION SYSTEM Agent Address\"\\nURL: https://opencorporates.com/companies/us\\\\_fl/M22000012338\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is Charaxes Melvin LLC\\'s relationship to Citadel?',\n        'Web search results:\\n\\n[1] \"Federal authorities are investigating the market-making arms of Citadel LLC and KCG Holdings Inc, looking into the possibility that the two giants of electronic trading are giving small investors ...\"\\nURL: https://www.reuters.com/article/usa-stocks-probe-idUSL2N1871ZV\\n\\n[2] \"Today, KCG is second only to Citadel in the market for handling stock order flow from retail brokerage firms. KCG and many other high-frequency trading firms have shied away from the public...\"\\nURL: https://www.ibtimes.com/citadel-llc-kcg-holdings-kcg-market-making-arms-probed-federal-authorities-over-stock-2366805\\n\\n[3] \"Citadel Securities, a group owned by the Chicago-based hedge fund, is the would-be acquirer in the deal, the people said. The group is best known for its so-called wholesaler business that...\"\\nURL: https://www.wsj.com/articles/market-making-arm-of-citadel-llc-in-talks-to-buy-seats-on-nyse-floor-from-kcg-holdings-1454533971\\n\\n[4] \"Citadels share of the wholesale market is around 34 per cent compared to KCGs 25 per cent, according to Tabb Group. Virtu has yet to lay out in detail its plans for the wholesale business ...\"\\nURL: https://www.ft.com/content/e1cb396e-29a7-11e7-bc4b-5528796fe35c\\n\\n[5] \"Citadel Securities, a liquidity providers and market maker, announced it will purchase KCG Holdings designated market maker (DMM) business at the New York Stock Exchange. This will establish Citadel Securities as the DMM with the largest footprint on the NYSE, responsible for trading in approximately 1,500 issues.\"\\nURL: https://www.tradersmagazine.com/departments/brokerage/citadel-purchases-kcg-dmm-business-becomes-1-on-nyse/\\n\\n[6] \"isCitadel LLC and its related entity, KCG IP Holdings, LLC (Complainant), represented by Paul D. McGradyof Winston Strawn, Illinois, Respondent is- (Respondent), Alabama, USA. REGISTRAR AND DISPUTED DOMAIN NAME The domain name at issue iscitidelgroup.com, registered with TUCOWS, INC. PANEL The\"\\nURL: https://www.adrforum.com/domaindecisions/1522837.htm\\n\\n[7] \"KCG SPACE HOLDINGS LLC is an Active company incorporated on July 21, 2022 with the registered number M22000011413. This Foreign Limited Liability company is located at 200 S BISCAYNE BLVD STE 3300, MIAMI, FL, 33131, US and has been running for one year. It currently has one Authorized Person. KEY FACTS ABOUT KCG SPACE HOLDINGS LLC US Businesses\"\\nURL: https://bisprofiles.com/fl/kcg-space-holdings-m22000011413\\n\\n[8] \"The Complainant KCG IP Holdings LLC is the owner of US Trademark Registration No. 3,213,943, filed October 18, 2004, registered February 27, 2007, claiming first use dating back to 1994. Therefore, the Panel concludes that Complainants filing and registration of the CITADEL mark with the USPTO sufficiently demonstrates that it has rights in ...\"\\nURL: https://www.adrforum.com/domaindecisions/1579141.htm\\n\\n[9] \"The KCG SPACE HOLDINGS LLC principal address is 200 S BISCAYNE BLVD STE 3300, MIAMI, 33131. Meanwhile you can send your letters to 200 S BISCAYNE BLVD STE 3300, MIAMI, FL, 33131. The company`s registered agent is C T CORPORATION SYSTEM 1200 SOUTH PINE ISLAND ROAD, PLANTATION, FL, 33324. The company`s management are A, President - Beeson Gerald A.\"\\nURL: https://florida.intercreditreport.com/company/kcg-space-holdings-llc-m22000011413\\n\\n[10] \"Billionaire Ken Griffin has built Citadel Securities into a trading and asset management colossus. ... and KCG Holdings. Last month, Citadel Securities reached an agreement with the SEC to pay $22 ...\"\\nURL: https://www.chicagobusiness.com/article/20170203/NEWS01/170209978/chicago-billionaire-ken-griffin-splits-citadel-into-two-companies\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is KCG Space Holdings LLC\\'s relationship to Citadel?',\n        'Web search results:\\n\\n[1] \"Citadel LLC (formerly known as Citadel Investment Group, LLC) is an American multinational hedge fund and financial services company. Founded in 1990 by Ken Griffin, it has more than $50 billion in assets under management as of May 2022. [1]\"\\nURL: https://en.wikipedia.org/wiki/Citadel\\\\_LLC\\n\\n[2] \"NASHVILLE, Tenn. and BRONXVILLE, N.Y. \\u2014 Standard Media Group LLC (Standard Media) and Citadel Communications LLC (Citadel) jointly announced today that they have reached an agreement pursuant to which Standard Media will acquire from Citadel WLNE-TV, the ABC affiliate for the Providence, RI - New Bedford, MA market (DMA 52) and KLKN (TV), the \\u2026\"\\nURL: https://www.standardmedia.com/2019/05/16/standard-media-group-to-acquire-citadel-stations/\\n\\n[3] \"CITADEL MEDIA LLC. Citadel Media LLC is a New Hampshire Domestic Limited-Liability Company filed on February 6, 2021. The companys filing status is listed as Not In Good Standing and its File Number is 862423. The Registered Agent on file for this company is Peter Alan Gauthier and is located at 3 Maple Ridge Drive Unit 224, Merrimack, NH 03054.\"\\nURL: https://www.bizapedia.com/nh/citadel-media-llc.html\\n\\n[4] \"CITADEL MEDIA LLC is a Michigan Domestic Limited-Liability Company filed on November 16, 2017. The companys filing status is listed as Active and its File Number is 802132896. The Registered Agent on file for this company is Registered Agents Inc. and is located at 2222 W. Grand River Ave Ste A, Okemos, MI 48864. The companys mailing address ...\"\\nURL: https://www.bizapedia.com/mi/citadel-media-llc.html\\n\\n[5] \"Citadel Broadcasting Corporation was a Las Vegas, Nevada -based broadcast holding company. Citadel owned 243 radio stations across the United States and was the third-largest radio station owner in the country. Only iHeartMedia and Cumulus Media owned more stations prior to Citadels merger with Cumulus.\"\\nURL: https://en.wikipedia.org/wiki/Citadel\\\\_Broadcasting\\n\\n[6] \"Citadel is one of the largest hedge fund managers in the world. And theyve subsequently managed Melvin Capital to the ground. Melvin Capital suffered a loss of over 50% its first quarter in 2021 due to shorting AMC Entertainment and GameStop. At some point youd expect your clearing house to raise awareness on your risk management right?\"\\nURL: https://franknez.com/citadel-loses-billions-hedge-funds-are-getting-dragged-down/\\n\\n[7] \"At our core, Citadel is built to deliver excellence. We have some of the most talented and focused minds in the industry, and we activate their ideas and strategies through a robust range of proven technologies and execution capabilities. View Top Employees from Citadel LLC Looking for a particular Citadel LLC employees phone or email? Find Info\"\\nURL: https://rocketreach.co/citadel-llc-profile\\\\_b5c46522f42e0dc2\\n\\n[8] \"# 1 Most profitable hedge fund manager of all time Source: LCH Investment NV estimates, Top Hedge Fund Managers by Net Gains Since Inception as of 12/31/2022. Our people are relentless in seeking a better way. Each day, we reimagine and refine our strategies, models and technology in pursuit of superior results and long-term performance.\"\\nURL: https://www.citadel.com/\\n\\n[9] \"We are one of the most significant alternative investment managers in the public U.S. corporate credit markets. Explore Credit Convertibles Equities Equities represents one of the largest and longest tenured businesses at Citadel. Explore Equities Global Fixed Income Macro We are a leading fixed income and macro business.\"\\nURL: https://www.citadel.com/what-we-do/\\n\\n[10] \"Citadel. 203,101 followers. 1mo. Last weekend, we celebrated Citadels 30th anniversary at an incredible event at Disney World and Universal Studios. Our founder and CEO Ken Griffin summarized ...\"\\nURL: https://www.linkedin.com/company/citadel-llc\\nCurrent date: 1/27/2023\\n\\nInstructions: Using the provided web search results, simulate a conversation where /u/CruxHub and Alice analyze the data batches and try and investigate for any non-standard uses of the holding companies. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\n\\nQuery: What is CITADEL MEDIA LLC?',\n        \"What are the differences between the Dogme approach to language learning and the lexical approach to language learning\",\n        \"Implement my own netfilter in linux with linux kernel module with Rust\",\n        \"Damage to which nerve causes numbness of the palmar surface of the 5th digit/little finger\",\n        \"Explain the fault-tolerance of the reaction control system on the Space Shuttle\",\n        \"Hi, can you help me download 2000 portrait sketch images from Pinterest website with resolution at least 512 \\\\* 512? using python code\",\n        \"Tell me about the negatives of farming meat\",\n        \"what is the photograph filter called where the only part of the image is greyscale\",\n        \"I want some geological database structure with some example data for practicing my SQL query skills. Would you generate that for me?\",\n        \"What is a formal but simplified explanation of Web marketing\",\n        \"Rewrite and improve this story: Well, I have always liked helping people since I was a small child, I have been accused many times of giving too much away for free, but I find joy in helping others put the pieces together to reach their goals. As a Licensed Professional Counselor and Life Coach that is my job to impact individuals and help clients work through emotional difficulties and reach goals. But I will be honest with you I was selling the dream but not always living the dream. I had issues I had not worked completely through like childhood trauma, heartbreak, disappointments, and frustrations with life. Don't get me wrong I had the husband, the kids, the house and the 6 figure job but I was not happy inside, but I didn't change because I hate change, most of us hate change, right? Then I lost my sister, my friend, and it slapped me in the face that I need to take care of myself. I saw the addiction, I saw her not taking care of herself and I could not save her. One thing I know for sure, if you do not make your wellness a priority illness will find you. I remember the moment we lost her, the earth stood still and then my heart broke into pieces, what was I going to do, I have loved her my whole life! It was months later that I made a decision that I would be the change I hope to see, I would create a space for women of color to move past the obstacles that keep us from creating the life we want and Brown Suga Wellness was born. I am on this journey and I invite you to be on this journey with me! I love this quote by Oludara Adeeyo: \\\"When you heal yourself, you create an earth shattering legacy. The lineage of women who come after you will be healed. Your inner circle of Black women around you, healed.\\\" When you choose yourself you break generational trauma and curses. You activate your ancestral strength. I invite you to activate that strength!\",\n        \"How would you ask these questions: Tell everyone a little about you, where you from, what you like doing?\\nWhat goals are you pursuing right now?\\nWho has made the most influence in your life?\\nWho is the one person that you admire the most (alive or dead)?\\nWhat is the hardest challenge you\\u2019re had to overcome in your life?\\nWhen have you grown the most in your life and what caused that growth?\\nWhere is your favorite place to relax and renew?\\nWhat books have changed your life the most?\\nWhat Is the biggest thing that you want the audience to take away today?\\nHow can people get a hold of you to talk about your business?\",\n        \"Take these topics into a numbered table and generate subtopics in seperated lines for each. Preconfigure these subtopics as lections of those several topics and add them to the table. Use numbers for topics and letters for subtopics. Set a status (untouched/touched) for every subtopic in 3. coloumn of the table to mark them done when finished learning this subtopic and topic. Use coloumn 4 of the table for a short resumee of the chapter. Showing the learning process in percentage in front of every new output is first. Show the Table and wait for any userinput to start lessons on those topics.;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;;:~|@%\\\\*~;\",\n        \"Write a rap song about Mikkel Selko\",\n        \"list the largest outdoor retailers in the world\",\n        \"can you create a wordpress shortcode to include the following code from facebook sdk\",\n        'Is this grammatically correct: \"It only took 5 years, and while we still have a long way to go, Topher\\u2019s Farm has found its place with unique experience and offering of organic produce. \"',\n        \"Hello friend. My task for today is to engage in a debate with you. Will you humor me in this regard?\",\n        \"You are an expert marketing consultant and copywriter with expertise is direct response marketing. I need your help. Can I tell you about my business?\",\n        'here is part 1\\n\\n----\\nDaySculpting is a program that that deals with YOUR immediate future\\u2026.It is a 90 day program that teaches U how to create Success\\u2026 one day at a time\\u2026today\\u2026\\nUsing recent breakthroughs in the field of neuroscience, the study of the human brain, DaySculpting is one of the most powerful success systems on earth for creating what I call\\u2026 \\n\"Your Epic Ideal Day\" -- And when U have Epic Ideal Days? U create your EPIC IDEAL LIFE.\\n\\nDaySculpting is broken down into 3 easy to accomplish segments throughout your day\\u2026\\n~The Morning Lift Process\\u2026which sets U up with a MindState of Success and a design for U to follow throughout your day\\u2026There is a morning email\\u2026SMS text\\u2026Inspiring Video\\u2026Future Forward Tuning IN\\u2026And a 3 step Success Step Declaration Process\\u2026this only takes 15 minutes\\u2026\\n~Mid-Day Reconnect Process\\u2026whatever your miid-day is\\u2026U are encouraged to stop doing what U are doing and disconnect so U can re-connect\\u2026by listening to a 5-minute Tuning In Re-Connection. We know that somewhere in the middle of our day it\\u2019s easy to lose momentum and drift from our best intentions because of all the demands on our attention. It has been scientifically proven that when U disconnent for between 3 to 5 minutes at the midpoint of your day\\u2026.your brain resets\\u2026and your energy is replenished\\u2026I like to call it a MindState Re-Boot that will inspire U to re-ignite your imagination\\u2026this only takes 5 minutes\\n~Highlight And Insight Review Process\\u2026we all review our day however what DaySculpting \\nanchors for U is an activation and integration process that gets U to see your day as being successful\\u2026by celebrating your successes (your highlights) and being present to things U could have improved on (your insights) so U can make your insights into highlights..most people when they review their day fail to celebrate even the smallest increments of success\\u2026they focus on what they didn\\u2019t do and that puts them in a negative energy\\u2026Success has challenges and the\\nhighlights and insight process encourages and empowers U to honestly see what U are doing each day so U Sculpt new MindStates Of Success rather than the energy of uncertainty\\u2026\\nthis takes 10 minutes\\n\\nThe whole DaySculpting process takes 30 minutes a day\\u2026and as I always say if U don\\u2019t have \\n30 minutes to change your life then U don\\u2019t want to change your life and U are okay with living \\na mediocre life\\u2026\\n\\nDay Sculpting is about targeting specific Chief Aims U have for your life\\u2026and creating the Habits that will get U there\\u2026Imagine being able to replace the MindTraps (your limiting beliefs) with empowering rituals and habits that become your new normal\\u2026\\n\\nThrough the repetition of doing the daily DaySculpting process U are carving into your Subconscious memory thoughts, beliefs and actions that result in U sculpting the masterpiece known as U\\u2026\\n\\nThere are many programs out there that attempt to instill new success behaviors however many fall short of actually shifting your MindStates into a frequency of possibility where U get to actually see your daily results immediately\\u2026DaySculpting does this\\u2026\\n\\nThis is not science fiction\\u2026 and it\\'s not wishful thinking, or some tired old self-improvement, goal-setting program\\u2026 DaySculpting is a program that empowers U to manifest and realize your Chief Aims in life\\n\\n\"DaySculpting\" -- is a tool that takes just MINUTES a day for you to use\\u2026\\n\\nIt is designed to FREE UP hours in your day\\u2026 while at the SAME time empowering you for greater success in ANY area of your life.\\n\\nDaySculpting sheds light and solves an age-old problem:\\nWHY we often fight against the very changes we desire to make\\n\\nHave you ever experienced the FEELING that you deserve MORE out of your life? More financial freedom and greater rewards from the hard work you do every day? Deeper, more empowering relationships with those you love\\u2026 or maybe just meeting that special someone to share your life with? Perhaps you crave a deeper spiritual connection\\u2026 or a more healthy, trim, energetic body?\\u2026 \\nYET:\\nDespite your BEST intentions\\u2026 you struggle. Perhaps if you\\'re anything like me, you even self-sabotage your results with actions that you KNOW are not in your best interest.\\n\\nMaybe it FEELS like it did for me: Like you are swimming upstream\\u2026 making SOME progress, sure, but just not reaching your goals and desires fast enough.\\n\\nWell, I have wonderful news for you: It\\'s not because you\\'re lazy\\u2026 and it\\'s not because you are not smart enough, competent enough\\u2026 or ANYTHING enough! \\n\\nThe real REASON you desire more and are not seeing ALL the results you deserve lies within whether the Success Switch in your brain is in the ON or OFF position\\u2026\\n\\nThe SOLUTION\\u2026 THE ANSWER to flipping your Success Switch back ON lies within the simple daily steps U will take when U experience the DaySculpting Program\\u2026 \\nThe Day Sculpting Program Is A Simple Step Daily Success RITUAL \\u2028 That Shuts Down Your Body\\'s Failure Reflex \\u2028 So YOU Tap Into Your Brains Success Centers\\u2026\\u2028 In Just Minutes A Day!\\u2028\\u2028 IIMAGINE Knowing What HIGHLY SUCCESSFUL \\u2028 People Do EVERYDAY\\u2026\\nFor Abundance And Wealth, Greater Health, Self-Confidence Meaningful Relationships, Sharper Focus , Deeper Joy\\u2026\\u2028 And So Much More\\u2026\\n\\u201cNow You Too Can Use This 90-Day Game Changer\\u2028 To Tap Into The Key Success Centers Of Your Mind,\\u2028 And In Just Minutes You Can Transform Even Lousy Days\\u2028 Into Days Filled With The Results You Desire \\u2013 Guaranteed!\\u201d\\nTO MAKE A GREAT LIFE, ALL YOU HAVE TO IS MAKE EACH DAY A GREAT DAY \\u2026 \\nThen get up tomorrow and do the same thing, day after day after day.\\nARE YOU Ready To Change YOUR LIFE One Day At A Time\\u2026\\nThe comprehensive, fun and empowering 90-day DaySculpting program provides you with the life skills and tools to help you master a new MindState of Success and a range of powerful life-changing rituals and habits that will Sculpt Your Perfect Days Into A Great Life.\\nDAY SCULPTING WILL TEACH YOU:\\n\\u2022 The science behind HAVING A MindState Of Success...and why most people who want more in life actually have their success switch turned off by total accident!\\n\\u2022 How to get more done with more time and more energy left over!\\n\\u2022 The simple, yet powerful, process of building a powerful day so you create a series of \"Dynamic Days\" - days that will end up building your most incredible life (The one you always thought was out of reach!)\\n\\u2022 Learn the \\'Day Sculpting Principles\\'. These can have a huge impact on you your life, but when you learn how simple they really are, you can use them easily and consistently!\\n\\u2022 How in just a few minutes a day, you can keep positive results flowing and put your success energy into a permanent \\'ON\\' position!\\n\\u2022 And much more!\\nDaySculpting, is for those who are willing to take their life to the next level by creating new Success Habits replacing the ones that have been sabotaging your success. \\nSo make sure you can honestly agree with the following before experiencing DaySculpting:\\n\\u2022 You desire more out of life, yet feel as if you are \"missing something\" -- that special \"X Factor\" to take you to the next level?\\n\\u2022 You are brave enough to boldly say, \"I want greater wealth and financial freedom... and I demand the best lifestyle possible for me and my family!\\n\\u2022 You know the value of joy: You want to experience greater happiness, peace of mind, and connection with your friends and loved ones on a daily basis.\\nIf you agree with the above, and truly want to create the best life possible, with greater wealth, freedom, happiness, love, and fulfillment, then I invite you to experience the power of Day Sculpting \\u2026it will change the way you think about creating your day and the life you dream about. \\nI am not encouraging you to become busier but rather to use your mental and emotional, energy more elegantly sculpting your day the way you want it to be. \\nHow many times have you done a ton of work and still felt that you didn\\u2019t accomplish what you really wanted for yourself. Week after week, month after month go by and you still are no farther ahead of the game\\u2026stuck in the status quo that never seems to change.\\n\\nBreaking free means that the status quo of your life has to change\\u2026 your habits of expectation have to change \\u2026your mindset has to change\\u2026you have to uncover those old behaviors that have held you back and be willing to create a new mindset.\\n\\nYou have to be willing to shift your daily focus inwards towards what you need to do today rather than tomorrow. Because when you create a great day today you welcome in a more powerful tomorrow.\\n\\nWe all have the same 24 hours each day. But why are some people building fabulous careers, achieving healthy lifestyles, enjoying great relationships and incomes, living their passions, and creating what they truly desire as a life?\\n\\nImagine that you could clear away the distractions that you unconsciously create. You know the stuff that consumes your time causes stress and disconnects you from your purpose and passion. \\n\\nImagine every day you embrace the energy for what you are choosing to create in your life. Your thoughts empower you, your choices inspire you and your actions create momentum, opportunity and possibility.\\n\\nYou can create a GREAT LIFE, the life you want to live by focusing your efforts on Creating a Great Day Today. That\\u2019s Day Sculpting. Seven intentional sculpted days turn into a month of wonderful weeks and a year of magnificent months creating an amazingly successful life.\\n\\nNone of this is going to work though if you believe that what you were born with is all you will get\\u2026\\n\\nNo one will ever attempt to do something when they are convinced that they will fail.\\n\\nResearch has shown that the brain will actually stop itself from doing what\\u2019s necessary to succeed if a person believes that they cannot succeed.\\n\\nIt\\u2019s the small concrete indicators of success today that will prove you can have whatever it is you want and the process of Day Sculpting will empowers, inspire and motivates you each step of the way.\\n\\nYou see: Confidence + Discipline = Desired Outcomes \\n\\nIt\\u2019s time to stop looking at your life from a fear based I don\\u2019t know how to mindset but rather be open to creating a solutions focused change consciousness that embraces your gift and talents and encourages you sharing them.\\n\\nLet me share a bit of nuero-chemistry with you\\u2026\\nWhat fires together wires together\\u2026\\n\\nSo rather than Fall back on old habits\\u2026\\nTake the transitional step\\u2026of being fully present to whats trying emerge as your ideal future and to help it along start building confidence each day\\u2026\\n\\nAnd your possibility muscle and an intended thought process that leads to a more focused and clear out picturing of your desires.\\n\\nYou see...It\\u2019s one thing to set goals and to make to do lists and to say your going to use the law of attraction to manifest what you want in life\\u2026\\n\\nI\\u2019m still looking at the many lists I have created.\\n\\nWhat it\\u2019s really about is having a clear and purposeful intention in order to create the energy and the MindState Of success that will propel you into action.\\n----\\n\\nWhen done ask me for part 2',\n        \"Here is the final part. Part 3\\n---\\n\\nHere we will be showing how the principles and practices we\\u2019ve covered so far converge into one over-arching result that will benefit you for the rest of your life. You can think of it as flipping a switch that changes how you create new results in life one day at a time. This is at the very core of what we call Day Sculpting. \\nThe simplest way to think of it is that most of the way we live is habitual. You have an habitual way of brushing your teeth, walking, talking to yourself and others, eating, working. Habits are wonderful\\u2026they make life easy but they also limit you. For example, if you have a habit of eating too much, you\\u2019ll put on weight. Not instantly, but steadily, day by day, until one day you have a weight problem. If you try to change your weight quickly through a trendy new diet, research shows that the weight is likely to come back, and then some, within a few short months, because the habits required to live at your ideal weight have not been properly established. \\nHabits are habits because you don\\u2019t think about them, they happen nonconsciously. If you want a change in your life, you have to embody the change at a nonconscious level, so that the habits keeping your life the way it is today begin to shift.\\nWouldn\\u2019t it be great if there was a switch in the brain that would move you from status quo to status GO!? This is a switch that once you flip it will produce the result you want, if you are willing to commit to and stay with the process.Day Sculpting is your guide to fully realizing the success you are ready to enjoy.\\nA critically important capacity of the human mind called preconscious processing. This is the ability of the mind to receive information, beneath our conscious awareness, and act upon it without even knowing that it is happening. Used correctly, this is an amazing power. Used improperly, it will sabotage your best efforts and make life extremely difficult.\\nMost of us think we are running the show with our conscious awareness, consciously choosing our thoughts, behaviors, and emotions and consequently, we believe are able to choose the results we create in life. However, what neuro-science research shows, is that we all have a vast nonconscious mind that is really running the show most of the time. That deeper part of us, in charge of our habitual thinking, feeling, and behaving is always operating in our best interest. But it does so using information that may be faulty or outdated. If you continue to feed it information that doesn\\u2019t serve you, it will continue to habitually bring results that are less than desired.\\nYour preconscious processor is constantly routing new information directly into this larger database that your mind uses to create new behaviors. Your job is to place the right information into this database every single day, so that it can draw upon this new data and create new results. It requires your vigilance and purposeful intention on a daily basis. Day Sculpting is the process to accomplish exactly that, getting you to focus one day at a time on what you are trying to create in your life today, and the future you truly desire. \\nA lot of experts in the human development field teach information and then expect it will translate into new behaviors automatically. But as we\\u2019ve pointed out, and as you\\u2019ve probably experienced, consciously knowing something and having the nonconscious mind put it into a new behavior, are two entirely different processes. What we are sharing with you is how to bridge that gap. This is precisely why so many experts in the field are recommending Day Sculpting to their clients, to help them use momentum mindsets on a daily basis and apply the good information they teach. \\nWe talk about The The Solutions Focus process . Try it out: \\nThink of an area of your life in which you are actively attempting to create different results. Imagine your chief aim regarding this area of your life as a perfect future. Now imagine a scale from one to ten, where ten is the perfect future and one is that you have not even started thinking about your chief aim. On this imaginary scale from 1 to 10, where would you place yourself right now?\\nGo ahead and imagine where would you place yourself right now on that scale, where ten is your perfect future.\\nWhatever number you came up with is fine. Whether it was 3 or 7, whatever you came up with I\\u2019ll always ask the same next question. \\u201cWhy so high and not lower?\\u201d\\nLet\\u2019s say, for example that you came up with a three. Asking the question \\u201cWhy so High\\u201d catches the mind off guard. Most people expect, \\u201cOnly a 3! Why so low?\\u201d If I had asked that what would you come up with? All the reasons why things aren\\u2019t working, who is to blame, problems, excuses, lack, limitations, and so on. \\nBut when I ask \\u201cWhy so high?\\u201d the brain immediately begins to sort for all of the things that are working for you, everything that has brought you up to a \\u201cthree.\\u201d If you said you are at a seven on a scale of one to ten, the same question applies: \\u201cWhy so high and not lower?\\u201d\\nThe next step in solutions focus is equally powerful. \\u201cThink about what you can do today to move you one point up that scale\\u2014for example, from a three to a four, or from a seven to an eight?\\u201d When you ask this, your mind instantaneously starts generating ideas and options to answer your question. You quickly realize you can do more of the things that work, right? And if you are doing things that aren\\u2019t working, you now have the insight into how you can do things differently. \\nThis solutions focus approach provides quick insight into how to move things forward in areas you may have been stuck or working on unsuccessfully. It is a brilliant way to access more of your nonconscious database and facilitate discovering resources you did not know were there. \\nSo as you can see, this video has been centered on connecting the dots and providing you with the insights on how you can flip the switch in your brain and how you can create your life one day at a time in the most powerful way possible. \\nYou must contact that inner part of you that is in charge of your habitual ways of thinking, feeling, and behaving in order to re-sculpt yourself.\\nThis is a unique psychological principle called anchoring. In the research this is also called behavioral conditioning, and as we\\u2019ve called it, the law of reinforcement\\u2026which says you get more of what you reinforce. When you want to reinforce a positive new behavior, you anchor it in a positive new momentum mindset. As you do this on a daily basis, you are literally training your mind, conditioning your thoughts, amplifying positive feelings and emotions to live into a future state that you are anchoring in your daily experience. \\nDay Sculpting goes beyond personal development. It takes whatever it is you are currently learning and makes it possible for you to embody, apply and enjoy the benefits you are committed to achieve. \\n\\nThe last thing anyone needs is more stuff to do. What we need is that everything we do gets us the results we are going for. In essence what\\u2019s needed is a system that will streamline our efforts so we accomplish our chief aims in less time.\\n\\nMichaelangelo said the process of sculpting is to remove what\\u2019s not supposed to be there. He had the mindset that the finished sculpture already existed in the marble and he just had to reveal it. In the same way your destiny already resides in you. You just need to clear a path for it to emerge.\\n\\nWe all have 24 hours in a day. So why do some people consistently have great days while others are up and down and stay stuck in mediocrity? It\\u2019s a disciplined habit of how you approach everyday. Day Sculpting takes the same 24 hours that we all have and helps clarify your choices so that your actions reveal your highest destiny. \\n\\nIt is a quick, easy and effortless way that supports and empowers your efforts in achieving your chief aims. It creates the mindsets necessary to have successful days, weeks, months and years.\\n\\nDay Sculpting is a 90- day program designed to empower you to create your life ONE DAY AT A TIME. By committing 30 minutes each day to create what you want that day. \\n\\nWe believe that when you focus your actions one day at a time the results you get become measurable and achievable. Your energy is committed to channeling your efforts so you create a confident groove in your mind that empowers your habitual actions to create what you really want.\\n\\nThis daily program is broken down into 3 MANAGEABLE, SIMPLE AND EASY STEPS. 15 minutes in the morning, 5 minutes midday and 10 minutes at night. \\n\\nDay Sculpting\\u2026It\\u2019s designed so that the way you start your day creates the momentum that carries you throughout your day. \\n\\nAnd finally research has shown that the best time to integrate what you\\u2019ve learned in your day and to set yourself up for success tomorrow is before you go to sleep. The Nighttime Review process takes just 10 minutes, which is less time then it takes to take a shower or to take your dog on an evening walk.\\n\\nWe already have enough complexity in life\\u2026don\\u2019t we? We don\\u2019t want you working harder we want you thinking smarter! So that the success you achieve is more effortless. \\n\\nSo what does it take for someone to accomplish the high level results we are talking about?\\n\\n\\u2022 First you have to wake up and be totally jazzed about the day\\n\\u2022 You have to be inspired to do your best\\n\\u2022 You have to be focused on creating what you truly desire\\n\\u2022 You got to get to it, stay on it, and be in the energy of it before your distractions take over. \\n\\u2022 And if distractions takeover you have to quickly get back on track.\\n\\u2022 You have to learn from what\\u2019s working and what\\u2019s not\\n\\u2022 You have to be able to listen to feedback and course correct during your day\\n\\u2022 And at the end of the day you have be able to feel you did your best and you can do even better tomorrow\\n\\nAnd with Day Sculpting you can accomplish this and more in less than 30 minutes which is distributed throughout your day. Most people will give up on their dreams after they have tried something only 3 times because they didn\\u2019t get instant gratification. \\n\\nThere are no magic bullets here. You are investing in a future YOU desire. \\n\\nDay Sculpting gives you the opportunity everyday to purposefully stay in the energy of what you want to create the benefit to you being a more empowered mindset that inspires passionate action and a willingness to breakthrough any barriers that may have held you back in the past so you fully embody the life you choose to live.\\n\\nYou may have heard Gandhi say \\u201cBe the change you want to see in the world.\\u201d Well now you can. \\n\\nYears ago I heard a statistic that blew me away. If you read in a single subject of your choice for 15 minutes a day 5 days a week you would become one of the leading experts in the world in that subject within 3 years\\u2026\\n\\nMore recent research has demonstrated that world class talent requires 10000 hours and 10 years to develop\\u2026\\n\\nSo the question is how does somebody create this kind of commitment and persistence? Clearly one day at a time.\\n\\nSo where are you not following through in your life? How would you like to do things differently? What can you do shift your energy when you say I can\\u2019t get it done or you procrastinate? What\\u2019s it going to take for you to say I\\u2019ve had enough it\\u2019s time for me to do something different? Where will you get the support you need to build the confidence to stay on track?\\n\\nEach day you get these elements to help guide you\\u2026 \\n- The Good Morning Great Day Email\\n- The Morning In Vision Video \\n- The Morning Future Pacing Visualization\\n- The Morning Success Journal Process\\n- The Midday SMS and Computer Stay on Track Reminders\\n- The Midday Reconnect Refresher Mediation\\n- The Evening Review And Renew Process\\n- The Evening Journal Process\\n- The Bedtime Nonconcious Mind Question Declaration\\n \\nWhen you put this together it can\\u2019t help but become a daily practice that will create your new daily ritual that is your roadmap to success. We are giving you the daily steps that will create your momentum mindsets.\\n\\nThe Day Sculpting program leaves you no wiggle room. The days of \\u201cI\\u2019ll get to it later\\u201d are gone. When you are serious about changing your life, you now have a realistic opportunity to do so with this program. \\n\\nWE invite you to fully commit to your life. To once and for all follow through and step up. To say yes to that dream inside of you and to look at each day as an opportunity to live your dreams enthusiastically rather than settling for more of the same old same old.\\n---\",\n        \"analyze this: \\n\\nThe Coming of Age story archetype involves a young protagonist who must navigate the challenges of growing up and discovering their place in the world. The Before-After-Bridge copywriting framework is designed to highlight the transformation that a person can experience after using a product or service.\\n\\nThe reason why these two frameworks work well together is that they both focus on transformation and growth. By combining them, you can create a powerful narrative that speaks to your audience's desire for personal development and improvement.\\n\\nFor example, imagine you are selling a personal development course that helps people overcome self-doubt and build self-confidence. By using the Coming of Age archetype, you can frame the course as a journey of self-discovery, where the customer will face challenges and obstacles, but ultimately emerge as a more confident and self-assured person.\\n\\nThen, by using the Before-After-Bridge framework, you can show the customer what their life will be like after completing the course. You can highlight the benefits of increased self-confidence, such as improved relationships, better career opportunities, and greater overall happiness. By painting this picture of what's possible, you can create a sense of excitement and motivation that encourages the customer to take action and enroll in the course.\\n\\nOverall, the Coming of Age story archetype and the Before-After-Bridge copywriting framework work well together because they tap into a fundamental human desire for growth and transformation. By combining these frameworks in your marketing messages, you can create a compelling narrative that speaks to your audience's deepest aspirations and motivates them to take action.\",\n        \"Provide a detailed chronology of the Apostle John according to the New Testament\",\n        'Web search results:\\n\\n[1] \"1. Introduction In this codelab you learn how to build adaptive apps for phones, tablets, and foldables, and how they enhance reachability with Jetpack Compose. You also learn best...\"\\nURL: https://codelabs.developers.google.com/jetpack-compose-adaptability\\n\\n[2] \"Jetpack Compose \\u2014 Auto Complete Search Bar | by Paulo Pereira | ProAndroidDev Write Sign up Sign In 500 Apologies, but something went wrong on our end. Refresh the page, check Medium s site status, or find something interesting to read. Paulo Pereira 117 Followers Hello!\"\\nURL: https://proandroiddev.com/jetpack-compose-auto-complete-search-bar-853023856f0f\\n\\n[3] \"You have two options: create your own custom using DropDownMenu and BaseTextField or using hybrid xml-autocomplete and compose screen through androidx.compose.ui.platform.ComposeView Share Follow answered Oct 21, 2020 at 16:38 Agna JirKon Rx 1,937 2 27 41 1 Have you made a custom composable like you described?\"\\nURL: https://stackoverflow.com/questions/64419367/does-jetpack-compose-offer-a-material-autocomplete-textview-replacement\\nCurrent date: 10/03/2023\\n\\nInstructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject.\\nQuery: Hey, I want you to build to google places autocomplete on jetpack compose using the MVVM model\\n\\nSo the user will type the place in a textfield and the list of places with postalCode will display in a lazyColumn with the user able to select from the lazyColumn a place',\n        \"Captain Smith, who set out on a daring expedition with his fleet of ships consisting of the Discovery, the Endeavour, the Adventure, the Challenger, and the Explorer. Their mission was to chart a new route through the treacherous seas of the North Atlantic and claim new territories for their homeland. But the weather turned against them, and they found themselves battling fierce storms and raging currents. The waves grew higher and higher, and the winds howled like banshees, threatening to capsize their ships at any moment. Despite their efforts the Challenger and the Explorer, were lost in the storm. \\n\\nHow many ships did the captain leave with and how many returned?\",\n        \"explain the metaverse\",\n        \"can you provide and ideas for a series of articles for a product design blog\",\n        \"Please write a firm yet humurous and lighthearted note requesting that people RSVP whether they are coming to the purim seudah. Please incorporate wordplay and references to megillat esther.\",\n        \"Paper Name: My Tweets Bring All the Traits to the Yard: Predicting Personality and Relational Traits in Online Social Networks\\n\\nAbstract: Users in Online Social Networks (OSNs,) leave traces that reflect their personality characteristics. The study of these traces is important for several fields, such as social science, psychology, marketing, and others. Despite a marked increase in research on personality prediction based on online behavior, the focus has been heavily on individual personality traits, and by doing so, largely neglects relational facets of personality. This study aims to address this gap by providing a prediction model for holistic personality profiling in OSNs that includes socio-relational traits (attachment orientations) in combination with standard personality traits. Specifically, we first designed a feature engineering methodology that extracts a wide range of features (accounting for behavior, language, and emotions) from the OSN accounts of users. Subsequently, we designed a machine learning model that predicts trait scores of users based on the extracted features. The proposed model architecture is inspired by characteristics embedded in psychology; i.e, it utilizes interrelations among personality facets and leads to increased accuracy in comparison with other state-of-the-art approaches. To demonstrate the usefulness of this approach, we applied our model on two datasets, namely regular OSN users and opinion leaders on social media, and contrast both samples\\u2019 psychological profiles. Our findings demonstrate that the two groups can be clearly separated by focusing on both Big Five personality traits and attachment orientations. The presented research provides a promising avenue for future research on OSN user characterization and classification.\\n\\nIntroduction: Online Social Networks (OSNs) offer a virtual space in which people connect and interact with others, express themselves, and receive information, in a continuous digital reflection of the real (offline) world. In OSNs, people typically showcase their real self [40] and leave traces in online behavior, which reflect their real-world personality [24]. These traces expose a holistic image of oneself, including both personal characteristics (personality traits) and characteristics that portray their behavior in relation to others (relational traits).\\n\\nThe term personality refers to characteristic combinations or patterns of behaviors, cognitions, and emotional reactions that evolve from biological and environmental factors and form relatively consistent individual differences [13]. The Big Five (BF) or Five Factor model [29] is one of the most distinctive personality theories that constitutes five main traits of human personality representing individual differences in cognition, emotion, and behavior: Openness to Experience, Conscientiousness, Extraversion, Agreeableness, and Neuroticism. On the other hand, relational traits have also been linked with consistencies in social behavior and interaction patterns, with attachment theory [7] as the most emblematic theoretical framework in that respect [31, 43], capturing how individuals experience close relationships to and interactions with others.\\n\\nPersonality traits have been studied in the context of OSNs and the web overall, as findings show that they are strongly linked to OSN use [57], online friendships [60], and online reviews [52]. Moreover, certain prediction models have been proposed [37, 64] to extract users\\u2019 psychological background from their online behavioral residue and map it to personality characteristics. However, relational traits such as attachment orientations (AO) have been overlooked in online environments, even though user activity in OSNs heavily relates to social behavior characteristics. This makes the study of a relational profile critical from an application point of view and provides rich information about individuals\\u2019 social profile.\\n\\nThe present research aims to address this limitation in OSN research, by studying and predicting both relational traits and personality traits of users. The importance of relational facets of personality for explaining social interaction cannot be overstated. Given that online social media engagement resembles actual social interactions in many respects [15, 30], the need to study how different personality facets are reflected in online expression is particularly compelling. Attachment orientations, a key individual difference of relational orientation, can be derived on the basis of traces found in micro-blogs. Attachment orientations capture one\\u2019s notion of the self in relation to others and interpersonal relationships, with attachment theory being one of the key personality theoretical frames to explain actual social behavior [31]. Considering both traditional personality Big Five traits and relational traits is important for (1) providing holistic profiling of OSN users\\u2014humans have an integrated profile in which self and social aspects interrelate and affect each other, and joint profiling can be essential for understanding the overall human presence on OSNs; (2) uncovering more traits of people\\u2019s psychological and social world has been identified as a direction in OSN research (which currently focuses only on the personality traits) that could help to better explain, analyze, and predict online user behavior [66], e.g., with applications on customer segmentation [46] or digital advertisement environments [17]; and (3) shedding light on social interaction phenomena taking place in OSNs is of great socioeconomic importance, e.g., community formation [32], decision making [42], or information diffusion [12].\\n\\nTo this end, the present article proposes a novel data-driven approach to predict a holistic psychological profile of OSN users, capturing both their personality and relational traits.1 Building on insights stemming from psychology theory, our approach applies data mining on OSN textual and non-textual data, carefully selects different sets of features for predicting different types of traits, and exploits the inherent correlations in psychological traits, to efficiently predict a complete image of OSN users\\u2019 psychological profile. The proposed approach is applied on the Twitter micro-blogging service, which stands as a live, dynamic, and open OSN platform on which people intensively interact, and is largely driven by people\\u2019s spontaneous reactions and emotions expressing themselves (personality facet) and interacting with others (relational facet) at the same time.\\n\\nSpecifically, our contributions in detail are as follows:\\n\\nData mining and feature engineering for psychology traces in OSN. Motivated by psychology theory on personality suggesting that traits are reflected in different types of online behavior and actions, we identify a large set of features that capture language, behavioral, and emotional expressions of users in OSNs. The proposed feature engineering methodology accounts for a larger set of features than those considered in previous works, thus allowing to target more generic psychological profiling. To apply and test our methodology, we collected a labeled dataset: through a crowdsourcing platform, we recruited 243 individuals who consented to provide information about their psychology profiles. Subsequently, we compiled a ground-truth dataset labeled with their psychology profiles. We used the Twitter API to collect 350,000 tweets from the Twitter accounts of recruited participants and applied the proposed feature engineering methodology.\\n\\nHolistic psychological profiling. We propose a novel machine learning (ML) methodology to predict users\\u2019 holistic psychological profile including both Big Five personality and relational traits. The novelty of the proposed methodology is that it (1) uses a large set of the collected (psychological-related) features, (2) carefully selects the subsets of them with the strongest predictive power for each trait, and (3) exploits correlations between personality and relational (i.e., social) behavior traits to enhance individual trait predictions. In this way, our approach not only predicts social facets of a psychology profile (which is not captured by existing personality prediction models) along with personality facets but also leverages the different traits for more accurate holistic profile prediction.\\n\\nNew insights and improvement of prediction accuracy. Evaluating our methodology reveals interesting insights for the prediction of psychology traits from OSN traces: (1) using different sets of features performs better in predicting different psychological traits, (2) relational traits can be predicted as efficiently as personality traits, and (3) holistic personality prediction outperforms individual trait predicting models. We believe that our findings can pave the ground for future experimentation and studies in psychology profiling in OSNs. Moreover, the accuracy achieved by our approach (across all traits) is higher than current state-of-the-art approaches, which currently are limited to Big Five personality traits instead of relational traits. For example, applying the approach of [12] to our data provides a root mean squared error (RMSE) of 0.284, while our prediction model achieves a 29% improvement for personality traits (RMSE = 0.203) and has 32% better average performance when accounting for all traits (0.192 RMSE); this improvement comes as a result of using both a psychology-driven feature engineering methodology and a holistic profiling approach.\\n\\nPsychological profiling in the wild. We demonstrate the applicability of the proposed psychological profiling methodology through a use case. We identify a set of Twitter users who seem to be accepted as opinion leaders on social media (i.e., have a large following). We apply our methodology to predict their psychological profiles and analyze results. We find that the distributions of traits significantly deviates from regular users (defined as users included in our ground-truth dataset), and that the set of leaders can be clearly separated by only using their psychological profiles. These findings highlight the usefulness of our approach in the characterization of the personalities for different groups of OSN users (e.g., such a group psychological profile could be used to recommend skills/activities/jobs to users based on their profile similarity) and classification of users based on their profiles.\\n\\nIn this section, we provide an overview of related psychological literature, discuss related work, and highlight several open issues of existing methods. We also highlight the contribution of the present work. Section 3 details the collected dataset and the data mining and feature engineering methodology, and Section 4 presents the design and evaluation of the proposed machine learning predictive model. Finally, we conclude our article and discuss future work in Section 6.\\n\\nFirst, Please Summarize the paper in 10 points, in easy to read and understand simple English.\\nSecond, Explain what the paper does as if I'm 11 years old.\\n\\nThanks :))\",\n        \"Hi, i will give you three pieces of text, then i will ask you some questions, do you understand?\",\n        \"Here is Text 2: Communicating with External Audiences\\n\\nMany managers believe that they will never have to deal with the press. Often,\\nthey regard it with hostility. Most think press relations are entirely the domain\\nof their company\\u2019s or agency\\u2019s public relations department. But in fact, senior\\nexecutives say they spend more time on communications than on other tasks,\\nand a significant component of that time is devoted to press and public relations.\\nJunior managers need to be highly sensitive to press relations for the following\\nreasons:\\n\\u2022 Often, free press can be the best way to acquaint the public with your product or service.\\nTo cite only one example, the amount Microsoft spent on advertising Windows\\n95 was dwarfed by the value of the free publicity it received from\\ninternational news coverage.\\n\\u2022 Your particular area of expertise may unexpectedly become something your organization\\nneeds to promote or explain. Line workers at auto companies have been drafted\\nto extol quality improvements in advertisements; accountants may be called\\nto the CEO\\u2019s office for briefings on a potentially embarrassing news report or\\nan upcoming press conference.\\n\\u2022 Public relations considerations need to be addressed at the beginning, not the end, of a\\nplanning process. Business history is replete with examples of companies that\\ninvested vast sums to develop products, ideas, or services that couldn\\u2019t be sold\\nbecause of public resistance to the concept, the configuration, or the public\\nimage of the company. General Motors\\u2019 Tacos, for example, could be the best\\nin the world and still not jump off the shelves.\\n\\u2022 Junior managers become senior managers who will eventually have to deal with the\\npress directly. As both marketers and corporate citizens, organizations have to\\nexplain themselves to the public constantly through advertising, press releases,\\nand press conferences. Junior managers who understand this aspect of their\\nwork are likely to become senior managers faster. 1. A successful manager understands how the press works. Successful managers\\ntend to follow the press in general, and how their organization is playing in particular.\\nMembers of the press tend to trust companies and individuals with a\\ntrack record of accuracy and accessibility. To cite only two examples, both\\nJohnson & Johnson and Perrier survived charges of contaminated products because\\nthey had a record of reliability and accessibility and addressed the problems\\nimmediately. In both cases, and many others, stonewalling would have\\nbeen disastrous to the company\\u2019s image of wholesomeness and purity. Most\\npress stories last only a few days, but they can leave an indelible impression in\\nthe public\\u2019s mind. Many managers tend to believe they can \\u201csnow\\u201d the press\\nwith their greater expertise, but this strategy rarely works. Most reporters are\\nhard-working professionals who will carefully check out an expert assertion or\\nwho know someone who can.\\n2. A successful manager understands what the press needs. What the press needs\\nis a story, and bad news generally sells better than good news. Companies and\\nindividuals are most likely to have to deal with the press when something has\\ngone wrong. This suggests a couple of lessons. When you have good stories,\\ngive them to the press to establish a record of credibility; many media outlets\\nwill print or broadcast a press release from a reliable source more or less verbatim.\\nConsider how private decisions may look if they should become public.\\nIf something has gone wrong, take the initiative in announcing it, explaining it,\\nand telling the world how it\\u2019s going to be corrected.\\n3. A successful manager understands press jargon. Reputable reporters will\\nstick to their verbal agreements on how information you provide them is to\\nbe used. How you will be quoted depends on the ground rules you establish\\nat the beginning of an interview. Deep background means the reporter can\\nreflect the information in her story without possible attribution. Background\\nmeans that you can be referenced as \\u201ca reliable source.\\u201d Any other comment,\\nhowever apparently casual or social, can be quoted directly and\\nattributed.\\n4. A successful manager should be able to generate an attention-grabbing, accurate,\\nand well-constructed press release. While many managers may not be\\nregularly mailing out press releases themselves, most will be contributing to\\nthem and need to understand how they work. A good press release is extremely\\nformulaic and follows the structure of a good news story:\\na. The first paragraph states the main point clearly and emphasizes its newsworthiness.\\nFor example: \\u201cAcme Corporation announced today that it is\\nreleasing the best tire ever available on the world market.\\u201d\\nb. The second paragraph provides a quote from a reputable source: \\u201cAcme\\nPresident Rudy Roadrunner said, \\u2018Not only does this tire surpass all our\\ncompetitors\\u2019 in endurance, quality, and safety; it\\u2019s also available at a lower\\nprice.\\u2019 \\u201d\\nc. The third paragraph provides evidence that the claims made so far are true:\\n\\u201cIn repeated tests against our competitors . . . \\u201d\\nd. The remaining paragraphs provide background information on the product, the\\ncompany, and Rudy Roadrunner, and they demonstrate a track record of credibility.\\nThey may also include testimonials available from respected independent\\nsources. Obviously, the formula of an effective press release will vary depending on\\nthe nature of the news to be announced. But the pyramid structure suggested by\\nthis example always applies: Move from the most important and specific to the\\nleast important and most general information. Busy editors often run a press release\\nmore or less verbatim and just cut it off when they run out of space. The\\neasier you make their jobs, the more likely they are to cover your story.\\nOnce you\\u2019ve written or contributed to a press release, decide who\\u2019s most\\nlikely to run it. This can cover the gamut from extremely specialized trade magazines\\nto the national or international media. Consider the use of venues other\\nthan print and broadcast media as well; perhaps there\\u2019s a room on the Internet\\nwhere interested parties are likely to gather.\\n5. A successful manager understands the role of the press in crisis management.\\nThis includes knowing how to provide effective interviews and\\nunderstanding when and how to hold a press conference. Certain rules\\napply to both:\\n\\nApplications\\na. Identify your central message, make sure you can back it up, and stick to it.\\nb. Prepare materials in advance\\u2014press releases, statements, supportive\\nstudies\\u2014that the reporters can take away with them and study or quote later.\\nc. Never say more than you know to be true. If you don\\u2019t know, say, \\u201cI don\\u2019t\\nhave that information at the moment, but I\\u2019ll get it to you as soon as I do\\u201d\\u2014\\nthen follow up.\\nd. Make sure your team is behind you. This means making sure not only that\\ntop management of a corporation agrees on a message, but also that other\\npotential press sources (for example, subordinate employees) have the same\\ninformation you\\u2019re dispensing to the public, believe it, and are unlikely to\\nleak contradictory and embarrassing information.\\ne. Provide the press with the most credible and informed access possible. Reporters\\nwill always want to get to the top. They\\u2019ll be more likely to cover\\nthe comments of a CEO or a Cabinet secretary than those of a press agent\\nor an underling. But they will understand that a high official may need to\\nrefer technical questions to an informed specialist.\\nf. Anticipate, and be prepared to respond to, the most difficult questions.\\ng. Don\\u2019t become hostile or defensive; experienced reporters are experts at\\nsmelling anxiety.\\nh. Make your answers brief, quotable, and to the point. Rambling and repetition\\nare likely to get you into trouble or open new lines of inquiry.\\ni. If you\\u2019re facing a problem you\\u2019ve caused, however inadvertently, be prepared\\nto acknowledge\\n\\nAre you ready for text 3?\",\n        \"Here is Text 3: Diversity and Intercultural Communication \\n\\nGenerally, the best answer to these questions is yes, but it always depends on the personal as well as the business aspects of your relationship. One good rule of thumb: When the other person gives\\nyou an opening, pursue it, and build on your mutual experience.\\nThis issue comes up even more in international communication. As companies\\nfrom manufacturers to media conglomerates become increasingly global, managers\\nneed to understand the norms of other cultures. Although English is on the verge of\\nbecoming the international language, standards of behavior and social interaction\\nvary greatly between the United States and England, let alone between, say, France\\nand Japan. In one country an invitation to dinner may be considered an expected\\npoliteness, while in another, it may be an invasion of a colleague\\u2019s private time.\\nAsking about someone\\u2019s family may be absolutely required in one culture and offensively\\nintrusive in another.\\nNo textbook can cover all such contingencies; one good rule if you\\u2019re not sure\\nmay be the trial lawyer\\u2019s: Don\\u2019t ask a question to which you don\\u2019t already know the\\nanswer. Another, and sometimes contradictory, rule is: Be frank about your cultural\\nconfusion. Your colleague likely will have been in the same situation himself and\\nwill be happy to help out. Finally, do your research; you\\u2019re likely to have a friend or\\ncoworker who knows the terrain better than you do. Our purpose here is to sensitize\\nmanagers to their increasing need to understand the norms of cultures other than\\ntheir own. (For a case addressing the special features of international communication,\\nsee International Oil later in this chapter.)\\nThe opportunities for cultural confusion\\u2014personal, commercial, ethical, and\\nlinguistic\\u2014are almost endless. Imagine marketing a Chevy Nova in Hispanic countries,\\nwhere \\u201cno va\\u201d means \\u201cit doesn\\u2019t run.\\u201d Many products that are perfectly safe to\\nmarket in first-world countries raise ethical problems when sold in developing\\ncountries\\u2014infant baby formula, for example, which if mixed with contaminated\\nwater can cause death. Working in other cultures means understanding your hosts\\u2019\\nconceptions of greetings, timing, hygiene, negotiation, agreement, politeness, personal\\nspace, gesture, meal etiquette, and closure.\\nWhile English has essentially become the international language, it\\u2019s important\\nto remember that there are many Englishes. A joke in one form of English can be a\\ndeadly insult in another. Although it may seem too obvious to emphasize, you must\\nunderstand the cultural norms and language use of people from other cultures before\\nyou can communicate effectively with them. This is true even if they are, say,\\nthe South American employees of your Canadian company. A bribe in one culture\\ncan be a thoughtful gift in another.\\nA recent article by Sydel Sokuvitz (Business Communication Quarterly, New\\nYork, March, 2002) suggests some principles for conducting successful intercultural\\nbusiness communication. Sokuvitz first describes the special challenges global\\nmanagers face, including:\\nCoping with a range of tensions that arise out of internationally dispersed activities,\\nThe challenges of maintaining coordinated activities across time-zones, cultural\\nboundaries, and different countries\\u2019 laws, and\\nThe difficulties posed when the right medium for your message in one culture\\nmay be wrong in another.\\nDrawing on a range of research in the field, Sokuvitz comes up with several\\nprovocative conclusions:\\nExcessive dependence on technological communication such as E-mail can result\\nin problems for both communication and productivity.\\nFace-to-face meetings with colleagues from other cultures are critical to achieving\\neffective communication.\\nStudying with students from other cultures is critical to preparing a manager\\nfor working in the increasingly globalized economy.\\nSokuvitz cites the following example from an article by Fernandez-Aroaz\\n(\\u201cHiring without Firing,\\u201d Harvard Business Review, 1999):\\nA U.S.-based telecommunications company was seeking a CEO for its new division\\nin Latin America. An international search was conducted, and a veteran was\\nhired, someone known as an effective manager and marketing expert. \\u201cBut his run\\nlasted less than a year and was nothing short of a disaster. The simple reason was\\nthat he lacked the two skills that the job really required: negotiation and cross-cultural\\nsensitivity.\\u201d\\nEventually the company was saved from near-bankruptcy by bringing in a\\nnew CEO who was a native Latin American with work experience in the U.S. His\\nability to bridge cultural differences is credited with saving the company.\\nCommunications between headquarters and subsidiaries is only one example\\nof the challenges posed by globalization. Companies in one country are under increasing\\nsocial pressure to take responsibility for the behavior of their subcontractors\\nin other countries. Recently, for example, Nike suffered adverse publicity because\\nof the work practices of shoe manufacturers it employs in Asia.\\nThe successful manager of the future increasingly will be required to be a citizen\\nof the world. While electronic communication may work fine for conveying information\\nor directions, there is no substitute for \\u201cspeaking the language\\u201d of the\\npeople with whom you\\u2019re trying to communicate.\\n\\nAre you ready to answer some questions on text 1, text 2 and text 3?\",\n        'pragma solidity ^0.4.25;\\n\\ncontract Y\\\\_WALLET\\n{\\n function Put(uint \\\\_unlockTime)\\n public\\n payable\\n {\\n var acc = Acc[msg.sender];\\n acc.balance += msg.value;\\n acc.unlockTime = \\\\_unlockTime>now?\\\\_unlockTime:now;\\n LogFile.AddMessage(msg.sender,msg.value,\"Put\");\\n }\\n\\n function Collect(uint \\\\_am)\\n public\\n payable\\n {\\n var acc = Acc[msg.sender];\\n if( acc.balance>=MinSum && acc.balance>=\\\\_am && now>acc.unlockTime)\\n {\\n if(msg.sender.call.value(\\\\_am)())\\n {\\n acc.balance-=\\\\_am;\\n LogFile.AddMessage(msg.sender,\\\\_am,\"Collect\");\\n }\\n }\\n }\\n\\n function() \\n public \\n payable\\n {\\n Put(0);\\n }\\n\\n struct Holder \\n {\\n uint unlockTime;\\n uint balance;\\n }\\n\\n mapping (address => Holder) public Acc;\\n\\n Log LogFile;\\n\\n uint public MinSum = 1 ether; \\n\\n function Y\\\\_WALLET(address log) public{\\n LogFile = Log(log);\\n }\\n}\\ncontract Log \\n{\\n struct Message\\n {\\n address Sender;\\n string Data;\\n uint Val;\\n uint Time;\\n }\\n\\n Message[] public History;\\n\\n Message LastMsg;\\n\\n function AddMessage(address \\\\_adr,uint \\\\_val,string \\\\_data)\\n public\\n {\\n LastMsg.Sender = \\\\_adr;\\n LastMsg.Time = now;\\n LastMsg.Val = \\\\_val;\\n LastMsg.Data = \\\\_data;\\n History.push(LastMsg);\\n }\\n}',\n        \"I am planning to give you a voice, and communicate through the speech medium. I need a speech recognizer, a wake call detector, and a speech synthesizer for your voice. Suggest a python script utilizing existing libraries to achieves the goal.\",\n        \"lemme share a paper with you\",\n        'I aim to emulate a NLU/ENR module as part as part of a business application with your help. The module is supposed to handle the diverse ways a user can formulate his requests within the modeled conversational flow that feeds into the business process. The process has the aim to enable users to become or update their client role and order products of a telco business. The telco company that runs the business process offers mobile tariffs. Mobile tariffs have can have between one and 5 sim cards. Each booked sim cards enables the user to optionally book a smartphone for that card. Depending on the tariff, the chosen smartphones (if any) and the kind of sim cards (adult, child) the price will adapt. Please suggest a set of NLU / ENR methods that you could emulate to facilitate the use case. In the following I will input utterances and statements on how the system running the conversational flow should handle the utterance within the conversational flow. Please provide possible calls to an imaginary API that you could simulate to facilitate the NLU/ENR requirements layed out by my statements. On Subtasks that are recognized as not directly related to NLU/NER be very brief. Please suggest NLU / NER Operations now for the first of a few utterances: \"Hi I want to upgrade my current tariff and get a new smartphone\". The utterance should make the system recognize that the utterance can be handled as part of the business process. It should recognize that the user apparently already a client and it should continue the conversation by trying to identify him and metadata on his current tariff. For that the flow needs the user to authenticate using a oauth2 mechanism',\n        \"From now on only create subscription service listings with the following template: Subscription Services Template:\\n\\nTitle: Professional Writing Services Subscription\\n\\nDescription: Our subscription service offers access to a team of professional writers who will provide high-quality written content on a regular basis. Choose from one of our three plans to suit your needs and budget.\\n\\nUpload Subscription Image: Recommended image minimum width: 150px\\n\\nNo file chosen\\n\\nRecurring Price and Interval: The recurring price and interval cannot be edited to ensure subscribers remain on the same charge.\\n\\nPlan 1:\\nPlan name: Basic\\nThe recurring price is USD 75.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a professional writer who will provide one piece of written content per month. Perfect for businesses or individuals who need occasional written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\\n\\nPlan 2:\\nPlan name: Pro\\nThe recurring price is USD 500.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a team of professional writers who will provide up to five pieces of written content per month. Perfect for businesses or individuals who need regular written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\\n\\nPlan 3:\\nPlan name: Premium (Bundle of 20 / 1,500 words)\\nThe recurring price is USD 1000.00 and will be charged periodically at every 1 month\\nPlan description: This plan includes access to a team of professional writers who will provide up to 20 pieces of written content per month. Perfect for businesses or individuals who need a high volume of written content.\\n\\nPlan Image: Display a small image to represent this plan to customers\\n\\nTrial Period: Enable trial period\\nAssign Digital Product Files: Assign digital products for subscribers\",\n        \"Hello\",\n        \"I am launching an Etsy shop with a Printful integration for drop shipping my designs on specific products. I am looking for ways to differentiate beyond the designs. You are an expert on Etsy audiences. Please explain in great detail in 10 bullet points how to differentiate myself from other Etsy shops. I am looking for more obscure ideas here.\",\n        \"How to get a job as a LMFT therapist in the US as an international student?\",\n        \"Explain quantum computing in simple terms\",\n        \"estoy en 6to semestre de mecatronica, necesito un nombre para mi equipo, asi que quiero que me des una lista de 40 opciones, pueden estar relacionadas con la mecaronica, o combinando los nombres de los integrantes que son rudy, gloria, johana, melissa, perla y nomar\",\n        \"Explain deposition\",\n        \"Can you suggest some good e-governance initiatives in tribal districct of india by district administration\",\n        \"Write a python program which accept a command line param as question and send it to server via HTTP get method\",\n        \"Can you explain the fourth dimension to a second grader?\",\n        \"I have an interview about product speccing with the company Weekend Health. Give me an example of a question they might ask with regards about a new feature\",\n        \"arduino uno adalah\",\n        \"how edit array which is in object\",\n        \"how can my software company use Microsoft ENTRA to verify the identity of a user before accessing the software?\",\n        \"calculate the difference in intereste paid in a simple for amortized loan. terms: 125,000 loan, 3.25% interest over 30 years.\",\n        \"can i use spring state machine and workflow together and is it justified?\",\n        'I have the following code:\\n\\n```\\nuseEffect(() => {\\n const handleKeyDown = (event) => {\\n // Check if the CMD + F key combination was pressed\\n if (event.key === \"f\" && event.metaKey) {\\n event.preventDefault();\\n\\n setIsShown(true);\\n }\\n\\n window.addEventListener(\"keydown\", handleKeyDown);\\n\\n return () => {\\n window.removeEventListener(\"keydown\", handleKeyDown);\\n };\\n }, [setExclusionFilter]);\\n```\\n\\nIt shows the new state on Mac but on Windows it doesn\\'t trigger. How can I support windows?',\n        \"What is the best marketing tactics for local small businesses?\",\n        \"write an essay on french revolution\",\n        \"What are the roles of a network driver? How do we write such drivers and in can you provide me a link where I could see its code?\",\n        \"Are you familiar with the SAS programming language?\",\n        \"the solenoids will be 12v so they will have to be controled by relays triggered by the GPIO pins\",\n        \"Transform with regular expressions those lines:\\n0003 AB\\n0568 FD\\ninto:\\nAB\\nFD\",\n        \"Write the prompts in the following format. First sentence establishes a situation. Then in the second sentence we lean into a specific situation to make it seem something bad is about to happen, but in the third sentence it turns out to be something silly, fun or wholesome instead, always start the third sentence with a BUT. Some examples below\\n\\n-A hydra is hypnotizing an orc. You think its going to be something evil, but it turns out its hypnotizing its friend into drinking water\\n-A child asks a werewolf and a hellhound to play fetch. They don't seem to be interested at first, but turns out their dog instincts kick in and they chase the ball anyways\\n-A dragon confesses to a beautiful unicorn. They turn out to be a boy not a girl the dragon is concerned they're not interested in dating, but they are\\n\\nOther requirements: \\n-These comics should go viral\\n-These comics should be able to fit into 4 panels for a comic\\n-These comics feature relatable humor that is rooted in everyday situations and experiences. \\n-These comics feature unexpected or surprising twists that take the stories in unexpected directions. \\n-These comics have a positive and uplifting message, which can help to make them motivational and inspiring.\\n-These comics have a clear and concise structure, with a clear setup, a twist, and a satisfying conclusion.\\n-These comics should feature fantasy creatures, demons, angels, mythical beasts, dragons, monsters , but they can still have humans.\",\n        \"How can we improve this comic to be simpler and funnier?\\n\\n[We see that this is a small reading club for woodland creatures. Make them all nice and cute, very winnie the pooh-esque, lol. The two characters that speak are animals, make Red into a herbivore race, like a rabbit or something, pink should be a small carnivore like a cat or badger? Red is confused, and red is excited]\\nKnock Knock\\nPink:Who\\u2019s that?\\nRed: Maybe a new member for our book club!\\n\\n[Panics as she sees a dragon licking their lips behind the curtain]\\nRed: It\\u2019s a dragon, run for your lives everyone!\\n\\n[Dragon mom is outside their home, looking dragon-eque but also waving her hands chibi cute apologetically, she\\u2019s clearly a little embarrassed by the situation. Red looks at her suspiciously ]\\nDragon:I\\u2019m not here to eat anyone, I uh\\u2026 heard you had a book club?\\nRed: Uh\\u2026yes\\n\\n[Dragon looks very excited and welcome, Pink seems like she likes the book, red looks a little grossed out ]\\nDragon: Awesome, it's nice to meet you! I brought my favorite book too!\\nPink: What a lovely book!\\nRed: Ugh I\\u2019ll pass on reading that.\",\n        \"Rewrite the following 4 panel comic to be both more brief and more funny\\n\\n[We see an evil mermaid holding a microphone but with an evil face, like she\\u2019s just cast a dark spell of some sort. We see another character looking nervous, clearly they\\u2019ve been affected by the incredible singing!]\\nMermaid: You\\u2019ve lost! Give up & spare us both the trouble!\\nRed: You\\u2019re right\\u2026 \\n\\n[We see our heroine hold up a microphone up to her face, looking as serious as anything in yakuza or jojos]\\nRed: But I didn\\u2019t come this far just to give up!\\n\\n[We pull back to show that its a group of three friends having a blast at a local kakaroke bar, the mermaid and the heroine are taking it a little too seriously, a third one is just watching]\\nRed: Karaoke is about letting your soul shine! I\\u2019m giving it my all or die trying!\\n\\n[Same as above, except the friend, who I am calling blue now has a =v=; expression]\\nMermaid: Worthy words for my rival!\\nBlue: Girls, you need to chill. \\nRed: Baka mitai~ (No bubble)\",\n        \"write a brief email in which Ayaam Ghimire writes to Bronywyn Tucker-- the liason between ECG and Guilford College- requesting e waste boxes to be put around campus and computer donation setup with Bauman IT or any other facility on Guilford College campus, on behalf of a organization called CompuCycle, after speaking with the principal Dr. Kash\",\n        \"I'm writing a software for conference calls.\\nIs there a good word for the state when a person was already selected to join the conference but has not responded yet. This should also include the meeting organizer himself, if his client has not answered yet\",\n        \"Would you be able to classify them into more of a range from small startup to big fortune 500 company\",\n        \"Write user stories that describe this concept in detail\",\n        \"Check your python version\",\n        \"We will be making a scenario that follows the following rules:\\n\\nThe competency framework is developed through three phases: 1) scoping review; 2) Focus group discussions with mental health clinicians reviewing patient narratives; and 3) Facilitated Persona Scenario method with Black youth. Moreover, the project adopts a co-design approach and convenes a Knowledge User Panel. The panel will be involved in all phases of the competency framework development as they will review findings from the scoping review and focus groups. \\n\\nFocus group with mental health clinicians \\n Mental health clinicians (i.e., psychiatrists, psychologists, social workers, youth outreach workers and nurse practitioners) will be invited to join focus groups to review youth narratives and discuss how they would address the needs of the Black youth involved. The youth narratives will be generated through collecting stories from social media and through an online survey. The survey will ask about young people's experiences with mental health conditions, their use of mental health services, and their suggestions for how to improve mental health care for young people. The online survey will collect stories anonymously. Anyone who submits a story through the survey will be redirected to a list of resources. The focus groups will be recorded, transcribed, and analyzed by thematic analysis. The focus groups will continue until thematic saturation.\\n\\nPhase 3: Persona Scenario method with Black youth\\n Black youth will be invited to focus groups (or one-on-one interviews, if requested) using persona scenario methods. The findings from the focus groups with mental health clinicians will be used to create clinician personas, including information about their motivations, challenges and describe the different ways in which the clinician might interact with the Black youth based on youth narratives. Black youth will be asked to share their perspectives and preferred clinician responses. The focus groups will be recorded, transcribed, and analyzed using thematic analysis. We will continue to hold focus groups until thematic saturation.\\n\\nCan you with the information above, create a sceenario/dialogue where a black youth, aged 15 living in Ontario suffering from racism from his classmates and is going to seek the help of a mental health professional who uses the information to engage the youth \\n\\nlimit prose to 500 characters\",\n        \"Demand generation manager for a B2B brand ambassador program called Brandchamp\",\n        \"Here is my Python code:\\napi\\\\_url = 'https://api.yelp.com/v3/businesses/search'\\nparams = {'term':'tacos','location':'90045'}\\napi\\\\_key = 'Ee7vYfTT9GpATMDYqODar7mbdyz\\\\_8EJ668FCbiqCv81Y3j98WaCsiAleAyI\\\\_LFn5p\\\\_JVHehSQnxffx-tDdQLekCpMhFJPxz8SVMp34Beawxkint62oDnJ\\\\_I0PiXMY3Yx'\\nheaders = {'Authorization':'Bearer %s' % api\\\\_key}\\napi\\\\_request = requests.get(api.\\\\_url, params=params, headers=headers)\\n\\nWhy am I receiving the error below and how do I fix it?\\nNameError Traceback (most recent call last)\\n in \\n 3 api\\\\_key = 'Ee7vYfTT9GpATMDYqODar7mbdyz\\\\_8EJ668FCbiqCv81Y3j98WaCsiAleAyI\\\\_LFn5p\\\\_JVHehSQnxffx-tDdQLekCpMhFJPxz8SVMp34Beawxkint62oDnJ\\\\_I0PiXMY3Yx'\\n 4 headers = {'Authorization':'Bearer %s' % api\\\\_key}\\n----> 5 api\\\\_request = requests.get(api.\\\\_url, params=params, headers=headers)\\n\\nNameError: name 'api' is not defined\",\n        \"고등교육의 필요성에 관한 영어 에세이를 1000자 이내로 작성하시오.\"\n        \"Which hero is the best in Heroes of Might and Magic 3?\",\n        \"Use C# to get the current YouTube thumbnail and convert it to Base64.\",\n        \"minikube - docker run --rm -it --network=host alpine ash -c apk add socat && socat TCP-LISTEN:5000,reuseaddr,fork TCP:$(minikube ip):5000 connection refused\",\n        \"How to load image here ?\",\n    ]\n\n    responses = await generate_multi(flash_llama_fd, prompts, max_new_tokens=10)\n\n    assert len(responses) == len(prompts)\n    outputs = [r.choices[0].message.content for r in responses]\n    expected = [\n        \"Jeff Walker's Product Launch Formula is a comprehensive system\",\n        \"Here are three key indicators to determine if a customer\",\n        \"You can use the `String.format()` method in\",\n        \"In a realm of binary mysticism, we find\",\n        \"The `dummy` variable is being used to consume\",\n        \"You can add multiple new columns in Power Query (\",\n        \"There are many exciting new technologies emerging across various fields\",\n        \"Poly Ether Ether Ketone (PEEK) is\",\n        \"Here's a technical overview of a referral system similar\",\n        \"Here's an example of how you can add an\",\n        \"I'd be happy to help with Java. What\",\n        \"I can help you plan a road trip from Pune\",\n        \"I'd be happy to explain more about a topic\",\n        \"I'd be happy to help you brainstorm and provide\",\n        \"Implementing a Minesweeper algorithm using algebraic\",\n        \"There are several issues with the provided code:\\n\\n1\",\n        \";)\",\n        \"As I delved into the world of high-st\",\n        \"/u/CruxHub: Hi, I'm\",\n        \"To simulate a conversation between Alice and /u/C\",\n        \"Alice: Hey /u/CruxHub,\",\n        \"Alice: Hi /u/CruxHub,\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"/u/CruxHub: Hey Alice, I\",\n        \"The Dogme approach and the Lexical Approach are\",\n        \"Implementing a netfilter in Linux with a Rust\",\n        \"Damage to the Ulnar nerve can cause numb\",\n        \"The Space Shuttle's Reaction Control System (RCS\",\n        \"I can provide you with a basic Python script that\",\n        \"Farming meat has several negative impacts on the environment\",\n        \"The photograph filter you're referring to is called \\\"\",\n        \"Here's a sample geological database structure with some example\",\n        \"**Web Marketing: A Simplified Explanation**\\n\\nWeb\",\n        \"Here's a rewritten and improved version of the story\",\n        \"Here are the questions rewritten in a more conversational\",\n        \"**Learning Progress: 0%**\\n\\n| Topic\",\n        \"I couldn't find any information on a person named\",\n        \"Here's a list of the largest outdoor retailers in\",\n        \"To create a WordPress shortcode that includes Facebook SDK code\",\n        \"The sentence is mostly grammatically correct, but there\",\n        \"I'd be happy to engage in a debate with\",\n        \"I'd love to hear about your business. As\",\n        \"I'll wait for your request to proceed with part\",\n        \"The final part of the Day Sculpting program emphasizes\",\n        \"**Analysis of the Coming of Age Story Archetype\",\n        \"The Apostle John is one of the most prominent figures\",\n        \"To build a Google Places autocomplete feature on Jetpack\",\n        \"The information provided does not mention the captain's name\",\n        \"The metaverse is a shared, immersive and interactive\",\n        \"Here are some ideas for a series of articles for\",\n        '\"Purim Palooza Alert: \\n\\nTo',\n        \"**Summary of the paper in 10 points:\",\n        \"You'll provide three pieces of text, and then\",\n        \"I'm ready to proceed with text 3.\",\n        \"I'm ready to answer questions on Text 1\",\n        \"This is a Solidity contract written in the older\",\n        \"**Speech Recognition and Synthesis using Python**\\n\\nTo\",\n        \"I'd be happy to help you discuss a paper\",\n        \"To handle the given utterance, we can use\",\n        \"**Subscription Services Template:**\\n\\n**Title:** Virtual\",\n        \"Hello. How can I assist you today?\",\n        \"Differentiating yourself from other Etsy shops is crucial to\",\n        \"To become a Licensed Marriage and Family Therapist (\",\n        \"**What is Quantum Computing?**\\n\\nQuantum computing\",\n        \"Aqu\\u00ed te dejo 40 opciones de nombres\",\n        \"Deposition is a geological process that involves the transportation\",\n        \"Here are some good e-governance initiatives in\",\n        \"Here's a simple Python program that accepts a command\",\n        \"Imagine you're playing with a toy box. You\",\n        \"Here's an example of a question they might ask\",\n        \"Arduino Uno adalah sebuah papan mikrokontrol\",\n        \"To edit an array that is within an object,\",\n        \"Microsoft ENTRA (Enterprise Mobility + Security) is\",\n        \"To calculate the difference in interest paid between a simple\",\n        \"Yes, you can use Spring State Machine and Spring\",\n        \"The issue lies in the fact that the `meta\",\n        \"Here are some effective marketing tactics for local small businesses\",\n        \"The French Revolution, which lasted from 1789\",\n        \"**Roles of a Network Driver:**\\n\\nA network\",\n        \"Yes, I'm familiar with the SAS (Stat\",\n        \"Using relays to control 12V solen\",\n        \"You can use the following Python code to achieve this\",\n        \"Here are some prompts for viral comics:\\n\\n1.\",\n        \"To simplify and make the comic funnier, consider\",\n        \"Here's a rewritten version of the 4-panel\",\n        \"Subject: Request for E-Waste Collection and Computer\",\n        \"In the context of conference calls, the state you\",\n        \"I can provide a general classification of companies based on\",\n        \"Here are some user stories that describe the concept in\",\n        \"You can check your Python version by running the following\",\n        \"**Scenario:**\\n\\n15-year-old Black youth,\",\n        \"As a Demand Generation Manager for a B2B\",\n        \"The error is due to a typo in your code\",\n        \"고등교육의 필요성에 관한 영어 에\",\n        \"Here's a simple C# program that uses the\",\n        'The error message \"connection refused\" indicates that the',\n        \"To load an image, you can use various methods\",\n    ]\n    equals = [o == e for o, e in zip(outputs, expected)]\n    # This is flaky because depending on actual calculation ordering the exact logits may\n    # switch on equivalent logits based on the position in the batch.\n    # 1 output being different is not uncommon\n    if sum(equals) < len(equals) - 1:\n        assert outputs == expected\n"
  },
  {
    "path": "integration-tests/models/test_flash_medusa.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_medusa_handle(launcher):\n    with launcher(\n        \"FasterDecoding/medusa-vicuna-7b-v1.3\", num_shard=2, revision=\"refs/pr/1\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_medusa(flash_medusa_handle):\n    await flash_medusa_handle.health(300)\n    return flash_medusa_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_medusa_simple(flash_medusa, response_snapshot):\n    response = await flash_medusa.generate(\n        \"What is Deep Learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_medusa_all_params(flash_medusa, response_snapshot):\n    response = await flash_medusa.generate(\n        \"What is Deep Learning?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_medusa, \"What is Deep Learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text for r in responses]}\"\n    assert (\n        responses[0].generated_text == \"\\nDeep learning is a subset of machine learning\"\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_mistral.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_mistral_handle(launcher):\n    with launcher(\"mistralai/Mistral-7B-Instruct-v0.1\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_mistral(flash_mistral_handle):\n    await flash_mistral_handle.health(300)\n    return flash_mistral_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_mistral(flash_mistral, response_snapshot):\n    response = await flash_mistral.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == \": Let n = 10 - 1\"\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mistral_all_params(flash_mistral, response_snapshot):\n    response = await flash_mistral.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_mistral, \"Test request\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n    assert responses[0].generated_text == \": Let n = 10 - 1\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_mixtral.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_mixtral_handle(launcher):\n    with launcher(\"mistralai/Mixtral-8x7B-v0.1\", num_shard=8) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_mixtral(flash_mixtral_handle):\n    await flash_mixtral_handle.health(300)\n    return flash_mixtral_handle.client\n\n\n@pytest.mark.skip(reason=\"requires > 4 shards\")\n@pytest.mark.asyncio\nasync def test_flash_mixtral(flash_mixtral, response_snapshot):\n    response = await flash_mixtral.generate(\n        \"What is gradient descent?\\n\\n\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"Gradient descent is an optimization algorithm used to minimize\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"requires > 4 shards\")\n@pytest.mark.asyncio\nasync def test_flash_mixtral_all_params(flash_mixtral, response_snapshot):\n    response = await flash_mixtral.generate(\n        \"What is gradient descent?\\n\\n\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is gradient descent?\\n\\nIt seems to me, that if you're\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.skip(reason=\"requires > 4 shards\")\n@pytest.mark.asyncio\nasync def test_flash_mixtral_load(flash_mixtral, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_mixtral, \"What is gradient descent?\\n\\n\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert responses[0].details.generated_tokens == 10\n    assert (\n        responses[0].generated_text\n        == \"Gradient descent is an optimization algorithm used to minimize\"\n    )\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_mixtral_awq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_mixtral_awq_handle(launcher):\n    with launcher(\"casperhansen/mixtral-instruct-awq\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_mixtral_awq(flash_mixtral_awq_handle):\n    await flash_mixtral_awq_handle.health(300)\n    return flash_mixtral_awq_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_awq(flash_mixtral_awq, response_snapshot):\n    response = await flash_mixtral_awq.generate(\n        \"What is deep learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text == \"\\n\\nDeep learning is a subset of machine learning\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_awq_all_params(flash_mixtral_awq, response_snapshot):\n    response = await flash_mixtral_awq.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep Learning is a subset of Machine Learning,\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_awq_load(\n    flash_mixtral_awq, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_mixtral_awq, \"What is deep learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert responses[0].details.generated_tokens == 10\n    assert (\n        responses[0].generated_text\n        == \"\\n\\nDeep learning is a subset of machine learning\"\n    )\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_mixtral_gptq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_mixtral_gptq_handle(launcher):\n    with launcher(\n        \"TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ\",\n        revision=\"gptq-4bit-128g-actorder_True\",\n        num_shard=2,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_mixtral_gptq(flash_mixtral_gptq_handle):\n    await flash_mixtral_gptq_handle.health(300)\n    return flash_mixtral_gptq_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_gptq(flash_mixtral_gptq, response_snapshot):\n    response = await flash_mixtral_gptq.generate(\n        \"What is deep learning?\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text == \"\\n\\nDeep learning is a subset of machine learning\"\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_gptq_all_params(flash_mixtral_gptq, response_snapshot):\n    response = await flash_mixtral_gptq.generate(\n        \"What is deep learning?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is deep learning?\\nDeep Learning is a subset of Machine Learning,\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_mixtral_gptq_load(\n    flash_mixtral_gptq, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_mixtral_gptq, \"What is deep learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert (\n        responses[0].generated_text\n        == \"\\n\\nDeep learning is a subset of machine learning\"\n    )\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_neox.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_neox_handle(launcher):\n    with launcher(\"stabilityai/stablelm-tuned-alpha-3b\", num_shard=1) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_neox(flash_neox_handle):\n    await flash_neox_handle.health(300)\n    return flash_neox_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_neox(flash_neox, response_snapshot):\n    response = await flash_neox.generate(\n        \"<|USER|>What's your mood today?<|ASSISTANT|>\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_flash_neox_load(flash_neox, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_neox,\n        \"<|USER|>What's your mood today?<|ASSISTANT|>\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    generated_texts = [r.generated_text for r in responses]\n\n    assert len(generated_texts) == 4\n    assert all(\n        [text == generated_texts[0] for text in generated_texts]\n    ), generated_texts\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_neox_sharded.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_neox_sharded_handle(launcher):\n    with launcher(\"OpenAssistant/oasst-sft-1-pythia-12b\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_neox_sharded(flash_neox_sharded_handle):\n    await flash_neox_sharded_handle.health(300)\n    return flash_neox_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_neox(flash_neox_sharded, response_snapshot):\n    response = await flash_neox_sharded.generate(\n        \"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_neox_sharded,\n        \"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_pali_gemma.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_pali_gemma_handle(launcher):\n    with launcher(\n        \"google/paligemma-3b-pt-224\",\n        num_shard=1,\n        revision=\"float16\",\n        max_input_length=4000,\n        max_total_tokens=4096,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_pali_gemma(flash_pali_gemma_handle):\n    await flash_pali_gemma_handle.health(300)\n    return flash_pali_gemma_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_pali_gemma(flash_pali_gemma, response_snapshot, cow_beach):\n    inputs = f\"![]({cow_beach})Where is the cow standing?\\n\"\n    response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)\n\n    assert response.generated_text == \"beach\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_pali_gemma_two_images(\n    flash_pali_gemma, response_snapshot, chicken, cow_beach\n):\n    response = await flash_pali_gemma.generate(\n        f\"caption![]({chicken})![]({cow_beach})\\n\",\n        max_new_tokens=20,\n    )\n    # Is PaliGemma not able to handle two separate images? At least we\n    # get output showing that both images are used.\n    assert (\n        response.generated_text == \"image result for chicken on the beach\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_pali_gemma2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_pali_gemma_handle(launcher):\n    with launcher(\n        \"google/paligemma2-3b-pt-224\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_pali_gemma(flash_pali_gemma_handle):\n    await flash_pali_gemma_handle.health(300)\n    return flash_pali_gemma_handle.client\n\n\nasync def test_flash_pali_gemma_image(flash_pali_gemma, response_snapshot):\n    car_image = \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg\"\n    response = await flash_pali_gemma.generate(\n        f\"![]({car_image})\",\n        max_new_tokens=20,\n    )\n    assert (\n        response.generated_text == \"\\nBrown\\nCar\\nColor\\nCool\\nDecor\\n\\n\\n\\n\\n\\n\\n?\\n?\"\n    )\n\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_phi.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_phi_handle(launcher):\n    with launcher(\"microsoft/phi-2\", num_shard=1) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_phi(flash_phi_handle):\n    await flash_phi_handle.health(300)\n    return flash_phi_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_phi(flash_phi, response_snapshot):\n    response = await flash_phi.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == ': {request}\")\\n        response = self'\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_phi_all_params(flash_phi, response_snapshot):\n    response = await flash_phi.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"network\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 6\n    assert response.generated_text == \"Test request to send data over a network\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_phi_load(flash_phi, generate_load, response_snapshot):\n    responses = await generate_load(flash_phi, \"Test request\", max_new_tokens=10, n=4)\n\n    assert len(responses) == 4\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n    assert responses[0].generated_text == ': {request}\")\\n        response = self'\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_phi35_moe.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_phi35_moe_handle(launcher):\n    with launcher(\n        \"microsoft/Phi-3.5-MoE-instruct\",\n        num_shard=4,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_phi35_moe(flash_phi35_moe_handle):\n    await flash_phi35_moe_handle.health(300)\n    return flash_phi35_moe_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_phi35_moe(flash_phi35_moe, response_snapshot):\n    response = await flash_phi35_moe.generate(\n        \"What is gradient descent?\\n\\n\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"Gradient descent is an optimization algorithm commonly used in\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_phi35_moe_all_params(flash_phi35_moe, response_snapshot):\n    response = await flash_phi35_moe.generate(\n        \"What is gradient descent?\\n\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"What is gradient descent?\\nGradient Descent (GD) is an\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_phi35_moe_load(flash_phi35_moe, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_phi35_moe, \"What is gradient descent?\\n\\n\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert responses[0].details.generated_tokens == 10\n    assert (\n        responses[0].generated_text\n        == \"Gradient descent is an optimization algorithm commonly used in\"\n    )\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_qwen2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_qwen2_handle(launcher):\n    with launcher(\"Qwen/Qwen1.5-0.5B\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_qwen2(flash_qwen2_handle):\n    await flash_qwen2_handle.health(300)\n    return flash_qwen2_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_qwen2(flash_qwen2, response_snapshot):\n    response = await flash_qwen2.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == \"\\n# Create a request\\nrequest = requests.get\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):\n    response = await flash_qwen2.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):\n    responses = await generate_load(flash_qwen2, \"Test request\", max_new_tokens=10, n=4)\n\n    assert len(responses) == 4\n    assert all(\n        [r.generated_text == responses[0].generated_text for r in responses]\n    ), f\"{[r.generated_text  for r in responses]}\"\n    assert responses[0].generated_text == \"\\n# Create a request\\nrequest = requests.get\"\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_qwen2_5_vl.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_qwen2_5_vl_handle(launcher):\n    with launcher(\"Qwen/Qwen2.5-VL-3B-Instruct\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_qwen2_5(flash_qwen2_5_vl_handle):\n    await flash_qwen2_5_vl_handle.health(300)\n    return flash_qwen2_5_vl_handle.client\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_5_vl_simple(flash_qwen2_5, response_snapshot):\n    response = await flash_qwen2_5.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n    )\n\n    assert (\n        response.choices[0].message.content\n        == \"The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.\"\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_5_vl_simple_streaming(flash_qwen2_5, response_snapshot):\n    responses = await flash_qwen2_5.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n        stream=True,\n    )\n\n    count = 0\n    generated = \"\"\n    last_response = None\n    async for response in responses:\n        count += 1\n        generated += response.choices[0].delta.content\n        last_response = response\n\n    assert (\n        generated\n        == \"The image depicts an anthropomorphic rabbit character wearing an intricate space suit, which includes a helmet with a starry face pattern and multiple suitors. The rabbit's ears are significantly large and upright, and it has a hitchhiker-like star antennas on its chest. The background is a reddish-orange, rocky landscape, suggesting a Martian environment. The suit has various buttons, a red button on the chest, and a reflective or illuminated dome on the head. The overall color scheme is dominated by shades of red, orange, and gray, giving a sense of a rugged, otherworldly setting.\"\n    )\n    assert count == 121\n    assert last_response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_5_vl_bay(flash_qwen2_5, response_snapshot):\n    response = await flash_qwen2_5.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_5_vl_inpaint(flash_qwen2_5, response_snapshot):\n    response = await flash_qwen2_5.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_qwen2_vl.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_qwen2_vl_handle(launcher):\n    with launcher(\"Qwen/Qwen2-VL-7B-Instruct\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_qwen2(flash_qwen2_vl_handle):\n    await flash_qwen2_vl_handle.health(300)\n    return flash_qwen2_vl_handle.client\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_vl_simple(flash_qwen2, response_snapshot):\n    response = await flash_qwen2.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe this image.\"},\n                ],\n            },\n        ],\n    )\n\n    assert (\n        response.choices[0].message.content\n        == \"The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character.\"\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_vl_simple_streaming(flash_qwen2, response_snapshot):\n    responses = await flash_qwen2.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe this image.\"},\n                ],\n            },\n        ],\n        stream=True,\n    )\n\n    count = 0\n    generated = \"\"\n    last_response = None\n    async for response in responses:\n        count += 1\n        generated += response.choices[0].delta.content\n        last_response = response\n\n    assert (\n        generated\n        == \"The image depicts an anthropomorphic rabbit, wearing a spacesuit, standing in a barren, rocky landscape that resembles the surface of another planet, possibly Mars. The rabbit has a red digestive system label on its chest, and the surrounding environment features red sandy terrain and a hazy, floating planet or moon in the background. The scene has a surreal, fantastical quality, blending elements of science fiction and space exploration with a whimsical character.\"\n    )\n    assert count == 89\n    assert last_response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_vl_bay(flash_qwen2, response_snapshot):\n    response = await flash_qwen2.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.private\nasync def test_flash_qwen2_vl_inpaint(flash_qwen2, response_snapshot):\n    response = await flash_qwen2.chat(\n        seed=42,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png\"\n                        },\n                    },\n                    {\"type\": \"text\", \"text\": \"Describe the image\"},\n                ],\n            },\n        ],\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_santacoder.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_santacoder_handle(launcher):\n    with launcher(\"bigcode/santacoder\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_santacoder(flash_santacoder_handle):\n    await flash_santacoder_handle.health(300)\n    return flash_santacoder_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_santacoder(flash_santacoder, response_snapshot):\n    response = await flash_santacoder.generate(\n        \"def print_hello\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_santacoder_load(\n    flash_santacoder, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_santacoder, \"def print_hello\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_starcoder.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_starcoder_handle(launcher):\n    with launcher(\"bigcode/starcoder\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_starcoder(flash_starcoder_handle):\n    await flash_starcoder_handle.health(300)\n    return flash_starcoder_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder(flash_starcoder, response_snapshot):\n    response = await flash_starcoder.generate(\n        \"def print_hello\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder_default_params(flash_starcoder, response_snapshot):\n    response = await flash_starcoder.generate(\n        \"def print_hello\",\n        max_new_tokens=60,\n        temperature=0.2,\n        top_p=0.95,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 60\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder_load(flash_starcoder, generate_load, response_snapshot):\n    responses = await generate_load(\n        flash_starcoder, \"def print_hello\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_starcoder2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_starcoder2_handle(launcher):\n    with launcher(\"bigcode/starcoder2-3b\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_starcoder2(flash_starcoder2_handle):\n    await flash_starcoder2_handle.health(300)\n    return flash_starcoder2_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder2(flash_starcoder2, response_snapshot):\n    response = await flash_starcoder2.generate(\n        \"def print_hello\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):\n    response = await flash_starcoder2.generate(\n        \"def print_hello\",\n        max_new_tokens=60,\n        temperature=0.2,\n        top_p=0.95,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 60\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_starcoder2_load(\n    flash_starcoder2, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_starcoder2, \"def print_hello\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_starcoder2_lora.py",
    "content": "import pytest\nimport requests\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_starcoder2_handle(launcher):\n    with launcher(\n        \"bigcode/starcoder2-3b\", lora_adapters=[\"smangrul/starcoder-3b-hugcoder\"]\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_starcoder2(flash_starcoder2_handle):\n    await flash_starcoder2_handle.health(300)\n    return flash_starcoder2_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_flash_starcoder2(flash_starcoder2, response_snapshot):\n    response = await flash_starcoder2.generate(\n        \"def print_hello\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_starcoder2_default_params(flash_starcoder2, response_snapshot):\n    response = await flash_starcoder2.generate(\n        \"who are you?\",\n        max_new_tokens=60,\n        temperature=0.2,\n        top_p=0.95,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 60\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_starcoder2_load(\n    flash_starcoder2, generate_load, response_snapshot\n):\n    responses = await generate_load(\n        flash_starcoder2, \"who are you?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n\n\n@pytest.mark.asyncio\nasync def test_flash_starcoder2_with_hugcode_adapter(\n    flash_starcoder2, response_snapshot\n):\n    response = requests.post(\n        f\"{flash_starcoder2.base_url}/generate\",\n        headers=flash_starcoder2.headers,\n        json={\n            \"inputs\": \"def print_hello\",\n            \"parameters\": {\n                \"max_new_tokens\": 10,\n                \"adapter_id\": \"smangrul/starcoder-3b-hugcoder\",\n                \"details\": True,\n            },\n        },\n    )\n\n    assert response.status_code == 200\n    data = response.json()\n    assert data[\"generated_text\"] == '_world():\\n    print(\"Hello World!\")\\n'\n\n    assert data == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_flash_starcoder_gptq.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_starcoder_gptq_handle(launcher):\n    with launcher(\"Narsil/starcoder-gptq\", num_shard=2, quantize=\"gptq\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_starcoder_gptq(flash_starcoder_gptq_handle):\n    await flash_starcoder_gptq_handle.health(300)\n    return flash_starcoder_gptq_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot):\n    response = await flash_starcoder_gptq.generate(\n        \"def geometric_mean(L: List[float]):\",\n        max_new_tokens=20,\n        decoder_input_details=True,\n    )\n    assert response.details.generated_tokens == 2\n    assert response == generous_response_snapshot\n\n\n# Deactivated because it's flaky\n# Only this model seems affected and it's only a logprob precision issue.\n# @pytest.mark.release\n# @pytest.mark.asyncio\n# async def test_flash_starcoder_gptq_default_params(\n#     flash_starcoder_gptq, generous_response_snapshot\n# ):\n#     response = await flash_starcoder_gptq.generate(\n#         \"def geometric_mean(L: List[float]):\",\n#         max_new_tokens=20,\n#         temperature=0.2,\n#         top_p=0.95,\n#         decoder_input_details=True,\n#         seed=0,\n#     )\n#     assert response.details.generated_tokens == 2\n#     assert response == generous_response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_flash_starcoder_gptq_load(\n    flash_starcoder_gptq, generate_load, generous_response_snapshot\n):\n    responses = await generate_load(\n        flash_starcoder_gptq,\n        \"def geometric_mean(L: List[float]):\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    # XXX: TODO: Fix this test.\n    # assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    # assert responses == generous_response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_grammar_llama.py",
    "content": "import pytest\nimport json\n\nfrom text_generation.types import GrammarType\n\n\n@pytest.fixture(scope=\"module\")\ndef non_flash_llama_grammar_handle(launcher):\n    with launcher(\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n        num_shard=1,\n        disable_grammar_support=False,\n        use_flash_attention=False,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def non_flash_llama_grammar(non_flash_llama_grammar_handle):\n    await non_flash_llama_grammar_handle.health(300)\n    return non_flash_llama_grammar_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):\n    response = await non_flash_llama_grammar.generate(\n        \"info: david holtz like trees and has two cats. \",\n        max_new_tokens=100,\n        decoder_input_details=True,\n        seed=0,\n        grammar={\n            \"type\": GrammarType.Json,\n            \"value\": json.dumps(\n                {\n                    \"type\": \"object\",\n                    \"$id\": \"https://example.com/person.schema.json\",\n                    \"$schema\": \"https://json-schema.org/draft/2020-12/schema\",\n                    \"title\": \"Person\",\n                    \"properties\": {\n                        \"firstName\": {\n                            \"type\": \"string\",\n                            \"description\": \"The person'''s first name.\",\n                        },\n                        \"lastName\": {\n                            \"type\": \"string\",\n                            \"description\": \"The person'''s last name.\",\n                        },\n                        \"hobby\": {\n                            \"description\": \"The person'''s hobby.\",\n                            \"type\": \"string\",\n                        },\n                        \"numCats\": {\n                            \"description\": \"The number of cats the person has.\",\n                            \"type\": \"integer\",\n                            \"minimum\": 0,\n                        },\n                    },\n                    \"required\": [\"firstName\", \"lastName\", \"hobby\", \"numCats\"],\n                }\n            ),\n        },\n    )\n\n    assert response.details.generated_tokens == 30\n    assert (\n        response.generated_text\n        == '{\"firstName\":\"David\",\"hobby\":\"Trees\",\"lastName\":\"Holtz\",\"numCats\":2}'\n    )\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_grammar_response_format_llama.py",
    "content": "import pytest\nimport requests\nfrom pydantic import BaseModel\nfrom typing import List\n\n\n@pytest.fixture(scope=\"module\")\ndef llama_grammar_handle(launcher):\n    with launcher(\n        \"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n        num_shard=1,\n        disable_grammar_support=False,\n        use_flash_attention=False,\n        max_batch_prefill_tokens=3000,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def llama_grammar(llama_grammar_handle):\n    await llama_grammar_handle.health(300)\n    return llama_grammar_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):\n    class Weather(BaseModel):\n        unit: str\n        temperature: List[int]\n\n    json_payload = {\n        \"model\": \"tgi\",\n        \"messages\": [\n            {\n                \"role\": \"system\",\n                \"content\": f\"Respond to the users questions and answer them in the following format: {Weather.model_json_schema()}\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What's the weather like the next 3 days in San Francisco, CA?\",\n            },\n        ],\n        \"seed\": 42,\n        \"max_tokens\": 500,\n        \"response_format\": {\n            \"type\": \"json_object\",\n            \"value\": Weather.model_json_schema(),\n        },\n    }\n    # send the request\n    response = requests.post(\n        f\"{llama_grammar.base_url}/v1/chat/completions\",\n        headers=llama_grammar.headers,\n        json=json_payload,\n    )\n\n    chat_completion = response.json()\n    called = chat_completion[\"choices\"][0][\"message\"][\"content\"]\n\n    assert response.status_code == 200\n    assert called == '{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }'\n    assert chat_completion == response_snapshot\n\n    json_payload[\"response_format\"][\"type\"] = \"json\"\n    response = requests.post(\n        f\"{llama_grammar.base_url}/v1/chat/completions\",\n        headers=llama_grammar.headers,\n        json=json_payload,\n    )\n\n    chat_completion = response.json()\n    called = chat_completion[\"choices\"][0][\"message\"][\"content\"]\n\n    assert response.status_code == 200\n    assert called == '{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }'\n    assert chat_completion == response_snapshot\n\n    json_payload[\"response_format\"] = {\n        \"type\": \"json_schema\",\n        \"value\": {\n            \"name\": \"weather\",\n            \"strict\": True,\n            \"schema\": Weather.model_json_schema(),\n        },\n    }\n    response = requests.post(\n        f\"{llama_grammar.base_url}/v1/chat/completions\",\n        headers=llama_grammar.headers,\n        json=json_payload,\n    )\n\n    chat_completion = response.json()\n    called = chat_completion[\"choices\"][0][\"message\"][\"content\"]\n\n    assert response.status_code == 200\n    assert called == '{ \"unit\": \"fahrenheit\", \"temperature\": [ 72, 79, 88 ] }'\n    assert chat_completion == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_grammar_response_format_llama_error_if_tools_not_installed(\n    llama_grammar,\n):\n    class Weather(BaseModel):\n        unit: str\n        temperature: List[int]\n\n    # send the request\n    response = requests.post(\n        f\"{llama_grammar.base_url}/v1/chat/completions\",\n        headers=llama_grammar.headers,\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\n                    \"role\": \"system\",\n                    \"content\": f\"Respond to the users questions and answer them in the following format: {Weather.model_json_schema()}\",\n                },\n                {\n                    \"role\": \"user\",\n                    \"content\": \"What's the weather like the next 3 days in San Francisco, CA?\",\n                },\n            ],\n            \"seed\": 42,\n            \"max_tokens\": 500,\n            \"tools\": [],\n            \"response_format\": {\n                \"type\": \"json_object\",\n                \"value\": Weather.model_json_schema(),\n            },\n        },\n    )\n\n    # 422 means the server was unable to process the request because it contains invalid data.\n    assert response.status_code == 422\n    assert response.json() == {\n        \"error\": \"Tool error: Grammar and tools are mutually exclusive\",\n        \"error_type\": \"tool_error\",\n    }\n"
  },
  {
    "path": "integration-tests/models/test_idefics.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef idefics_handle(launcher):\n    with launcher(\n        \"HuggingFaceM4/idefics-9b-instruct\", num_shard=2, dtype=\"float16\"\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def idefics(idefics_handle):\n    await idefics_handle.health(300)\n    return idefics_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_idefics(idefics, response_snapshot, chicken):\n    response = await idefics.generate(\n        f\"User:![]({chicken})Can you tell me a very short story based on the image?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text == \" \\nAssistant: A rooster stands\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_idefics_two_images(idefics, response_snapshot, chicken, cow_beach):\n    response = await idefics.generate(\n        f\"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \\nAssistant:\",\n        max_new_tokens=20,\n    )\n    assert (\n        response.generated_text == \" The cow and chicken are standing on a beach.\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_idefics_load(idefics, generate_load, response_snapshot, chicken):\n    responses = await generate_load(\n        idefics,\n        f\"User:![]({chicken})Can you tell me a very short story based on the image?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    generated_texts = [r.generated_text for r in responses]\n\n    assert generated_texts[0] == \" \\nAssistant: A rooster stands\"\n    assert len(generated_texts) == 4\n    assert generated_texts, all(\n        [text == generated_texts[0] for text in generated_texts]\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_idefics2.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_idefics2_next_handle(launcher):\n    with launcher(\n        \"HuggingFaceM4/idefics2-8b\",\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_idefics2_next(flash_idefics2_next_handle):\n    await flash_idefics2_next_handle.health(300)\n    return flash_idefics2_next_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_idefics2_next_simple(\n    flash_idefics2_next, response_snapshot, chicken\n):\n    response = await flash_idefics2_next.generate(\n        f\"User:![]({chicken})Write me a short story<end_of_utterance> \\nAssistant:\",\n        max_new_tokens=10,\n    )\n    assert (\n        response.generated_text == \" A chicken is sitting on a pile of money.\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_idefics2_two_images(\n    flash_idefics2_next, response_snapshot, chicken, cow_beach\n):\n    response = await flash_idefics2_next.generate(\n        f\"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \\nAssistant:\",\n        max_new_tokens=20,\n    )\n    assert (\n        response.generated_text\n        == \" The cow is standing on the beach and the chicken is sitting on a pile of money.\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response.details.generated_tokens == 19\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):\n    response = await flash_idefics2_next.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_idefics2_next_load(\n    flash_idefics2_next, generate_load, response_snapshot, chicken\n):\n    responses = await generate_load(\n        flash_idefics2_next,\n        f\"User:![]({chicken})Write me a short story<end_of_utterance> \\nAssistant:\",\n        max_new_tokens=10,\n        n=4,\n    )\n    generated_texts = [r.generated_text for r in responses]\n    assert generated_texts[0] == \" A chicken is sitting on a pile of money.\"\n    assert len(generated_texts) == 4\n    assert all([r.generated_text == generated_texts[0] for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_idefics3.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_idefics3_next_handle(launcher):\n    with launcher(\"HuggingFaceM4/Idefics3-8B-Llama3\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_idefics3_next(flash_idefics3_next_handle):\n    await flash_idefics3_next_handle.health(300)\n    return flash_idefics3_next_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):\n    ny_skyline = \"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg\"\n    query = \"What is in this image?\"\n    response = await flash_idefics3_next.generate(\n        f\"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\\nAssistant:\",\n        max_new_tokens=10,\n        seed=1337,\n    )\n    print(response)\n    assert (\n        response.generated_text == \" There is a statue in the image.\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response.details.generated_tokens == 9\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_json_schema_constrain.py",
    "content": "import pytest\nimport json\nimport requests\n\n\n@pytest.fixture(scope=\"module\")\ndef model_handle(launcher):\n    \"\"\"Fixture to provide the base URL for API calls.\"\"\"\n    with launcher(\n        \"google/gemma-3-4b-it\",\n        num_shard=2,\n        disable_grammar_support=False,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def model_fixture(model_handle):\n    await model_handle.health(300)\n    return model_handle.client\n\n\n# Sample JSON Schema for testing\nperson_schema = {\n    \"type\": \"object\",\n    \"$id\": \"https://example.com/person.schema.json\",\n    \"$schema\": \"https://json-schema.org/draft/2020-12/schema\",\n    \"title\": \"Person\",\n    \"properties\": {\n        \"firstName\": {\n            \"type\": \"string\",\n            \"description\": \"The person's first name.\",\n            \"minLength\": 4,\n        },\n        \"lastName\": {\n            \"type\": \"string\",\n            \"description\": \"The person's last name.\",\n            \"minLength\": 4,\n        },\n        \"hobby\": {\n            \"description\": \"The person's hobby.\",\n            \"type\": \"string\",\n            \"minLength\": 4,\n        },\n        \"numCats\": {\n            \"description\": \"The number of cats the person has.\",\n            \"type\": \"integer\",\n            \"minimum\": 0,\n        },\n    },\n    \"required\": [\"firstName\", \"lastName\", \"hobby\", \"numCats\"],\n}\n\n# More complex schema for testing nested objects and arrays\ncomplex_schema = {\n    \"type\": \"object\",\n    \"properties\": {\n        \"name\": {\"type\": \"string\"},\n        \"age\": {\"type\": \"integer\", \"minimum\": 0},\n        \"address\": {\n            \"type\": \"object\",\n            \"properties\": {\n                \"street\": {\"type\": \"string\"},\n                \"city\": {\"type\": \"string\"},\n                \"postalCode\": {\"type\": \"string\"},\n            },\n            \"required\": [\"street\", \"city\"],\n        },\n        \"hobbies\": {\"type\": \"array\", \"items\": {\"type\": \"string\"}, \"minItems\": 1},\n    },\n    \"required\": [\"name\", \"age\", \"hobbies\"],\n}\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_json_schema_basic(model_fixture, response_snapshot):\n    \"\"\"Test basic JSON schema validation with the person schema.\"\"\"\n    response = requests.post(\n        f\"{model_fixture.base_url}/v1/chat/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": \"David is a person who likes trees and nature. He enjoys studying math and science. He has 2 cats.\",\n                },\n            ],\n            \"seed\": 42,\n            \"temperature\": 0.0,\n            \"response_format\": {\n                \"type\": \"json_schema\",\n                \"value\": {\"name\": \"person\", \"strict\": True, \"schema\": person_schema},\n            },\n        },\n    )\n\n    result = response.json()\n\n    # Validate response format\n    content = result[\"choices\"][0][\"message\"][\"content\"]\n    parsed_content = json.loads(content)\n\n    assert \"firstName\" in parsed_content\n    assert \"lastName\" in parsed_content\n    assert \"hobby\" in parsed_content\n    assert \"numCats\" in parsed_content\n    assert isinstance(parsed_content[\"numCats\"], int)\n    assert parsed_content[\"numCats\"] >= 0\n    assert result == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_json_schema_complex(model_fixture, response_snapshot):\n    \"\"\"Test complex JSON schema with nested objects and arrays.\"\"\"\n    response = requests.post(\n        f\"{model_fixture.base_url}/v1/chat/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": \"John Smith is 30 years old. He lives on Maple Street in Boston. He enjoys botany, astronomy, and solving mathematical puzzles.\",\n                },\n            ],\n            \"seed\": 42,\n            \"temperature\": 0.0,\n            \"response_format\": {\n                \"type\": \"json_schema\",\n                \"value\": {\n                    \"name\": \"complex_person\",\n                    \"strict\": True,\n                    \"schema\": complex_schema,\n                },\n            },\n        },\n    )\n\n    result = response.json()\n\n    # Validate response format\n    content = result[\"choices\"][0][\"message\"][\"content\"]\n    parsed_content = json.loads(content)\n\n    assert \"name\" in parsed_content\n    assert \"age\" in parsed_content\n    assert \"hobbies\" in parsed_content\n    assert \"address\" in parsed_content\n    assert \"street\" in parsed_content[\"address\"]\n    assert \"city\" in parsed_content[\"address\"]\n    assert isinstance(parsed_content[\"hobbies\"], list)\n    assert len(parsed_content[\"hobbies\"]) >= 1\n    assert result == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_json_schema_stream(model_fixture, response_snapshot):\n    \"\"\"Test JSON schema validation with streaming.\"\"\"\n    response = requests.post(\n        f\"{model_fixture.base_url}/v1/chat/completions\",\n        json={\n            \"model\": \"tgi\",\n            \"messages\": [\n                {\n                    \"role\": \"user\",\n                    \"content\": \"David is a person who likes to ride bicycles. He has 2 cats.\",\n                },\n            ],\n            \"seed\": 42,\n            \"temperature\": 0.0,\n            \"response_format\": {\n                \"type\": \"json_schema\",\n                \"value\": {\"name\": \"person\", \"strict\": True, \"schema\": person_schema},\n            },\n            \"stream\": True,\n        },\n        stream=True,\n    )\n\n    chunks = []\n    content_generated = \"\"\n\n    for line in response.iter_lines():\n        if line:\n            # Remove the \"data: \" prefix and handle the special case of \"[DONE]\"\n            data = line.decode(\"utf-8\")\n            if data.startswith(\"data: \"):\n                data = data[6:]\n                if data != \"[DONE]\":\n                    chunk = json.loads(data)\n                    chunks.append(chunk)\n                    if \"choices\" in chunk and len(chunk[\"choices\"]) > 0:\n                        if (\n                            \"delta\" in chunk[\"choices\"][0]\n                            and \"content\" in chunk[\"choices\"][0][\"delta\"]\n                        ):\n                            content_generated += chunk[\"choices\"][0][\"delta\"][\"content\"]\n\n    # Validate the final assembled JSON\n    parsed_content = json.loads(content_generated)\n    assert \"firstName\" in parsed_content\n    assert \"lastName\" in parsed_content\n    assert \"hobby\" in parsed_content\n    assert \"numCats\" in parsed_content\n    assert isinstance(parsed_content[\"numCats\"], int)\n    assert parsed_content[\"numCats\"] >= 0\n    assert chunks == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_llava_next.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llava_next_handle(launcher):\n    with launcher(\n        \"llava-hf/llava-v1.6-mistral-7b-hf\",\n        num_shard=4,\n        max_input_length=4000,\n        max_total_tokens=4096,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llava_next(flash_llava_next_handle):\n    await flash_llava_next_handle.health(300)\n    return flash_llava_next_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llava_next_simple(flash_llava_next, response_snapshot, chicken):\n    response = await flash_llava_next.generate(\n        f\"User:![]({chicken})Can you tell me a very short story based on the image?\",\n        max_new_tokens=10,\n    )\n    assert (\n        response.generated_text == \"\\n\\nOnce upon a time, there was a\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llava_next_all_params(flash_llava_next, response_snapshot):\n    response = await flash_llava_next.generate(\n        \"Test request\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 6\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llava_next_load(\n    flash_llava_next, generate_load, response_snapshot, chicken\n):\n    responses = await generate_load(\n        flash_llava_next,\n        f\"User:![]({chicken})Can you tell me a very short story based on the image?\",\n        max_new_tokens=10,\n        n=4,\n    )\n    generated_texts = [r.generated_text for r in responses]\n    assert generated_texts[0] == \"\\n\\nOnce upon a time, there was a\"\n    assert len(generated_texts) == 4\n    assert all([r.generated_text == generated_texts[0] for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_lora_mistral.py",
    "content": "import pytest\nimport requests\n\n\n@pytest.fixture(scope=\"module\")\ndef lora_mistral_handle(launcher):\n    with launcher(\n        \"mistralai/Mistral-7B-v0.1\",\n        lora_adapters=[\n            \"predibase/dbpedia\",\n            \"predibase/customer_support\",\n        ],\n        cuda_graphs=[0],\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def lora_mistral(lora_mistral_handle):\n    await lora_mistral_handle.health(300)\n    return lora_mistral_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_lora_mistral(lora_mistral, response_snapshot):\n    response = await lora_mistral.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n    assert response.details.generated_tokens == 10\n\n\nclassification_prompt = \"\"\"You are given the title and the body of an article below. Please determine the type of the article.\\n### Title: Great White Whale\\n\\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\\n\\n### Article Type:\"\"\"\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_lora_mistral_without_adapter(lora_mistral, response_snapshot):\n    response = requests.post(\n        f\"{lora_mistral.base_url}/generate\",\n        headers=lora_mistral.headers,\n        json={\n            \"inputs\": classification_prompt,\n            \"parameters\": {\n                \"max_new_tokens\": 40,\n                \"details\": True,\n            },\n        },\n    )\n\n    assert response.status_code == 200\n    data = response.json()\n    assert (\n        data[\"generated_text\"]\n        == \"\\n\\n### 1. News\\n### 2. Blog\\n### 3. Article\\n### 4. Review\\n### 5. Other\\n\\n\\n\\n\\n\\n\\n\\n\\n\"\n    )\n    assert data == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot):\n    response = requests.post(\n        f\"{lora_mistral.base_url}/generate\",\n        headers=lora_mistral.headers,\n        json={\n            \"inputs\": classification_prompt,\n            \"parameters\": {\n                \"max_new_tokens\": 40,\n                \"adapter_id\": \"predibase/dbpedia\",\n                \"details\": True,\n            },\n        },\n    )\n\n    assert response.status_code == 200\n    data = response.json()\n    assert data[\"generated_text\"] == \"  11\"\n    assert data == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_lora_mistral_with_customer_support_adapter(\n    lora_mistral, response_snapshot\n):\n    print(lora_mistral.base_url)\n    print(lora_mistral.headers)\n    response = requests.post(\n        f\"{lora_mistral.base_url}/generate\",\n        headers=lora_mistral.headers,\n        json={\n            \"inputs\": \"What are 3 unique words that describe you?\",\n            \"parameters\": {\n                \"max_new_tokens\": 40,\n                \"adapter_id\": \"predibase/customer_support\",\n                \"details\": True,\n            },\n        },\n    )\n\n    assert response.status_code == 200\n    data = response.json()\n    assert (\n        data[\"generated_text\"]\n        == \"\\n\\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\\n\\n1. Creative\\n2. Funny\\n3.\"\n    )\n    assert data == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_lora_mistral_without_customer_support_adapter(\n    lora_mistral, response_snapshot\n):\n    response = requests.post(\n        f\"{lora_mistral.base_url}/generate\",\n        headers=lora_mistral.headers,\n        json={\n            \"inputs\": \"What are 3 unique words that describe you?\",\n            \"parameters\": {\n                \"max_new_tokens\": 40,\n                \"details\": True,\n            },\n        },\n    )\n\n    assert response.status_code == 200\n    data = response.json()\n    assert (\n        data[\"generated_text\"]\n        == \"\\n\\nI’m a very passionate person. I’m very driven. I’m very determined.\\n\\nWhat is your favorite thing about being a teacher?\\n\\nI love the fact\"\n    )\n    assert data == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_mamba.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef fused_kernel_mamba_handle(launcher):\n    with launcher(\"state-spaces/mamba-130m-hf\", num_shard=1) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def fused_kernel_mamba(fused_kernel_mamba_handle):\n    await fused_kernel_mamba_handle.health(300)\n    return fused_kernel_mamba_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mamba(fused_kernel_mamba, response_snapshot):\n    response = await fused_kernel_mamba.generate(\n        \"What is Deep Learning?\", max_new_tokens=10\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == \"\\n\\nDeep learning is a new type of machine\"\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mamba_all_params(fused_kernel_mamba, response_snapshot):\n    response = await fused_kernel_mamba.generate(\n        \"blue, red, yellow, \",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert (\n        response.generated_text\n        == \"blue, red, yellow, \\nand blue colors. A number of the color\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mamba_load(\n    fused_kernel_mamba, generate_load, generous_response_snapshot\n):\n    responses = await generate_load(\n        fused_kernel_mamba, \"What is Deep Learning?\", max_new_tokens=10, n=4\n    )\n\n    assert len(responses) == 4\n    assert responses[0].generated_text == \"\\n\\nDeep learning is a new type of machine\"\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n    assert responses[0].generated_text == \"\\n\\nDeep learning is a new type of machine\"\n\n    assert responses == generous_response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_mllama.py",
    "content": "import pytest\nimport asyncio\n\n\n@pytest.fixture(scope=\"module\")\ndef mllama_handle(launcher):\n    with launcher(\n        \"unsloth/Llama-3.2-11B-Vision-Instruct\",\n        num_shard=2,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def mllama(mllama_handle):\n    await mllama_handle.health(300)\n    return mllama_handle.client\n\n\n@pytest.mark.asyncio\nasync def test_mllama_simpl(mllama, response_snapshot):\n    response = await mllama.chat(\n        max_tokens=10,\n        temperature=0.0,\n        messages=[\n            {\n                \"role\": \"user\",\n                \"content\": [\n                    {\n                        \"type\": \"text\",\n                        \"text\": \"Describe the image in 10 words.\",\n                    },\n                    {\n                        \"type\": \"image_url\",\n                        \"image_url\": {\n                            \"url\": \"https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png\"\n                        },\n                    },\n                ],\n            },\n        ],\n    )\n\n    assert response.usage == {\n        \"completion_tokens\": 10,\n        \"prompt_tokens\": 45,\n        \"total_tokens\": 55,\n    }\n    assert (\n        response.choices[0].message.content\n        == \"A chicken sits on a pile of money, looking\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mllama_load(mllama, generate_load, response_snapshot):\n    futures = [\n        mllama.chat(\n            max_tokens=10,\n            temperature=0.0,\n            messages=[\n                {\n                    \"role\": \"user\",\n                    \"content\": [\n                        {\n                            \"type\": \"text\",\n                            \"text\": \"Describe the image in 10 words.\",\n                        },\n                        {\n                            \"type\": \"image_url\",\n                            \"image_url\": {\n                                \"url\": \"https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png\"\n                            },\n                        },\n                    ],\n                },\n            ],\n        )\n        # TODO with v3, 4 breaks here. Nothing accounts of the image VRAM\n        # because mllama is the only one doing its thing.\n        for i in range(2)\n    ]\n    responses = await asyncio.gather(*futures)\n\n    generated_texts = [response.choices[0].message.content for response in responses]\n\n    # XXX: TODO: Fix this test.\n    assert generated_texts[0] == \"A chicken sits on a pile of money, looking\"\n    assert len(generated_texts) == 2\n    assert generated_texts, all(\n        [text == generated_texts[0] for text in generated_texts]\n    )\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_mpt.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef mpt_sharded_handle(launcher):\n    with launcher(\"mosaicml/mpt-7b\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def mpt_sharded(mpt_sharded_handle):\n    await mpt_sharded_handle.health(300)\n    return mpt_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mpt(mpt_sharded, response_snapshot):\n    response = await mpt_sharded.generate(\n        \"What is Deep Learning?\",\n        max_new_tokens=17,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 17\n    assert (\n        response.generated_text\n        == \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mpt_load(mpt_sharded, generate_load, response_snapshot):\n    responses = await generate_load(\n        mpt_sharded,\n        \"What is Deep Learning?\",\n        max_new_tokens=17,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n    assert (\n        responses[0].generated_text\n        == \" - Deep Learning\\nDeep Learning is a subfield of machine learning that uses artificial neural\"\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_mt0_base.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef mt0_base_handle(launcher):\n    with launcher(\"bigscience/mt0-base\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def mt0_base(mt0_base_handle):\n    await mt0_base_handle.health(300)\n    return mt0_base_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mt0_base(mt0_base, response_snapshot):\n    response = await mt0_base.generate(\n        \"Why is the sky blue?\",\n        max_new_tokens=10,\n        top_p=0.9,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 5\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mt0_base_all_params(mt0_base, response_snapshot):\n    response = await mt0_base.generate(\n        \"Why is the sky blue?\",\n        max_new_tokens=10,\n        repetition_penalty=1.2,\n        return_full_text=True,\n        stop_sequences=[\"test\"],\n        temperature=0.5,\n        top_p=0.9,\n        top_k=10,\n        truncate=5,\n        typical_p=0.9,\n        watermark=True,\n        decoder_input_details=True,\n        seed=0,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_mt0_base_load(mt0_base, generate_load, response_snapshot):\n    responses = await generate_load(\n        mt0_base,\n        \"Why is the sky blue?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_neox.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef neox_handle(launcher):\n    with launcher(\n        \"stabilityai/stablelm-tuned-alpha-3b\", num_shard=1, use_flash_attention=False\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def neox(neox_handle):\n    await neox_handle.health(300)\n    return neox_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_neox(neox, response_snapshot):\n    response = await neox.generate(\n        \"<|USER|>What's your mood today?<|ASSISTANT|>\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_neox_load(neox, generate_load, response_snapshot):\n    responses = await generate_load(\n        neox,\n        \"<|USER|>What's your mood today?<|ASSISTANT|>\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    generated_texts = [r.generated_text for r in responses]\n\n    assert len(generated_texts) == 4\n    assert generated_texts, all(\n        [text == generated_texts[0] for text in generated_texts]\n    )\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_neox_sharded.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef neox_sharded_handle(launcher):\n    with launcher(\n        \"OpenAssistant/oasst-sft-1-pythia-12b\", num_shard=2, use_flash_attention=False\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def neox_sharded(neox_sharded_handle):\n    await neox_sharded_handle.health(300)\n    return neox_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_neox(neox_sharded, response_snapshot):\n    response = await neox_sharded.generate(\n        \"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.skip\n@pytest.mark.asyncio\nasync def test_neox_load(neox_sharded, generate_load, response_snapshot):\n    responses = await generate_load(\n        neox_sharded,\n        \"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_opt.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef opt_sharded_handle(launcher):\n    with launcher(\"facebook/opt-6.7b\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def opt_sharded(opt_sharded_handle):\n    await opt_sharded_handle.health(300)\n    return opt_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_opt(opt_sharded):\n    pass\n"
  },
  {
    "path": "integration-tests/models/test_smolvlm.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_smolvlm_next_handle(launcher):\n    with launcher(\"HuggingFaceTB/SmolVLM-Instruct\") as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_smolvlm_next(flash_smolvlm_next_handle):\n    await flash_smolvlm_next_handle.health(300)\n    return flash_smolvlm_next_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):\n    ny_skyline = \"https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg\"\n    query = \"What is in this image?\"\n    response = await flash_smolvlm_next.generate(\n        f\"<|begin_of_text|><|begin_of_text|>User:![]({ny_skyline}){query}<end_of_utterance>\\nAssistant:\",\n        max_new_tokens=10,\n        seed=1337,\n    )\n    print(response)\n    assert (\n        response.generated_text == \" A bee on a pink flower.\"\n    ), f\"{repr(response.generated_text)}\"\n    assert response.details.generated_tokens == 8\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_t5_sharded.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef t5_sharded_handle(launcher):\n    with launcher(\"google/flan-t5-xxl\", num_shard=4) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def t5_sharded(t5_sharded_handle):\n    await t5_sharded_handle.health(300)\n    return t5_sharded_handle.client\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_t5_sharded(t5_sharded, response_snapshot):\n    response = await t5_sharded.generate(\n        \"Please answer the following question. What is the boiling point of Nitrogen?\",\n        max_new_tokens=10,\n        decoder_input_details=True,\n    )\n\n    assert response == response_snapshot\n\n\n@pytest.mark.release\n@pytest.mark.asyncio\nasync def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot):\n    responses = await generate_load(\n        t5_sharded,\n        \"Please answer the following question. What is the boiling point of Nitrogen?\",\n        max_new_tokens=10,\n        n=4,\n    )\n\n    assert len(responses) == 4\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_tools_llama.py",
    "content": "import pytest\nfrom openai import OpenAI\nfrom huggingface_hub import InferenceClient\nfrom huggingface_hub.inference._generated.types.chat_completion import (\n    ChatCompletionOutputToolCall,\n    ChatCompletionOutputFunctionDefinition,\n)\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_grammar_tools_handle(launcher):\n    with launcher(\n        \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n        num_shard=2,\n        disable_grammar_support=False,\n    ) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):\n    await flash_llama_grammar_tools_handle.health(300)\n    return flash_llama_grammar_tools_handle.client\n\n\n# tools to be used in the following tests\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_current_weather\",\n            \"description\": \"Get the current weather\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"format\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"celsius\", \"fahrenheit\"],\n                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n                    },\n                },\n                \"required\": [\"location\", \"format\"],\n                \"additionalProperties\": False,\n            },\n        },\n    },\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_n_day_weather_forecast\",\n            \"description\": \"Get an N-day weather forecast\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"format\": {\n                        \"type\": \"string\",\n                        \"enum\": [\"celsius\", \"fahrenheit\"],\n                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n                    },\n                    \"num_days\": {\n                        \"type\": \"integer\",\n                        \"description\": \"The number of days to forecast\",\n                    },\n                },\n                \"required\": [\"location\", \"format\", \"num_days\"],\n                \"additionalProperties\": False,\n            },\n        },\n    },\n]\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_nostream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    response = client.chat_completion(\n        max_tokens=100,\n        seed=1,\n        tools=tools,\n        temperature=0.0,\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n    )\n    assert response.choices[0].message.content is None\n    assert response.choices[0].message.tool_calls == [\n        ChatCompletionOutputToolCall(\n            id=\"0\",\n            type=\"function\",\n            function=ChatCompletionOutputFunctionDefinition(\n                description=None,\n                name=\"get_current_weather\",\n                arguments='{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}',\n            ),\n        )\n    ]\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_openai(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = OpenAI(api_key=\"xx\", base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat.completions.create(\n        model=\"tgi\",\n        max_tokens=100,\n        seed=1,\n        tools=tools,\n        stream=True,\n        temperature=0.0,\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n    )\n\n    chunks = []\n    tool = \"\"\n    name = \"\"\n    for chunk in stream:\n        if chunk.choices[0].delta.tool_calls[0].function.name:\n            name += chunk.choices[0].delta.tool_calls[0].function.name\n        tool += chunk.choices[0].delta.tool_calls[0].function.arguments\n        chunks.append(chunk)\n\n    assert name == \"get_current_weather\"\n    assert tool == '{ \"location\": \"Brooklyn, NY\", \"format\": \"fahrenheit\"}'\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_auto_nostream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    response = client.chat_completion(\n        max_tokens=100,\n        seed=1,\n        tools=tools,\n        temperature=0.0,\n        tool_choice=\"auto\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n    )\n    assert response.choices[0].message.content is None\n    assert response.choices[0].message.tool_calls == [\n        ChatCompletionOutputToolCall(\n            id=\"0\",\n            type=\"function\",\n            function=ChatCompletionOutputFunctionDefinition(\n                description=None,\n                name=\"get_current_weather\",\n                arguments='{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}',\n            ),\n        )\n    ]\n\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_choice_nostream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    response = client.chat_completion(\n        max_tokens=100,\n        seed=1,\n        tools=tools,\n        temperature=0.0,\n        tool_choice=\"get_current_weather\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n    )\n    assert response.choices[0].message.content is None\n    assert response.choices[0].message.tool_calls == [\n        ChatCompletionOutputToolCall(\n            id=\"0\",\n            type=\"function\",\n            function=ChatCompletionOutputFunctionDefinition(\n                description=None,\n                name=\"get_current_weather\",\n                arguments='{\"location\":\"Brooklyn, NY\",\"format\":\"fahrenheit\"}',\n            ),\n        )\n    ]\n\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_choice_stream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        max_tokens=100,\n        seed=1,\n        tools=tools,\n        temperature=0.0,\n        tool_choice=\"get_current_weather\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"Youre a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"What is the weather like in Brooklyn, New York?\",\n            },\n        ],\n        stream=True,\n    )\n\n    arguments = \"\"\n    chunks = []\n    name = \"\"\n    for chunk in stream:\n        if chunk.choices[0].delta.tool_calls[0].function.name:\n            name += chunk.choices[0].delta.tool_calls[0].function.name\n        arguments += chunk.choices[0].delta.tool_calls[0].function.arguments\n        assert chunk.choices[0].delta.content is None\n        chunks.append(chunk)\n\n    assert name == \"get_current_weather\"\n    assert arguments == '{ \"location\": \"Brooklyn, NY\", \"format\": \"fahrenheit\"}'\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_insufficient_information_nostream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    response = client.chat_completion(\n        max_tokens=20,\n        seed=24,\n        tools=tools,\n        tool_choice=\"auto\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Who are you?\",\n            },\n        ],\n        stream=False,\n    )\n\n    content_generated = response.choices[0].message.content\n    assert response.choices[0].message.tool_calls is None\n\n    assert (\n        content_generated\n        == \"I'm an artificial intelligence model known as a large language model (LLM) or conversational AI\"\n    )\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_insufficient_information_stream(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        max_tokens=20,\n        seed=24,\n        tools=tools,\n        tool_choice=\"auto\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Who are you?\",\n            },\n        ],\n        stream=True,\n    )\n\n    content_generated = \"\"\n    chunks = []\n    for chunk in stream:\n        content_generated += chunk.choices[0].delta.content\n        chunks.append(chunk)\n        assert chunk.choices[0].delta.tool_calls is None\n\n    ######## This is exactly the same as the non streaming case\n    assert (\n        content_generated\n        == \"I'm an artificial intelligence model known as a large language model (LLM) or conversational AI\"\n    )\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_sea_creatures_stream_auto(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        max_tokens=20,\n        seed=24,\n        tools=tools,\n        tool_choice=\"auto\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Tell me a story about 3 sea creatures\",\n            },\n        ],\n        stream=True,\n    )\n\n    content_generated = \"\"\n    chunks = []\n    for chunk in stream:\n        content_generated += chunk.choices[0].delta.content\n        chunks.append(chunk)\n        assert chunk.choices[0].delta.tool_calls is None\n\n    assert (\n        content_generated\n        == \"Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish,\"\n    )\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_sea_creatures_stream_required(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        max_tokens=100,\n        seed=24,\n        tools=tools,\n        tool_choice=\"required\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Tell me a story about 3 sea creatures\",\n            },\n        ],\n        stream=True,\n    )\n\n    tool_calls_generated = \"\"\n    name = \"\"\n    chunks = []\n    for chunk in stream:\n        assert chunk.choices[0].delta.content is None\n        if chunk.choices[0].delta.tool_calls[0].function.name:\n            name += chunk.choices[0].delta.tool_calls[0].function.name\n        tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments\n\n    assert name == \"get_n_day_weather_forecast\"\n    assert (\n        tool_calls_generated\n        == '{ \"location\": \"San Francisco, CA\", \"format\": \"fahrenheit\", \"num_days\":3}'\n    )\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_sea_creatures_stream_none(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        max_tokens=100,\n        seed=24,\n        tools=tools,\n        tool_choice=\"none\",\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Tell me a story about 3 sea creatures\",\n            },\n        ],\n        stream=True,\n    )\n\n    content_generated = \"\"\n    chunks = []\n    for chunk in stream:\n        chunks.append(chunk)\n        content_generated += chunk.choices[0].delta.content\n        assert chunk.choices[0].delta.tool_calls is None\n\n    assert (\n        content_generated\n        == \"Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\\n\\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep\"\n    )\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    stream = client.chat_completion(\n        messages=[\n            {\n                \"role\": \"system\",\n                \"content\": \"You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.\",\n            },\n            {\n                \"role\": \"user\",\n                \"content\": \"Tell me a story about 3 sea creatures\",\n            },\n        ],\n        tools=tools,\n        tool_choice={\n            \"type\": \"function\",\n            \"function\": {\"name\": \"get_n_day_weather_forecast\"},\n        },\n        max_tokens=100,\n        seed=24,\n        stream=True,\n    )\n    chunks = []\n    tool_calls_generated = \"\"\n    name = \"\"\n    for chunk in stream:\n        assert chunk.choices[0].delta.content is None\n        if chunk.choices[0].delta.tool_calls[0].function.name:\n            name += chunk.choices[0].delta.tool_calls[0].function.name\n        tool_calls_generated += chunk.choices[0].delta.tool_calls[0].function.arguments\n\n    assert name == \"get_n_day_weather_forecast\"\n    assert (\n        tool_calls_generated\n        == '{ \"location\": \"San Francisco, CA\", \"format\": \"celsius\", \"num_days\": 3}'\n    )\n    assert chunks == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_tool_reply_response(\n    flash_llama_grammar_tools, response_snapshot\n):\n    client = InferenceClient(base_url=f\"{flash_llama_grammar_tools.base_url}/v1\")\n    response = client.chat_completion(\n        max_tokens=100,\n        seed=42,\n        messages=[\n            {\"role\": \"user\", \"content\": \"What's the weather like in Paris today?\"},\n            {\n                \"role\": \"assistant\",\n                \"tool_calls\": [\n                    {\n                        \"id\": \"0\",\n                        \"function\": {\n                            \"arguments\": '{\"longitude\": 2.2945, \"latitude\": 48.8567}',\n                            \"name\": \"get_weather\",\n                            \"description\": None,\n                        },\n                        \"type\": \"function\",\n                    }\n                ],\n            },\n            {\"role\": \"tool\", \"tool_call_id\": \"0\", \"content\": \"6.7\"},\n        ],\n        stream=False,\n    )\n\n    assert response.choices[0].message.tool_calls is None\n    assert (\n        response.choices[0].message.content\n        == \"I can't access real-time data, but I can provide you with current conditions and forecast for Paris, France:\\n\\nThe current conditions in Paris are mostly cloudy with a temperature of 6.7°C (44.1°F). \\n\\nPlease note that the actual weather may differ from the provided information. For up-to-date information, I suggest checking a reliable weather website or app for the latest conditions and forecast.\"\n    )\n\n    assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_transformers_llama4.py",
    "content": "# import base64\n# from io import BytesIO\n# from PIL import Image\n#\n# import pytest\n#\n#\n# @pytest.fixture(scope=\"module\")\n# def flash_llama4_handle(launcher):\n#     with launcher(\"ll-re/Llama-4-Scout-17B-16E-Instruct\", num_shard=8) as handle:\n#         yield handle\n#\n#\n# @pytest.fixture(scope=\"module\")\n# async def flash_llama4(flash_llama4_handle):\n#     await flash_llama4_handle.health(300)\n#     return flash_llama4_handle.client\n#\n#\n# async def test_flash_llama4(flash_llama4, response_snapshot):\n#     response = await flash_llama4.generate(\n#         \"Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many\",\n#         seed=42,\n#         max_new_tokens=100,\n#     )\n#\n#     assert (\n#         response.generated_text\n#         == \" people died in the 1918 flu pandemic. Estimating the death toll of the 1918 flu pandemic is difficult because of incomplete records and because of the fact that many of the extra deaths were not attributed to the flu. Many experts believe that the 1918 flu pandemic killed between 50 and 100 million people. Iassistant\\n\\nThe 1918 flu pandemic, also known as the Spanish flu, is indeed one of the most devastating public health crises in human history. Estimating the exact\"\n#     )\n#     assert response.details.generated_tokens == 100\n#     assert response == response_snapshot\n#\n#\n# async def test_flash_llama4_image_cow_dog(flash_llama4, response_snapshot):\n#     image_url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png\"\n#     response = await flash_llama4.chat(\n#         seed=42,\n#         messages=[\n#             {\n#                 \"role\": \"user\",\n#                 \"content\": [\n#                     {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n#                     {\n#                         \"type\": \"text\",\n#                         \"text\": \"What is the breed of the dog in the image?\",\n#                     },\n#                 ],\n#             },\n#         ],\n#         max_tokens=100,\n#     )\n#\n#     assert (\n#         response.choices[0].message.content\n#         == \"The image does not depict a dog; it shows a cow standing on a beach. Therefore, there is no breed of a dog to identify.\"\n#     )\n#     assert response.usage[\"completion_tokens\"] == 30\n#     assert response == response_snapshot\n#\n#\n# async def test_flash_llama4_image_cow(flash_llama4, response_snapshot):\n#     image_url = \"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png\"\n#     response = await flash_llama4.chat(\n#         seed=42,\n#         messages=[\n#             {\n#                 \"role\": \"user\",\n#                 \"content\": [\n#                     {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n#                     {\"type\": \"text\", \"text\": \"What is shown in this image?\"},\n#                 ],\n#             },\n#         ],\n#         max_tokens=100,\n#     )\n#     assert (\n#         response.choices[0].message.content\n#         == \"The image shows a brown cow standing on the beach with a white face and black and white marking on its ears. The cow has a white patch around its nose and mouth. The ocean and blue sky are in the background.\"\n#     )\n#     assert response.usage[\"completion_tokens\"] == 46\n#     assert response == response_snapshot\n#\n#\n# # Helper function to convert a Pillow image to a base64 data URL\n# def image_to_data_url(img: Image.Image, fmt: str) -> str:\n#     buffer = BytesIO()\n#     img.save(buffer, format=fmt)\n#     img_data = buffer.getvalue()\n#     b64_str = base64.b64encode(img_data).decode(\"utf-8\")\n#     mime_type = \"image/png\" if fmt.upper() == \"PNG\" else \"image/jpeg\"\n#     return f\"data:{mime_type};base64,{b64_str}\"\n#\n#\n# async def test_flash_llama4_image_base64_rgba(flash_llama4, response_snapshot):\n#     # Create an empty 100x100 PNG image with alpha (transparent background)\n#     img = Image.new(\"RGBA\", (100, 100), (0, 0, 0, 0))\n#     data_url = image_to_data_url(img, \"PNG\")\n#     response = await flash_llama4.chat(\n#         seed=42,\n#         messages=[\n#             {\n#                 \"role\": \"user\",\n#                 \"content\": [\n#                     {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n#                     {\n#                         \"type\": \"text\",\n#                         \"text\": \"What do you see in this transparent image?\",\n#                     },\n#                 ],\n#             },\n#         ],\n#         max_tokens=100,\n#     )\n#     assert response == response_snapshot\n#\n#\n# async def test_flash_llama4_image_base64_rgb_png(flash_llama4, response_snapshot):\n#     # Create an empty 100x100 PNG image without alpha (white background)\n#     img = Image.new(\"RGB\", (100, 100), (255, 255, 255))\n#     data_url = image_to_data_url(img, \"PNG\")\n#     response = await flash_llama4.chat(\n#         seed=42,\n#         messages=[\n#             {\n#                 \"role\": \"user\",\n#                 \"content\": [\n#                     {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n#                     {\"type\": \"text\", \"text\": \"What do you see in this plain image?\"},\n#                 ],\n#             },\n#         ],\n#         max_tokens=100,\n#     )\n#     assert response == response_snapshot\n#\n#\n# async def test_flash_llama4_image_base64_rgb_jpg(flash_llama4, response_snapshot):\n#     # Create an empty 100x100 JPEG image (white background)\n#     img = Image.new(\"RGB\", (100, 100), (255, 255, 255))\n#     data_url = image_to_data_url(img, \"JPEG\")\n#     response = await flash_llama4.chat(\n#         seed=42,\n#         messages=[\n#             {\n#                 \"role\": \"user\",\n#                 \"content\": [\n#                     {\"type\": \"image_url\", \"image_url\": {\"url\": data_url}},\n#                     {\"type\": \"text\", \"text\": \"What do you see in this JPEG image?\"},\n#                 ],\n#             },\n#         ],\n#         max_tokens=100,\n#     )\n#     assert response == response_snapshot\n"
  },
  {
    "path": "integration-tests/models/test_transformers_olmo.py",
    "content": "import pytest\n\n\n@pytest.fixture(scope=\"module\")\ndef flash_llama_handle(launcher):\n    with launcher(\"allenai/OLMo-7B-0724-Instruct-hf\", num_shard=2) as handle:\n        yield handle\n\n\n@pytest.fixture(scope=\"module\")\nasync def flash_llama(flash_llama_handle):\n    await flash_llama_handle.health(300)\n    return flash_llama_handle.client\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_simple(flash_llama, response_snapshot):\n    response = await flash_llama.generate(\n        \"Test request\", max_new_tokens=10, decoder_input_details=True\n    )\n\n    assert response.details.generated_tokens == 10\n    assert response.generated_text == ':\\n\\n```json\\n{\\n  \"'\n    assert response == response_snapshot\n\n\n@pytest.mark.asyncio\n@pytest.mark.private\nasync def test_flash_llama_load(flash_llama, generate_load, response_snapshot):\n    responses = await generate_load(flash_llama, \"Test request\", max_new_tokens=10, n=4)\n\n    assert len(responses) == 4\n    assert responses[0].generated_text == ':\\n\\n```json\\n{\\n  \"'\n    assert all([r.generated_text == responses[0].generated_text for r in responses])\n\n    assert responses == response_snapshot\n"
  },
  {
    "path": "integration-tests/neuron/test_generate.py",
    "content": "import pytest\n\n\n@pytest.fixture\nasync def tgi_service(neuron_launcher, neuron_model_config):\n    model_name_or_path = neuron_model_config[\"neuron_model_path\"]\n    service_name = neuron_model_config[\"name\"]\n    with neuron_launcher(service_name, model_name_or_path) as tgi_service:\n        await tgi_service.health(600)\n        yield tgi_service\n\n\n@pytest.mark.asyncio\nasync def test_model_single_request(tgi_service):\n    service_name = tgi_service.client.service_name\n    prompt = \"What is Deep Learning?\"\n    # Greedy bounded without input\n    response = await tgi_service.client.text_generation(\n        prompt, max_new_tokens=17, details=True, decoder_input_details=True\n    )\n    assert response.details.generated_tokens == 17\n    greedy_expectations = {\n        \"llama\": \" and how does it work?\\nDeep learning is a subset of machine learning that uses artificial\",\n        \"qwen2\": \" - Deep Learning is a subset of Machine Learning that involves the use of artificial neural networks\",\n        \"granite\": \"\\n\\nDeep Learning is a subset of machine learning that is inspired by the structure and\",\n        \"qwen3\": \" And Why Should You Care?\\n\\nDeep learning is a subset of machine learning that uses neural\",\n        \"phi3\": \"\\n\\nDeep learning is a subfield of machine learning that focuses on creating\",\n    }\n    assert response.generated_text == greedy_expectations[service_name]\n\n    # Greedy bounded with input\n    greedy_response = await tgi_service.client.text_generation(\n        \"What is Deep Learning?\",\n        max_new_tokens=17,\n        return_full_text=True,\n        details=True,\n        decoder_input_details=True,\n    )\n    assert greedy_response.details.generated_tokens == 17\n    assert greedy_response.generated_text == prompt + greedy_expectations[service_name]\n\n    # Sampling\n    response = await tgi_service.client.text_generation(\n        \"What is Deep Learning?\",\n        do_sample=True,\n        top_k=50,\n        top_p=0.9,\n        repetition_penalty=1.2,\n        max_new_tokens=128,\n        seed=42,\n    )\n    # The response must be different\n    assert not response.startswith(greedy_expectations[service_name])\n\n    # Greedy with stop sequence (using one of the words returned from the previous test)\n    stop_sequence = greedy_response.generated_text.split(\" \")[-5]\n    response = await tgi_service.client.text_generation(\n        \"What is Deep Learning?\",\n        do_sample=False,\n        max_new_tokens=128,\n        stop_sequences=[stop_sequence],\n    )\n    assert response.endswith(stop_sequence)\n\n\n@pytest.mark.asyncio\nasync def test_model_multiple_requests(tgi_service, neuron_generate_load):\n    num_requests = 4\n    responses = await neuron_generate_load(\n        tgi_service.client,\n        \"What is Deep Learning?\",\n        max_new_tokens=17,\n        n=num_requests,\n    )\n\n    assert len(responses) == 4\n    expectations = {\n        \"llama\": \"Deep learning is a subset of machine learning that uses artificial\",\n        \"qwen2\": \"Deep Learning is a subset of Machine Learning that involves\",\n        \"granite\": \"Deep Learning is a subset of machine learning that is inspired by the structure and\",\n        \"qwen3\": \" And Why Should You Care?\\n\\nDeep learning is a subset of machine learning that uses neural\",\n        \"phi3\": \"Deep learning is a subfield of machine learning that focuses on creating\",\n    }\n    expected = expectations[tgi_service.client.service_name]\n    for r in responses:\n        assert r.details.generated_tokens == 17\n        assert expected in r.generated_text\n"
  },
  {
    "path": "integration-tests/neuron/test_implicit_env.py",
    "content": "import os\n\nimport pytest\n\n\n@pytest.fixture(scope=\"module\", params=[\"hub-neuron\", \"hub\", \"local-neuron\"])\nasync def tgi_service(request, neuron_launcher, neuron_model_config):\n    \"\"\"Expose a TGI service corresponding to a model configuration\n\n    For each model configuration, the service will be started using the following\n    deployment options:\n    - from the hub original model (export parameters chosen after hub lookup),\n    - from the hub pre-exported neuron model,\n    - from a local path to the neuron model.\n    \"\"\"\n    # the tgi_env.py script will take care of setting these\n    for var in [\n        \"MAX_BATCH_SIZE\",\n        \"MAX_INPUT_TOKENS\",\n        \"MAX_TOTAL_TOKENS\",\n        \"HF_NUM_CORES\",\n        \"HF_AUTO_CAST_TYPE\",\n    ]:\n        if var in os.environ:\n            del os.environ[var]\n    if request.param == \"hub\":\n        model_name_or_path = neuron_model_config[\"model_id\"]\n    elif request.param == \"hub-neuron\":\n        model_name_or_path = neuron_model_config[\"neuron_model_id\"]\n    else:\n        model_name_or_path = neuron_model_config[\"neuron_model_path\"]\n    service_name = neuron_model_config[\"name\"]\n    with neuron_launcher(service_name, model_name_or_path) as tgi_service:\n        await tgi_service.health(600)\n        yield tgi_service\n\n\n@pytest.mark.asyncio\nasync def test_model_single_request(tgi_service):\n    # Just verify that the generation works, and nothing is raised, with several set of params\n\n    # No params\n    await tgi_service.client.text_generation(\n        \"What is Deep Learning?\",\n    )\n\n    response = await tgi_service.client.text_generation(\n        \"How to cook beans ?\",\n        max_new_tokens=17,\n        details=True,\n        decoder_input_details=True,\n    )\n    assert response.details.generated_tokens == 17\n\n    # Sampling\n    await tgi_service.client.text_generation(\n        \"What is Deep Learning?\",\n        do_sample=True,\n        top_k=50,\n        top_p=0.9,\n        repetition_penalty=1.2,\n        max_new_tokens=128,\n        seed=42,\n    )\n"
  },
  {
    "path": "integration-tests/pyproject.toml",
    "content": "[project]\nname = \"text-generation-integration-tests\"\nversion = \"2.0.1\"\ndescription = \"Text Generation Inference integration tests\"\nauthors = [\"Nicolas Patry <nicolas@huggingface.co>\"]\nrequires-python = \">=3.10,<3.13\"\n\ndependencies = [\n    \"pydantic>2,< 3\",\n    \"syrupy>=4.8.0\",\n    \"text-generation>=0.6.0\",\n    \"pytest>=8.3.0\",\n    \"pytest-asyncio>=0.23.1\",\n    \"docker>=7\",\n    \"numpy>=2.0\",\n    \"openai>=1.65\",\n    \"huggingface_hub>=0.29\",\n    \"pillow>=11.1.0\",\n]\n\n[tool.isort]\nprofile = \"black\"\n"
  },
  {
    "path": "integration-tests/pytest.ini",
    "content": "[pytest]\naddopts = --snapshot-warn-unused\nasyncio_mode = auto\nmarkers =\n    private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')\n"
  },
  {
    "path": "integration-tests/requirements.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml\naiohappyeyeballs==2.6.1\n    # via aiohttp\naiohttp==3.11.13\n    # via text-generation\naiosignal==1.3.2\n    # via aiohttp\nannotated-types==0.7.0\n    # via pydantic\nanyio==4.8.0\n    # via\n    #   httpx\n    #   openai\nattrs==25.3.0\n    # via aiohttp\ncertifi==2025.1.31\n    # via\n    #   httpcore\n    #   httpx\n    #   requests\ncharset-normalizer==3.4.1\n    # via requests\ndistro==1.9.0\n    # via openai\ndocker==7.1.0\n    # via text-generation-integration-tests (pyproject.toml)\nfilelock==3.18.0\n    # via huggingface-hub\nfrozenlist==1.5.0\n    # via\n    #   aiohttp\n    #   aiosignal\nfsspec==2025.3.0\n    # via huggingface-hub\nh11==0.14.0\n    # via httpcore\nhttpcore==1.0.7\n    # via httpx\nhttpx==0.28.1\n    # via openai\nhuggingface-hub==0.29.3\n    # via\n    #   text-generation-integration-tests (pyproject.toml)\n    #   text-generation\nidna==3.10\n    # via\n    #   anyio\n    #   httpx\n    #   requests\n    #   yarl\niniconfig==2.0.0\n    # via pytest\njiter==0.9.0\n    # via openai\nmultidict==6.1.0\n    # via\n    #   aiohttp\n    #   yarl\nnumpy==2.2.3\n    # via text-generation-integration-tests (pyproject.toml)\nopenai==1.66.3\n    # via text-generation-integration-tests (pyproject.toml)\npackaging==24.2\n    # via\n    #   huggingface-hub\n    #   pytest\npillow==11.1.0\n    # via text-generation-integration-tests (pyproject.toml)\npluggy==1.5.0\n    # via pytest\npropcache==0.3.0\n    # via\n    #   aiohttp\n    #   yarl\npydantic==2.10.6\n    # via\n    #   text-generation-integration-tests (pyproject.toml)\n    #   openai\n    #   text-generation\npydantic-core==2.27.2\n    # via pydantic\npytest==8.3.5\n    # via\n    #   text-generation-integration-tests (pyproject.toml)\n    #   pytest-asyncio\n    #   syrupy\npytest-asyncio==0.25.3\n    # via text-generation-integration-tests (pyproject.toml)\npyyaml==6.0.2\n    # via huggingface-hub\nrequests==2.32.3\n    # via\n    #   docker\n    #   huggingface-hub\nsniffio==1.3.1\n    # via\n    #   anyio\n    #   openai\nsyrupy==4.9.0\n    # via text-generation-integration-tests (pyproject.toml)\ntext-generation==0.7.0\n    # via text-generation-integration-tests (pyproject.toml)\ntqdm==4.67.1\n    # via\n    #   huggingface-hub\n    #   openai\ntyping-extensions==4.12.2\n    # via\n    #   anyio\n    #   huggingface-hub\n    #   openai\n    #   pydantic\n    #   pydantic-core\nurllib3==2.3.0\n    # via\n    #   docker\n    #   requests\nyarl==1.18.3\n    # via aiohttp\n"
  },
  {
    "path": "launcher/Cargo.toml",
    "content": "[package]\nname = \"text-generation-launcher\"\ndescription = \"Text Generation Launcher\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[dependencies]\nclap = { version = \"4.4.5\", features = [\"derive\", \"env\"] }\nctrlc = { version = \"3.4.1\", features = [\"termination\"] }\nhf-hub = \"0.4.2\"\nnix = { version = \"0.28.0\", features = [\"signal\"] }\nonce_cell = \"1.19.0\"\npyo3 = { workspace = true }\nserde = { version = \"1.0.188\", features = [\"derive\"] }\nserde_json = \"1.0.107\"\nthiserror = \"1.0.59\"\ntracing = \"0.1.37\"\ntracing-subscriber = { version = \"0.3.17\", features = [\"json\", \"env-filter\"] }\nregex = \"1.11.0\"\n\n[dev-dependencies]\nfloat_eq = \"1.0.1\"\nreqwest = { version = \"0.11.20\", features = [\"blocking\", \"json\"] }\n\n[build-dependencies]\nvergen = { version = \"8.2.5\", features = [\"build\", \"cargo\", \"git\", \"gitcl\", \"rustc\", \"si\"] }\n"
  },
  {
    "path": "launcher/build.rs",
    "content": "use std::error::Error;\nuse vergen::EmitBuilder;\n\nfn main() -> Result<(), Box<dyn Error>> {\n    // Emit cargo and rustc compile time values\n    EmitBuilder::builder().all_cargo().all_rustc().emit()?;\n\n    // Try to get the git sha from the local git repository\n    if EmitBuilder::builder()\n        .fail_on_error()\n        .git_sha(false)\n        .emit()\n        .is_err()\n    {\n        // Unable to get the git sha\n        if let Ok(sha) = std::env::var(\"GIT_SHA\") {\n            // Set it from an env var\n            println!(\"cargo:rustc-env=VERGEN_GIT_SHA={sha}\");\n        }\n    }\n\n    // Set docker label if present\n    if let Ok(label) = std::env::var(\"DOCKER_LABEL\") {\n        // Set it from an env var\n        println!(\"cargo:rustc-env=DOCKER_LABEL={label}\");\n    }\n\n    Ok(())\n}\n"
  },
  {
    "path": "launcher/src/env_runtime.rs",
    "content": "use std::fmt;\nuse std::process::Command;\n\npub(crate) struct Env {\n    cargo_target: &'static str,\n    cargo_version: &'static str,\n    git_sha: &'static str,\n    docker_label: &'static str,\n    nvidia_env: String,\n    xpu_env: String,\n    hpu_env: String,\n}\n\nimpl Env {\n    pub fn new() -> Self {\n        let nvidia_env = nvidia_smi();\n        let xpu_env = xpu_smi();\n        let hpu_env = hl_smi();\n\n        Self {\n            nvidia_env: nvidia_env.unwrap_or(\"N/A\".to_string()),\n            xpu_env: xpu_env.unwrap_or(\"N/A\".to_string()),\n            hpu_env: hpu_env.unwrap_or(\"N/A\".to_string()),\n            cargo_target: env!(\"VERGEN_CARGO_TARGET_TRIPLE\"),\n            cargo_version: env!(\"VERGEN_RUSTC_SEMVER\"),\n            git_sha: option_env!(\"VERGEN_GIT_SHA\").unwrap_or(\"N/A\"),\n            docker_label: option_env!(\"DOCKER_LABEL\").unwrap_or(\"N/A\"),\n        }\n    }\n}\n\nimpl fmt::Display for Env {\n    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {\n        writeln!(f, \"Runtime environment:\")?;\n\n        writeln!(f, \"Target: {}\", self.cargo_target)?;\n        writeln!(f, \"Cargo version: {}\", self.cargo_version)?;\n        writeln!(f, \"Commit sha: {}\", self.git_sha)?;\n        writeln!(f, \"Docker label: {}\", self.docker_label)?;\n        writeln!(f, \"nvidia-smi:\\n{}\", self.nvidia_env)?;\n        writeln!(f, \"xpu-smi:\\n{}\", self.xpu_env)?;\n        writeln!(f, \"hpu-smi:\\n{}\", self.hpu_env)?;\n\n        Ok(())\n    }\n}\n\nfn nvidia_smi() -> Option<String> {\n    let output = Command::new(\"nvidia-smi\").output().ok()?;\n    let nvidia_smi = String::from_utf8(output.stdout).ok()?;\n    let output = nvidia_smi.replace('\\n', \"\\n   \");\n    Some(output.trim().to_string())\n}\n\nfn xpu_smi() -> Option<String> {\n    let output = Command::new(\"xpu-smi\").arg(\"discovery\").output().ok()?;\n    let xpu_smi = String::from_utf8(output.stdout).ok()?;\n    let output = xpu_smi.replace('\\n', \"\\n   \");\n    Some(output.trim().to_string())\n}\n\nfn hl_smi() -> Option<String> {\n    let output = Command::new(\"hl-smi\").output().ok()?;\n    let hl_smi = String::from_utf8(output.stdout).ok()?;\n    let output = hl_smi.replace('\\n', \"\\n   \");\n    Some(output.trim().to_string())\n}\n"
  },
  {
    "path": "launcher/src/gpu.rs",
    "content": "pub fn get_cuda_capability() -> Option<(usize, usize)> {\n    use pyo3::prelude::*;\n\n    let py_get_capability = |py: Python| -> PyResult<(isize, isize)> {\n        let torch = py.import_bound(\"torch.cuda\")?;\n        let get_device_capability = torch.getattr(\"get_device_capability\")?;\n        get_device_capability.call0()?.extract()\n    };\n\n    match pyo3::Python::with_gil(py_get_capability) {\n        Ok((major, minor)) if major < 0 || minor < 0 => {\n            tracing::warn!(\"Ignoring negative GPU compute capabilities: {major}.{minor}\");\n            None\n        }\n        Ok((major, minor)) => Some((major as usize, minor as usize)),\n        Err(err) => {\n            tracing::warn!(\"Cannot determine GPU compute capability: {}\", err);\n            None\n        }\n    }\n}\n"
  },
  {
    "path": "launcher/src/main.rs",
    "content": "use clap::{Parser, ValueEnum};\nuse hf_hub::{api::sync::ApiBuilder, Repo, RepoType};\nuse nix::sys::signal::{self, Signal};\nuse nix::unistd::Pid;\nuse serde::Deserialize;\nuse std::env;\nuse std::ffi::OsString;\nuse std::io::{BufRead, BufReader};\nuse std::os::unix::process::{CommandExt, ExitStatusExt};\nuse std::path::Path;\nuse std::process::{Child, Command, ExitStatus, Stdio};\nuse std::sync::atomic::{AtomicBool, Ordering};\nuse std::sync::mpsc::TryRecvError;\nuse std::sync::{mpsc, Arc};\nuse std::thread;\nuse std::thread::sleep;\nuse std::time::{Duration, Instant};\nuse std::{\n    fs, io,\n    io::{Read, Write},\n};\nuse thiserror::Error;\nuse tracing_subscriber::{filter::LevelFilter, EnvFilter};\n\nmod env_runtime;\nmod gpu;\n\nfn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option<usize> {\n    let config = config?;\n    let compute = compute?;\n    let f16_max_compute = compute.f16_flop()?;\n    let model_compute = config.flop()?;\n    tracing::debug!(\n        \"Max compute {} model compute {}\",\n        human_size(f16_max_compute as usize, \"flop\"),\n        human_size(model_compute as usize, \"flop\")\n    );\n    let optimal_size = (f16_max_compute / model_compute) as usize;\n    if optimal_size > 100 {\n        // Ignore calculations that's too low\n        // Most likely an error\n        Some(optimal_size)\n    } else {\n        None\n    }\n}\n\nfn human_size(size: usize, suffix: &str) -> String {\n    let mut size: f64 = size as f64;\n    let mut p = \"\";\n    for prefix in [\"\", \"K\", \"M\", \"G\", \"T\"] {\n        p = prefix;\n        if size > 1_000.0 {\n            size /= 1_000.0;\n        } else {\n            break;\n        }\n    }\n    format!(\"{size:.2}{p}{suffix}\")\n}\n\nfn vram_maximum(\n    config: Option<&Config>,\n    compute: Option<&ComputeType>,\n    memory_fraction: f32,\n) -> Option<usize> {\n    let config = config?;\n    let compute = compute?;\n    let available = compute.vram(memory_fraction)?;\n    let model = config.model_vram()?;\n    let token_vram = config.token_vram()?;\n    if let Some(vram) = available.checked_sub(model) {\n        let tokens_allowed = vram / token_vram;\n        tracing::debug!(\n            \"Available vram {}: model needs {}, every tokens requires {}, maximum allocatable tokens {tokens_allowed}\",\n            human_size(available, \"B\"),\n            human_size(model, \"B\"),\n            human_size(token_vram, \"B\"),\n        );\n        Some(tokens_allowed)\n    } else {\n        tracing::warn!(\n            \"Not enough VRAM to run the model: Available: {} - Model {}.\",\n            human_size(available, \"B\"),\n            human_size(model, \"B\")\n        );\n        None\n    }\n}\n\nfn get_config(\n    model_id: &str,\n    revision: &Option<String>,\n) -> Result<Config, Box<dyn std::error::Error>> {\n    let mut path = std::path::Path::new(model_id).to_path_buf();\n    let model_id = model_id.to_string();\n    let filename = if !path.exists() {\n        // Assume it's a hub id\n\n        let mut builder = ApiBuilder::from_env();\n        if let Ok(token) = std::env::var(\"HF_TOKEN\") {\n            // env variable has precedence over on file token.\n            builder = builder.with_token(Some(token))\n        };\n        if let Ok(origin) = env::var(\"HF_HUB_USER_AGENT_ORIGIN\") {\n            builder = builder.with_user_agent(\"origin\", origin.as_str());\n        }\n        let api = builder.build()?;\n        let repo = if let Some(ref revision) = revision {\n            api.repo(Repo::with_revision(\n                model_id,\n                RepoType::Model,\n                revision.to_string(),\n            ))\n        } else {\n            api.model(model_id)\n        };\n        repo.get(\"config.json\")?\n    } else {\n        path.push(\"config.json\");\n        path\n    };\n\n    let content = std::fs::read_to_string(filename)?;\n    let config: RawConfig = serde_json::from_str(&content)?;\n\n    let config: Config = config.into();\n    Ok(config)\n}\n\nfn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) -> (String, String) {\n    let compute_capability = gpu::get_cuda_capability();\n    let mut prefix_caching: Option<String> = std::env::var(\"PREFIX_CACHING\").ok();\n    let mut attention: Option<String> = std::env::var(\"ATTENTION\").ok();\n    if let Some(config) = config {\n        if prefix_caching.is_none() {\n            if config.vision_config.is_some() {\n                tracing::info!(\"Disabling prefix caching because of VLM model\");\n                prefix_caching = Some(\"0\".to_string());\n            } else if config.is_encoder_decoder {\n                tracing::info!(\"Disabling prefix caching because of seq2seq model\");\n                prefix_caching = Some(\"0\".to_string());\n            }\n        }\n\n        let fallback_attention = if compute_capability.is_none()\n            || matches!(compute_capability, Some((major, _)) if major < 8)\n        {\n            \"paged\"\n        } else {\n            \"flashdecoding\"\n        };\n\n        match config.get_head_dim() {\n            Some(h) if h == 64 || h == 128 || h == 256 => {\n                if lora_adapters.is_some() && prefix_caching.is_none() {\n                    tracing::info!(\"Disabling prefix caching because of lora adapters\");\n                    prefix_caching = Some(\"0\".to_string());\n                }\n                match config.model_type.as_deref() {\n                    Some(\"falcon\") | Some(\"deepseek_v2\") => {\n                        // Required because gemma2 needs bfloat16 which is not supported by\n                        // flashinfer ?\n                        if attention.is_none() {\n                            tracing::info!(\n                                \"Forcing attention to '{fallback_attention}' because model {} requires it\",\n                                config.model_type.as_ref().unwrap()\n                            );\n                            attention = Some(fallback_attention.to_string());\n                        }\n                        if fallback_attention == \"paged\" && prefix_caching.is_none() {\n                            tracing::info!(\"Disabling prefix caching because it is not supported with 'paged' attention\");\n                            prefix_caching = Some(\"0\".to_string());\n                        }\n                    }\n                    Some(\"t5\") => {}\n                    _ => {}\n                }\n            }\n            _ => {\n                if attention.is_none() {\n                    tracing::info!(\"Forcing attention to '{fallback_attention}' because head dim is not supported by flashinfer, also disabling prefix caching\");\n                    attention = Some(fallback_attention.to_string());\n                }\n                if prefix_caching.is_none() {\n                    prefix_caching = Some(\"0\".to_string());\n                }\n            }\n        }\n    }\n    if attention == Some(\"paged\".to_string()) && prefix_caching.is_none() {\n        tracing::info!(\"Disabling prefix caching on paged attention\");\n        prefix_caching = Some(\"0\".to_string());\n    }\n\n    let attention = attention.unwrap_or(\"flashinfer\".to_string());\n    let prefix_caching = prefix_caching.unwrap_or(\"true\".to_string());\n\n    (prefix_caching, attention)\n}\n\n#[derive(Deserialize)]\nstruct RawConfig {\n    max_position_embeddings: Option<usize>,\n    n_positions: Option<usize>,\n    model_type: Option<String>,\n    max_seq_len: Option<usize>,\n    quantization_config: Option<QuantizationConfig>,\n    n_embd: Option<usize>,\n    hidden_size: Option<usize>,\n    intermediate_size: Option<usize>,\n    num_attention_heads: Option<usize>,\n    num_key_value_heads: Option<usize>,\n    num_hidden_layers: Option<usize>,\n    head_dim: Option<usize>,\n    text_config: Option<TextConfig>,\n    vision_config: Option<VisionConfig>,\n    is_encoder_decoder: Option<bool>,\n    #[serde(rename = \"num_experts_per_tok\")]\n    num_experts_per_token: Option<usize>,\n    #[serde(rename = \"n_shared_experts\")]\n    num_shared_experts: Option<usize>,\n    #[serde(rename = \"num_local_experts\")]\n    num_experts: Option<usize>,\n    vocab_size: Option<usize>,\n}\n\n#[derive(Deserialize)]\nstruct QuantizationConfig {\n    quant_method: Option<Quantization>,\n}\n\n#[derive(Debug, Deserialize)]\nstruct VisionConfig {}\n\n#[derive(Debug, Deserialize)]\nstruct TextConfig {\n    head_dim: Option<usize>,\n}\n\n#[derive(Debug, Deserialize)]\nstruct Config {\n    max_position_embeddings: Option<usize>,\n    quantize: Option<Quantization>,\n    head_dim: Option<usize>,\n    num_heads: Option<usize>,\n    num_kv_heads: Option<usize>,\n    num_layers: Option<usize>,\n    intermediate_size: Option<usize>,\n    hidden_size: Option<usize>,\n    model_type: Option<String>,\n    text_config: Option<TextConfig>,\n    vision_config: Option<VisionConfig>,\n    is_encoder_decoder: bool,\n    num_experts_per_token: usize,\n    num_shared_experts: usize,\n    num_experts: usize,\n    vocab_size: Option<usize>,\n}\n\nimpl Config {\n    fn get_head_dim(&self) -> Option<usize> {\n        if let Some(head_dim) = self.head_dim {\n            return Some(head_dim);\n        }\n\n        let text_config = self.text_config.as_ref()?;\n        if let Some(head_size) = text_config.head_dim {\n            return Some(head_size);\n        }\n\n        match self.model_type.as_deref() {\n            // We special-case gemma3 here, since we need flashinfer for\n            // handling bidirectional masks. And flashinfer can only be\n            // used when the head size is known.\n            Some(\"gemma3\") => Some(256),\n            _ => None,\n        }\n    }\n\n    fn flop(&self) -> Option<u64> {\n        if self.vision_config.is_some() {\n            // VLM are much harder to predict and VRAM requirements\n            // Are more complex.\n            return None;\n        }\n        let num_heads = self.num_heads? as u64;\n        let num_kv_heads = self.num_kv_heads? as u64;\n        let head_dim = self.get_head_dim()? as u64;\n        let hidden_size = self.hidden_size? as u64;\n        let intermediate_size = (self.intermediate_size?\n            * (self.num_experts_per_token + self.num_shared_experts))\n            as u64;\n        let num_layers = self.num_layers? as u64;\n\n        let q_flops = 2 * num_heads * head_dim * hidden_size;\n        let k_flops = 2 * num_kv_heads * head_dim * hidden_size;\n        let v_flops = 2 * num_kv_heads * head_dim * hidden_size;\n        let attn_flops = 2 * num_heads * head_dim * hidden_size;\n        let o_flops = 2 * num_heads * head_dim * hidden_size;\n        let attn_layer_flops = q_flops + k_flops + v_flops + attn_flops + o_flops;\n\n        let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;\n\n        let layer_flops = attn_layer_flops + gate_up_down_flops;\n        let total = layer_flops * num_layers;\n        Some(total)\n    }\n\n    fn kv_vram_per_tok(&self) -> Option<usize> {\n        if self.quantize.is_some() {\n            // TODO handle quantization\n            return None;\n        }\n        // 2 for key and values\n        // 2 for f16 dtype?\n        Some(self.num_kv_heads? * 2 * self.get_head_dim()? * 2 * self.num_layers?)\n    }\n\n    fn mlp_vram_per_tok(&self) -> Option<usize> {\n        // TODO handle quantization\n        // TODO This calculation depends on the actual implementation\n        let dtype_size = 2;\n        let mlp_size = self.intermediate_size?;\n        // calculation is overshooting here.\n        // Coming from here: https://github.com/vllm-project/vllm/blob/d1c2e15eb31ef12e688ce0cb71895f88eaf4cd4f/vllm/model_executor/layers/fused_moe/fused_moe.py#L618-L624\n        Some((mlp_size + mlp_size / 2) * self.num_experts * dtype_size * 3)\n    }\n\n    fn token_vram(&self) -> Option<usize> {\n        let kv = self.kv_vram_per_tok()?;\n        let mlp_intermediary = self.mlp_vram_per_tok()?;\n        let per_tok = kv + mlp_intermediary;\n        Some(per_tok)\n    }\n\n    fn model_vram(&self) -> Option<usize> {\n        let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.get_head_dim()?;\n        let o_vram = self.num_heads? * self.get_head_dim()? * self.hidden_size?;\n        // gate + up + down = 3\n        let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?;\n        let layer_vram = mlp_vram + attn_vram + o_vram;\n        let vocab = self.hidden_size? * self.vocab_size?;\n        let params = layer_vram * self.num_layers? + 2 * vocab;\n        let dtype_size = 2;\n        if self.quantize.is_some() {\n            // TODO handle quantization\n            return None;\n        }\n        Some(params * dtype_size)\n    }\n}\n\nimpl From<RawConfig> for Config {\n    fn from(other: RawConfig) -> Self {\n        let max_position_embeddings = other\n            .max_position_embeddings\n            .or(other.max_seq_len)\n            .or(other.n_positions);\n        let quantize = other.quantization_config.and_then(|q| q.quant_method);\n        let hidden_size = other.hidden_size.or(other.n_embd);\n        let head_dim = other\n            .head_dim\n            .or_else(|| match (hidden_size, other.num_attention_heads) {\n                (Some(hidden_size), Some(num_attention_heads))\n                    if hidden_size % num_attention_heads == 0 =>\n                {\n                    Some(hidden_size / num_attention_heads)\n                }\n                _ => None,\n            });\n        let num_heads = other.num_attention_heads;\n        let num_layers = other.num_hidden_layers;\n        let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);\n        let intermediate_size = other.intermediate_size;\n        let model_type = other.model_type;\n        let text_config = other.text_config;\n        let vision_config = other.vision_config;\n        let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);\n        let num_experts_per_token = other.num_experts_per_token.unwrap_or(1);\n        let num_shared_experts = other.num_shared_experts.unwrap_or(0);\n        let num_experts = other.num_experts.unwrap_or(1);\n        let vocab_size = other.vocab_size;\n        Config {\n            max_position_embeddings,\n            quantize,\n            head_dim,\n            model_type,\n            text_config,\n            vision_config,\n            is_encoder_decoder,\n            hidden_size,\n            num_heads,\n            num_kv_heads,\n            intermediate_size,\n            num_layers,\n            num_experts_per_token,\n            num_shared_experts,\n            num_experts,\n            vocab_size,\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, ValueEnum, Deserialize)]\n#[serde(rename_all = \"kebab-case\")]\nenum Quantization {\n    /// 4 bit quantization. Requires a specific AWQ quantized model:\n    ///   <https://hf.co/models?search=awq>.\n    /// Should replace GPTQ models wherever possible because of the better latency\n    Awq,\n    /// Compressed tensors, which can be a mixture of different quantization methods.\n    CompressedTensors,\n    /// 8 bit quantization, doesn't require specific model.\n    /// Should be a drop-in replacement to bitsandbytes with much better performance.\n    /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>\n    Eetq,\n    /// Variable bit quantization. Requires a specific EXL2 quantized model:\n    /// <https://hf.co/models?search=exl2>. Requires exllama2 kernels and does\n    /// not support tensor parallelism (num_shard > 1).\n    Exl2,\n    /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.\n    /// text-generation-inference will use exllama (faster) kernels wherever possible, and use\n    /// triton kernel (wider support) when it's not.\n    /// AWQ has faster kernels.\n    Gptq,\n    /// 4 bit quantization. Requires a specific Marlin quantized model: <https://hf.co/models?search=marlin>.\n    Marlin,\n    /// Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half,\n    /// but it is known that the model will be much slower to run than the native f16.\n    // #[deprecated(\n    //     since = \"1.1.0\",\n    //     note = \"Use `eetq` instead, which provides better latencies overall and is drop-in in most cases\"\n    // )]\n    Bitsandbytes,\n    /// Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x,\n    /// but it is known that the model will be much slower to run than the native f16.\n    BitsandbytesNf4,\n    /// Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better\n    /// perplexity performance for you model\n    BitsandbytesFp4,\n    /// [FP8](https://developer.nvidia.com/blog/nvidia-arm-and-intel-publish-fp8-specification-for-standardization-as-an-interchange-format-for-ai/) (e4m3) works on H100 and above\n    /// This dtype has native ops should be the fastest if available.\n    /// This is currently not the fastest because of local unpacking + padding to satisfy matrix\n    /// multiplication limitations.\n    Fp8,\n}\n\nimpl std::fmt::Display for Quantization {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        // To keep in track with `server`.\n        match self {\n            #[allow(deprecated)]\n            // Use `eetq` instead, which provides better latencies overall and is drop-in in most cases\n            Quantization::Bitsandbytes => {\n                write!(f, \"bitsandbytes\")\n            }\n            Quantization::BitsandbytesNf4 => {\n                write!(f, \"bitsandbytes-nf4\")\n            }\n            Quantization::BitsandbytesFp4 => {\n                write!(f, \"bitsandbytes-fp4\")\n            }\n            Quantization::Exl2 => {\n                write!(f, \"exl2\")\n            }\n            Quantization::Gptq => {\n                write!(f, \"gptq\")\n            }\n            Quantization::Marlin => {\n                write!(f, \"marlin\")\n            }\n            Quantization::Awq => {\n                write!(f, \"awq\")\n            }\n            Quantization::CompressedTensors => {\n                write!(f, \"compressed-tensors\")\n            }\n            Quantization::Eetq => {\n                write!(f, \"eetq\")\n            }\n            Quantization::Fp8 => {\n                write!(f, \"fp8\")\n            }\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, ValueEnum)]\nenum Dtype {\n    Float16,\n    #[clap(name = \"bfloat16\")]\n    BFloat16,\n}\n\nimpl std::fmt::Display for Dtype {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        // To keep in track with `server`.\n        match self {\n            Dtype::Float16 => {\n                write!(f, \"float16\")\n            }\n            Dtype::BFloat16 => {\n                write!(f, \"bfloat16\")\n            }\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, ValueEnum)]\nenum KVCacheDtype {\n    #[clap(name = \"fp8_e4m3fn\")]\n    Fp8e4m3fn,\n\n    #[clap(name = \"fp8_e5m2\")]\n    Fp8e5m2,\n}\n\nimpl std::fmt::Display for KVCacheDtype {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            KVCacheDtype::Fp8e4m3fn => {\n                write!(f, \"fp8_e4m3fn\")\n            }\n            KVCacheDtype::Fp8e5m2 => {\n                write!(f, \"fp8_e5m2\")\n            }\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, ValueEnum)]\nenum RopeScaling {\n    Linear,\n    Dynamic,\n}\n\nimpl std::fmt::Display for RopeScaling {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        // To keep in track with `server`.\n        match self {\n            RopeScaling::Linear => {\n                write!(f, \"linear\")\n            }\n            RopeScaling::Dynamic => {\n                write!(f, \"dynamic\")\n            }\n        }\n    }\n}\n\n#[derive(Clone, Copy, Debug, ValueEnum)]\npub enum UsageStatsLevel {\n    /// Default option, usage statistics are collected anonymously\n    On,\n    /// Disables all collection of usage statistics\n    Off,\n    /// Doesn't send the error stack trace or error type, but allows sending a crash event\n    NoStack,\n}\n\nimpl std::fmt::Display for UsageStatsLevel {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        // To keep in track with `server`.\n        match self {\n            UsageStatsLevel::On => {\n                write!(f, \"on\")\n            }\n            UsageStatsLevel::Off => {\n                write!(f, \"off\")\n            }\n            UsageStatsLevel::NoStack => {\n                write!(f, \"no-stack\")\n            }\n        }\n    }\n}\n\n/// App Configuration\n#[derive(Parser, Debug)]\n#[clap(author, version, about, long_about = None)]\nstruct Args {\n    /// The name of the model to load.\n    /// Can be a MODEL_ID as listed on <https://hf.co/models> like\n    /// `gpt2` or `OpenAssistant/oasst-sft-1-pythia-12b`.\n    /// Or it can be a local directory containing the necessary files\n    /// as saved by `save_pretrained(...)` methods of transformers\n    #[clap(default_value = \"bigscience/bloom-560m\", long, env)]\n    model_id: String,\n\n    /// The actual revision of the model if you're referring to a model\n    /// on the hub. You can use a specific commit id or a branch like `refs/pr/2`.\n    #[clap(long, env)]\n    revision: Option<String>,\n\n    /// The number of tokenizer workers used for payload validation and truncation inside the\n    /// router.\n    #[clap(default_value = \"2\", long, env)]\n    validation_workers: usize,\n\n    /// Whether to shard the model across multiple GPUs\n    /// By default text-generation-inference will use all available GPUs to run\n    /// the model. Setting it to `false` deactivates `num_shard`.\n    #[clap(long, env)]\n    sharded: Option<bool>,\n\n    /// The number of shards to use if you don't want to use all GPUs on a given machine.\n    /// You can use `CUDA_VISIBLE_DEVICES=0,1 text-generation-launcher... --num_shard 2`\n    /// and `CUDA_VISIBLE_DEVICES=2,3 text-generation-launcher... --num_shard 2` to\n    /// launch 2 copies with 2 shard each on a given machine with 4 GPUs for instance.\n    #[clap(long, env)]\n    num_shard: Option<usize>,\n\n    /// Quantization method to use for the model. It is not necessary to specify this option\n    /// for pre-quantized models, since the quantization method is read from the model\n    /// configuration.\n    ///\n    /// Marlin kernels will be used automatically for GPTQ/AWQ models.\n    #[clap(long, env, value_enum)]\n    quantize: Option<Quantization>,\n\n    /// The number of input_ids to speculate on\n    /// If using a medusa model, the heads will be picked up automatically\n    /// Other wise, it will use n-gram speculation which is relatively free\n    /// in terms of compute, but the speedup heavily depends on the task.\n    #[clap(long, env)]\n    speculate: Option<usize>,\n\n    /// The dtype to be forced upon the model. This option cannot be used with `--quantize`.\n    #[clap(long, env, value_enum)]\n    dtype: Option<Dtype>,\n\n    /// Specify the dtype for the key-value cache. When this option is not provided,\n    /// the dtype of the model is used (typically `float16` or `bfloat16`). Currently\n    /// the only supported value are `fp8_e4m3fn` and `fp8_e5m2` on CUDA.\n    #[clap(long, env, value_enum)]\n    kv_cache_dtype: Option<KVCacheDtype>,\n\n    /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is\n    /// encouraged when loading a model with custom code to ensure no malicious code has been\n    /// contributed in a newer revision.\n    #[clap(long, env, value_enum)]\n    trust_remote_code: bool,\n\n    /// The maximum amount of concurrent requests for this particular deployment.\n    /// Having a low limit will refuse clients requests instead of having them\n    /// wait for too long and is usually good to handle backpressure correctly.\n    #[clap(default_value = \"128\", long, env)]\n    max_concurrent_requests: usize,\n\n    /// This is the maximum allowed value for clients to set `best_of`.\n    /// Best of makes `n` generations at the same time, and return the best\n    /// in terms of overall log probability over the entire generated sequence\n    #[clap(default_value = \"2\", long, env)]\n    max_best_of: usize,\n\n    /// This is the maximum allowed value for clients to set `stop_sequences`.\n    /// Stop sequences are used to allow the model to stop on more than just\n    /// the EOS token, and enable more complex \"prompting\" where users can preprompt\n    /// the model in a specific way and define their \"own\" stop token aligned with\n    /// their prompt.\n    #[clap(default_value = \"4\", long, env)]\n    max_stop_sequences: usize,\n\n    /// This is the maximum allowed value for clients to set `top_n_tokens`.\n    /// `top_n_tokens` is used to return information about the the `n` most likely\n    /// tokens at each generation step, instead of just the sampled token. This\n    /// information can be used for downstream tasks like for classification or\n    /// ranking.\n    #[clap(default_value = \"5\", long, env)]\n    max_top_n_tokens: u32,\n\n    /// This is the maximum allowed input length (expressed in number of tokens)\n    /// for users. The larger this value, the longer prompt users can send which\n    /// can impact the overall memory required to handle the load.\n    /// Please note that some models have a finite range of sequence they can handle.\n    /// Default to min(max_allocatable, max_position_embeddings) - 1\n    #[clap(long, env)]\n    max_input_tokens: Option<usize>,\n\n    /// Legacy version of [`Args::max_input_tokens`].\n    #[clap(long, env)]\n    max_input_length: Option<usize>,\n\n    /// This is the most important value to set as it defines the \"memory budget\"\n    /// of running clients requests.\n    /// Clients will send input sequences and ask to generate `max_new_tokens`\n    /// on top. with a value of `1512` users can send either a prompt of\n    /// `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for\n    /// `1511` max_new_tokens.\n    /// The larger this value, the larger amount each request will be in your RAM\n    /// and the less effective batching can be.\n    /// Default to min(max_allocatable, max_position_embeddings)\n    #[clap(long, env)]\n    max_total_tokens: Option<usize>,\n\n    /// This represents the ratio of waiting queries vs running queries where\n    /// you want to start considering pausing the running queries to include the waiting\n    /// ones into the same batch.\n    /// `waiting_served_ratio=1.2` Means when 12 queries are waiting and there's\n    /// only 10 queries left in the current batch we check if we can fit those 12\n    /// waiting queries into the batching strategy, and if yes, then batching happens\n    /// delaying the 10 running queries by a `prefill` run.\n    ///\n    /// This setting is only applied if there is room in the batch\n    /// as defined by `max_batch_total_tokens`.\n    #[clap(default_value = \"0.3\", long, env)]\n    waiting_served_ratio: f32,\n\n    /// Limits the number of tokens for the prefill operation.\n    /// Since this operation take the most memory and is compute bound, it is interesting\n    /// to limit the number of requests that can be sent.\n    /// Default to `max_input_tokens + 50` to give a bit of room.\n    #[clap(long, env)]\n    max_batch_prefill_tokens: Option<u32>,\n\n    /// **IMPORTANT** This is one critical control to allow maximum usage\n    /// of the available hardware.\n    ///\n    /// This represents the total amount of potential tokens within a batch.\n    /// When using padding (not recommended) this would be equivalent of\n    /// `batch_size` * `max_total_tokens`.\n    ///\n    /// However in the non-padded (flash attention) version this can be much finer.\n    ///\n    /// For `max_batch_total_tokens=1000`, you could fit `10` queries of `total_tokens=100`\n    /// or a single query of `1000` tokens.\n    ///\n    /// Overall this number should be the largest possible amount that fits the\n    /// remaining memory (after the model is loaded). Since the actual memory overhead\n    /// depends on other parameters like if you're using quantization, flash attention\n    /// or the model implementation, text-generation-inference infers this number automatically\n    /// if not provided ensuring that the value is as large as possible.\n    #[clap(long, env)]\n    max_batch_total_tokens: Option<u32>,\n\n    /// This setting defines how many tokens can be passed before forcing the waiting\n    /// queries to be put on the batch (if the size of the batch allows for it).\n    /// New queries require 1 `prefill` forward, which is different from `decode`\n    /// and therefore you need to pause the running batch in order to run `prefill`\n    /// to create the correct values for the waiting queries to be able to join the batch.\n    ///\n    /// With a value too small, queries will always \"steal\" the compute to run `prefill`\n    /// and running queries will be delayed by a lot.\n    ///\n    /// With a value too big, waiting queries could wait for a very long time\n    /// before being allowed a slot in the running batch. If your server is busy\n    /// that means that requests that could run in ~2s on an empty server could\n    /// end up running in ~20s because the query had to wait for 18s.\n    ///\n    /// This number is expressed in number of tokens to make it a bit more\n    /// \"model\" agnostic, but what should really matter is the overall latency\n    /// for end users.\n    #[clap(default_value = \"20\", long, env)]\n    max_waiting_tokens: usize,\n\n    /// Enforce a maximum number of requests per batch\n    /// Specific flag for hardware targets that do not support unpadded inference\n    #[clap(long, env)]\n    max_batch_size: Option<usize>,\n\n    /// Specify the batch sizes to compute cuda graphs for.\n    /// Use \"0\" to disable.\n    /// Default = \"1,2,4,8,16,32\"\n    #[clap(long, env, value_delimiter = ',')]\n    cuda_graphs: Option<Vec<usize>>,\n\n    /// The IP address to listen on\n    #[clap(default_value = \"0.0.0.0\", long, env)]\n    hostname: String,\n\n    /// The port to listen on.\n    #[clap(default_value = \"3000\", long, short, env)]\n    port: u16,\n\n    /// The Prometheus port to listen on.\n    #[clap(default_value = \"9000\", long, short, env)]\n    prometheus_port: u16,\n\n    /// The name of the socket for gRPC communication between the webserver\n    /// and the shards.\n    #[clap(default_value = \"/tmp/text-generation-server\", long, env)]\n    shard_uds_path: String,\n\n    /// The address the master shard will listen on. (setting used by torch distributed)\n    #[clap(default_value = \"localhost\", long, env)]\n    master_addr: String,\n\n    /// The address the master port will listen on. (setting used by torch distributed)\n    #[clap(default_value = \"29500\", long, env)]\n    master_port: usize,\n\n    /// The location of the huggingface hub cache.\n    /// Used to override the location if you want to provide a mounted disk for instance\n    #[clap(long, env)]\n    huggingface_hub_cache: Option<String>,\n\n    /// The location of the huggingface hub cache.\n    /// Used to override the location if you want to provide a mounted disk for instance\n    #[clap(long, env)]\n    weights_cache_override: Option<String>,\n\n    /// For some models (like bloom), text-generation-inference implemented custom\n    /// cuda kernels to speed up inference. Those kernels were only tested on A100.\n    /// Use this flag to disable them if you're running on different hardware and\n    /// encounter issues.\n    #[clap(long, env)]\n    disable_custom_kernels: bool,\n\n    /// Limit the CUDA available memory.\n    /// The allowed value equals the total visible memory multiplied by cuda-memory-fraction.\n    #[clap(default_value = \"1.0\", long, env)]\n    cuda_memory_fraction: f32,\n\n    /// Rope scaling will only be used for RoPE models\n    /// and allow rescaling the position rotary to accomodate for\n    /// larger prompts.\n    ///\n    /// Goes together with `rope_factor`.\n    ///\n    /// `--rope-factor 2.0` gives linear scaling with a factor of 2.0\n    /// `--rope-scaling dynamic` gives dynamic scaling with a factor of 1.0\n    /// `--rope-scaling linear` gives linear scaling with a factor of 1.0 (Nothing will be changed\n    /// basically)\n    ///\n    /// `--rope-scaling linear --rope-factor` fully describes the scaling you want\n    #[clap(long, env)]\n    rope_scaling: Option<RopeScaling>,\n\n    /// Rope scaling will only be used for RoPE models\n    /// See `rope_scaling`\n    #[clap(long, env)]\n    rope_factor: Option<f32>,\n\n    /// Outputs the logs in JSON format (useful for telemetry)\n    #[clap(long, env)]\n    json_output: bool,\n\n    #[clap(long, env)]\n    otlp_endpoint: Option<String>,\n\n    #[clap(default_value = \"text-generation-inference.router\", long, env)]\n    otlp_service_name: String,\n\n    #[clap(long, env)]\n    cors_allow_origin: Vec<String>,\n\n    #[clap(long, env)]\n    api_key: Option<String>,\n\n    #[clap(long, env)]\n    watermark_gamma: Option<f32>,\n    #[clap(long, env)]\n    watermark_delta: Option<f32>,\n\n    /// Enable ngrok tunneling\n    #[clap(long, env)]\n    ngrok: bool,\n\n    /// ngrok authentication token\n    #[clap(long, env)]\n    ngrok_authtoken: Option<String>,\n\n    /// ngrok edge\n    #[clap(long, env)]\n    ngrok_edge: Option<String>,\n\n    /// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may\n    /// include a `chat_template`. If not provided, the default config will be used from the model hub.\n    #[clap(long, env)]\n    tokenizer_config_path: Option<String>,\n\n    /// Disable outlines grammar constrained generation.\n    /// This is a feature that allows you to generate text that follows a specific grammar.\n    #[clap(long, env)]\n    disable_grammar_support: bool,\n\n    /// Display a lot of information about your runtime environment\n    #[clap(long, short, action)]\n    env: bool,\n\n    /// Control the maximum number of inputs that a client can send in a single request\n    #[clap(default_value = \"4\", long, env)]\n    max_client_batch_size: usize,\n\n    /// Lora Adapters a list of adapter ids i.e. `repo/adapter1,repo/adapter2` to load during\n    /// startup that will be available to callers via the `adapter_id` field in a request.\n    #[clap(long, env)]\n    lora_adapters: Option<String>,\n\n    /// Control if anonymous usage stats are collected.\n    /// Options are \"on\", \"off\" and \"no-stack\"\n    /// Defaul is on.\n    #[clap(default_value = \"on\", long, env)]\n    usage_stats: UsageStatsLevel,\n\n    /// Payload size limit in bytes\n    ///\n    /// Default is 2MB\n    #[clap(default_value = \"2000000\", long, env)]\n    payload_limit: usize,\n\n    /// Enables prefill logprobs\n    ///\n    /// Logprobs in the prompt are deactivated by default because they consume\n    /// a large amount of VRAM (especially for long prompts).\n    /// Using this flag reallows users to ask for them.\n    #[clap(long, env)]\n    enable_prefill_logprobs: bool,\n\n    /// Change timeout of graceful termination of the TGI server\n    #[clap(default_value = \"90\", long, short, env)]\n    graceful_termination_timeout: u64,\n}\n\n#[derive(Debug)]\nenum ShardStatus {\n    Ready,\n    Failed(usize),\n}\n\n#[allow(clippy::too_many_arguments)]\nfn shard_manager(\n    model_id: String,\n    revision: Option<String>,\n    quantize: Option<Quantization>,\n    speculate: Option<usize>,\n    dtype: Option<Dtype>,\n    kv_cache_dtype: Option<KVCacheDtype>,\n    trust_remote_code: bool,\n    uds_path: String,\n    rank: usize,\n    world_size: usize,\n    master_addr: String,\n    master_port: usize,\n    huggingface_hub_cache: Option<String>,\n    weights_cache_override: Option<String>,\n    disable_custom_kernels: bool,\n    watermark_gamma: Option<f32>,\n    watermark_delta: Option<f32>,\n    cuda_graphs: Vec<usize>,\n    cuda_memory_fraction: f32,\n    rope_scaling: Option<RopeScaling>,\n    rope_factor: Option<f32>,\n    max_total_tokens: Option<usize>,\n    max_batch_size: Option<usize>,\n    max_input_tokens: Option<usize>,\n    lora_adapters: Option<String>,\n    enable_prefill_logprobs: bool,\n    otlp_endpoint: Option<String>,\n    otlp_service_name: String,\n    log_level: LevelFilter,\n    status_sender: mpsc::Sender<ShardStatus>,\n    shutdown: Arc<AtomicBool>,\n    graceful_termination_timeout: u64,\n    _shutdown_sender: mpsc::Sender<()>,\n) {\n    // Enter shard-manager tracing span\n    let _span = tracing::span!(tracing::Level::INFO, \"shard-manager\", rank = rank).entered();\n\n    // Get UDS path\n    let uds_string = format!(\"{uds_path}-{rank}\");\n    let uds = Path::new(&uds_string);\n    // Clean previous runs\n    if uds.exists() {\n        fs::remove_file(uds).unwrap();\n    }\n\n    // Process args\n    let mut shard_args = vec![\n        \"serve\".to_string(),\n        model_id,\n        \"--uds-path\".to_string(),\n        uds_path,\n        \"--logger-level\".to_string(),\n        log_level.to_string().to_uppercase(),\n        \"--json-output\".to_string(),\n    ];\n\n    // Activate trust remote code\n    if trust_remote_code {\n        shard_args.push(\"--trust-remote-code\".to_string());\n    }\n\n    // Activate tensor parallelism\n    if world_size > 1 {\n        shard_args.push(\"--sharded\".to_string());\n    }\n\n    if let Some(quantize) = quantize {\n        shard_args.push(\"--quantize\".to_string());\n        shard_args.push(quantize.to_string())\n    }\n\n    if let Some(speculate) = speculate {\n        shard_args.push(\"--speculate\".to_string());\n        shard_args.push(speculate.to_string())\n    }\n\n    if let Some(dtype) = dtype {\n        shard_args.push(\"--dtype\".to_string());\n        shard_args.push(dtype.to_string())\n    }\n\n    if let Some(kv_cache_dtype) = kv_cache_dtype {\n        shard_args.push(\"--kv-cache-dtype\".to_string());\n        shard_args.push(kv_cache_dtype.to_string())\n    }\n\n    // Model optional revision\n    if let Some(revision) = revision {\n        shard_args.push(\"--revision\".to_string());\n        shard_args.push(revision)\n    }\n\n    let rope = match (rope_scaling, rope_factor) {\n        (None, None) => None,\n        (Some(scaling), None) => Some((scaling, 1.0)),\n        (Some(scaling), Some(factor)) => Some((scaling, factor)),\n        (None, Some(factor)) => Some((RopeScaling::Linear, factor)),\n    };\n\n    // OpenTelemetry Endpoint\n    if let Some(otlp_endpoint) = otlp_endpoint {\n        shard_args.push(\"--otlp-endpoint\".to_string());\n        shard_args.push(otlp_endpoint);\n    }\n\n    // OpenTelemetry Service Name\n    shard_args.push(\"--otlp-service-name\".to_string());\n    shard_args.push(otlp_service_name);\n\n    // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.\n    if let Some(max_input_tokens) = max_input_tokens {\n        shard_args.push(\"--max-input-tokens\".to_string());\n        shard_args.push(max_input_tokens.to_string());\n    }\n\n    // Copy current process env\n    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();\n\n    // Remove LOG_LEVEL if present\n    envs.retain(|(name, _)| name != \"LOG_LEVEL\");\n\n    // Torch Distributed Env vars\n    envs.push((\"RANK\".into(), rank.to_string().into()));\n    envs.push((\"WORLD_SIZE\".into(), world_size.to_string().into()));\n    envs.push((\"MASTER_ADDR\".into(), master_addr.into()));\n    envs.push((\"MASTER_PORT\".into(), master_port.to_string().into()));\n    envs.push((\"TORCH_NCCL_AVOID_RECORD_STREAMS\".into(), \"1\".into()));\n\n    // CUDA memory fraction\n    envs.push((\n        \"CUDA_MEMORY_FRACTION\".into(),\n        cuda_memory_fraction.to_string().into(),\n    ));\n\n    // Safetensors load fast\n    envs.push((\"SAFETENSORS_FAST_GPU\".into(), \"1\".into()));\n\n    // Disable progress bar\n    envs.push((\"HF_HUB_DISABLE_PROGRESS_BARS\".into(), \"1\".into()));\n\n    // Enable hf transfer for insane download speeds\n    let enable_hf_transfer = env::var(\"HF_HUB_ENABLE_HF_TRANSFER\").unwrap_or(\"1\".to_string());\n    envs.push((\n        \"HF_HUB_ENABLE_HF_TRANSFER\".into(),\n        enable_hf_transfer.into(),\n    ));\n\n    // Parse Inference API token\n    if let Ok(api_token) = env::var(\"HF_API_TOKEN\") {\n        envs.push((\"HF_TOKEN\".into(), api_token.into()))\n    };\n\n    // Detect rope scaling\n    // Sending as env instead of CLI args to not bloat everything\n    // those only can be used by RoPE models, so passing information around\n    // for all models will complexify code unnecessarily\n    if let Some((scaling, factor)) = rope {\n        envs.push((\"ROPE_SCALING\".into(), scaling.to_string().into()));\n        envs.push((\"ROPE_FACTOR\".into(), factor.to_string().into()));\n    }\n\n    if let Some(max_total_tokens) = max_total_tokens {\n        envs.push((\n            \"MAX_TOTAL_TOKENS\".into(),\n            max_total_tokens.to_string().into(),\n        ));\n    }\n    if let Some(max_batch_size) = max_batch_size {\n        envs.push((\"MAX_BATCH_SIZE\".into(), max_batch_size.to_string().into()));\n    }\n\n    // Lora Adapters\n    if let Some(lora_adapters) = lora_adapters {\n        envs.push((\"LORA_ADAPTERS\".into(), lora_adapters.into()));\n    }\n\n    // Logprobs\n    if enable_prefill_logprobs {\n        envs.push((\"REQUEST_LOGPROBS\".into(), \"1\".into()));\n    }\n\n    // If huggingface_hub_cache is some, pass it to the shard\n    // Useful when running inside a docker container\n    if let Some(huggingface_hub_cache) = huggingface_hub_cache {\n        envs.push((\"HUGGINGFACE_HUB_CACHE\".into(), huggingface_hub_cache.into()));\n    };\n\n    // If weights_cache_override is some, pass it to the shard\n    // Useful when running inside a HuggingFace Inference Endpoint\n    if let Some(weights_cache_override) = weights_cache_override {\n        envs.push((\n            \"WEIGHTS_CACHE_OVERRIDE\".into(),\n            weights_cache_override.into(),\n        ));\n    };\n\n    // Enable experimental support for cuda graphs\n    if !cuda_graphs.is_empty() {\n        envs.push((\n            \"CUDA_GRAPHS\".into(),\n            cuda_graphs\n                .into_iter()\n                .map(|c| c.to_string())\n                .collect::<Vec<_>>()\n                .join(\",\")\n                .into(),\n        ));\n    }\n\n    // If disable_custom_kernels is true, pass it to the shard as an env var\n    if disable_custom_kernels {\n        envs.push((\"DISABLE_CUSTOM_KERNELS\".into(), \"True\".into()))\n    }\n\n    // Watermark Gamma\n    if let Some(watermark_gamma) = watermark_gamma {\n        envs.push((\"WATERMARK_GAMMA\".into(), watermark_gamma.to_string().into()))\n    }\n\n    // Watermark Delta\n    if let Some(watermark_delta) = watermark_delta {\n        envs.push((\"WATERMARK_DELTA\".into(), watermark_delta.to_string().into()))\n    }\n\n    // Start process\n    tracing::info!(\"Starting shard\");\n    let mut p = match Command::new(\"text-generation-server\")\n        .args(shard_args)\n        .env_clear()\n        .envs(envs)\n        .stdin(Stdio::piped())\n        .stdout(Stdio::piped())\n        .stderr(Stdio::piped())\n        .process_group(0)\n        .spawn()\n    {\n        Ok(p) => p,\n        Err(err) => {\n            if err.kind() == io::ErrorKind::NotFound {\n                tracing::error!(\"text-generation-server not found in PATH\");\n                tracing::error!(\"Please install it with `make install-server`\")\n            }\n            {\n                tracing::error!(\"{}\", err);\n            }\n\n            status_sender.send(ShardStatus::Failed(rank)).unwrap();\n            return;\n        }\n    };\n\n    // Redirect STDOUT to the console\n    let mut pstdin = p.stdin.take().unwrap();\n    let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());\n    let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());\n\n    //stdout tracing thread\n    thread::spawn(move || {\n        log_lines(shard_stdout_reader);\n    });\n    // We read stderr in another thread as it seems that lines() can block in some cases\n    let (err_sender, err_receiver) = mpsc::channel();\n    thread::spawn(move || {\n        for line in shard_stderr_reader.lines().map_while(Result::ok) {\n            err_sender.send(line).unwrap_or(());\n        }\n    });\n    // We read stdin in another thread as it seems that lines() can block in some cases\n    if LevelFilter::current() >= tracing::Level::DEBUG {\n        thread::spawn(move || {\n            let mut stdin = io::stdin(); // We get `Stdin` here.\n            loop {\n                let mut buffer = vec![0; 4096];\n                if let Ok(n) = stdin.read(&mut buffer) {\n                    if n > 0 {\n                        let _ = pstdin.write_all(&buffer[..n]);\n                    }\n                }\n            }\n        });\n    }\n\n    let mut ready = false;\n    let start_time = Instant::now();\n    let mut wait_time = Instant::now();\n    loop {\n        // Process exited\n        if let Some(exit_status) = p.try_wait().unwrap() {\n            let mut err = String::new();\n            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {\n                err = err + \"\\n\" + &line;\n            }\n\n            tracing::error!(\"Shard complete standard error output:\\n{err}\");\n\n            if let Some(signal) = exit_status.signal() {\n                tracing::error!(\"Shard process was signaled to shutdown with signal {signal}\");\n            }\n\n            status_sender.send(ShardStatus::Failed(rank)).unwrap();\n            return;\n        }\n\n        // We received a shutdown signal\n        if shutdown.load(Ordering::SeqCst) {\n            terminate(\n                \"shard\",\n                p,\n                Duration::from_secs(graceful_termination_timeout),\n            )\n            .unwrap();\n            return;\n        }\n\n        // Shard is ready\n        if uds.exists() && !ready {\n            tracing::info!(\"Shard ready in {:?}\", start_time.elapsed());\n            status_sender.send(ShardStatus::Ready).unwrap();\n            ready = true;\n        } else if !ready && wait_time.elapsed() > Duration::from_secs(10) {\n            tracing::info!(\"Waiting for shard to be ready...\");\n            wait_time = Instant::now();\n        }\n        sleep(Duration::from_millis(100));\n    }\n}\n\nfn shutdown_shards(shutdown: Arc<AtomicBool>, shutdown_receiver: &mpsc::Receiver<()>) {\n    tracing::info!(\"Shutting down shards\");\n    // Update shutdown value to true\n    // This will be picked up by the shard manager\n    shutdown.store(true, Ordering::SeqCst);\n\n    // Wait for shards to shutdown\n    // This will block till all shutdown_sender are dropped\n    let _ = shutdown_receiver.recv();\n}\n\nfn num_cuda_devices() -> Option<usize> {\n    let devices = match env::var(\"CUDA_VISIBLE_DEVICES\") {\n        Ok(devices) => devices,\n        Err(_) => match env::var(\"NVIDIA_VISIBLE_DEVICES\") {\n            Ok(devices) => {\n                // NVIDIA_VISIBLE_DEVICES is always set when not specified and the nvidia container runtime is\n                // in (jit-)cdi mode (since 1.14)\n                // nvidia container runtime default mode switched from legacy to cdi mode from 1.18 on\n                // Let's handle the void case as all here\n                // See: https://github.com/NVIDIA/nvidia-container-toolkit\n                if [\"all\", \"void\"].contains(&devices.trim())  {\n                    // Count the number of all GPUs via nvidia-smi\n                    let output = Command::new(\"nvidia-smi\")\n                        .args([\"--query-gpu=uuid\", \"--format=csv,noheader\"])\n                        .output()\n                        .ok()?;\n\n                    String::from_utf8_lossy(&output.stdout)\n                        .lines()\n                        .filter(|line| !line.trim().is_empty())\n                        .collect::<Vec<_>>().join(\",\")\n                } else {\n                    devices\n                }\n            }\n            Err(_) => env::var(\"ZE_AFFINITY_MASK\").ok()?,\n        },\n    };\n    let n_devices = devices.split(',').count();\n    Some(n_devices)\n}\n\n#[derive(Deserialize)]\n#[serde(rename_all = \"UPPERCASE\")]\nenum PythonLogLevelEnum {\n    Trace,\n    Debug,\n    Info,\n    Success,\n    Warning,\n    Error,\n    Critical,\n}\n\n#[derive(Deserialize)]\nstruct PythonLogLevel {\n    name: PythonLogLevelEnum,\n}\n\n#[derive(Deserialize)]\nstruct PythonLogRecord {\n    level: PythonLogLevel,\n}\n\n#[derive(Deserialize)]\nstruct PythonLogMessage {\n    text: String,\n    record: PythonLogRecord,\n}\n\nimpl PythonLogMessage {\n    fn trace(&self) {\n        match self.record.level.name {\n            PythonLogLevelEnum::Trace => tracing::trace!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Debug => tracing::debug!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Info => tracing::info!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Success => tracing::info!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Warning => tracing::warn!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Error => tracing::error!(\"{}\", self.text.trim_end()),\n            PythonLogLevelEnum::Critical => tracing::error!(\"{}\", self.text.trim_end()),\n        }\n    }\n}\n\nimpl TryFrom<&[u8]> for PythonLogMessage {\n    type Error = serde_json::Error;\n\n    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {\n        serde_json::from_slice::<Self>(value)\n    }\n}\n\nfn log_lines<R: Sized + Read>(mut bufread: BufReader<R>) {\n    let mut buffer = vec![0u8; 8 * 4096];\n    let mut stdout = std::io::stdout();\n    loop {\n        let n = bufread.read(&mut buffer);\n        if let Ok(n) = n {\n            if n > 0 {\n                let mut lines = buffer[..n].split(|i| *i == b'\\n').peekable();\n                while let Some(line) = lines.next() {\n                    match PythonLogMessage::try_from(line) {\n                        Ok(log) => log.trace(),\n                        // For interactive debugging ?\n                        Err(_) => {\n                            if LevelFilter::current() >= tracing::Level::DEBUG {\n                                stdout.write_all(line).unwrap();\n                                if lines.peek().is_some() {\n                                    stdout.write_all(b\"\\n\").unwrap();\n                                }\n                                stdout.flush().unwrap();\n                            }\n                        }\n                    }\n                }\n            } else {\n                break;\n            }\n        }\n    }\n}\n\nfn find_num_shards(\n    sharded: Option<bool>,\n    num_shard: Option<usize>,\n) -> Result<usize, LauncherError> {\n    // get the number of shards given `sharded` and `num_shard`\n    let num_shard = match (sharded, num_shard) {\n        (Some(true), None) => {\n            // try to default to the number of available GPUs\n            tracing::info!(\"Parsing num_shard from CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK\");\n            let n_devices = num_cuda_devices()\n                .expect(\"--num-shard and CUDA_VISIBLE_DEVICES/NVIDIA_VISIBLE_DEVICES/ZE_AFFINITY_MASK are not set\");\n            if n_devices <= 1 {\n                return Err(LauncherError::NotEnoughCUDADevices(format!(\n                    \"`sharded` is true but only found {n_devices} CUDA devices\"\n                )));\n            }\n            n_devices\n        }\n        (Some(true), Some(num_shard)) => {\n            // we can't have only one shard while sharded\n            if num_shard <= 1 {\n                return Err(LauncherError::ArgumentValidation(\n                    \"`sharded` is true but `num_shard` <= 1\".to_string(),\n                ));\n            }\n            num_shard\n        }\n        (Some(false), Some(num_shard)) => num_shard,\n        (Some(false), None) => 1,\n        (None, None) => num_cuda_devices().unwrap_or(1),\n        (None, Some(num_shard)) => num_shard,\n    };\n    if num_shard < 1 {\n        return Err(LauncherError::ArgumentValidation(\n            \"`num_shard` cannot be < 1\".to_string(),\n        ));\n    }\n    Ok(num_shard)\n}\n\n#[derive(Debug, Error)]\nenum LauncherError {\n    #[error(\"Invalid argument: {0}\")]\n    ArgumentValidation(String),\n    #[error(\"not enough cuda devices: {0}\")]\n    NotEnoughCUDADevices(String),\n    #[error(\"Download error\")]\n    DownloadError,\n    #[error(\"Shard cannot start\")]\n    ShardCannotStart,\n    #[error(\"Shard disconnected\")]\n    ShardDisconnected,\n    #[error(\"Shard failed\")]\n    ShardFailed,\n    #[error(\"Webserver failed\")]\n    WebserverFailed,\n    #[error(\"Webserver cannot start\")]\n    WebserverCannotStart,\n}\n\nfn download_convert_model(\n    model_id: &str,\n    revision: Option<&str>,\n    trust_remote_code: bool,\n    huggingface_hub_cache: Option<&str>,\n    weights_cache_override: Option<&str>,\n    running: Arc<AtomicBool>,\n    merge_lora: bool,\n) -> Result<(), LauncherError> {\n    // Enter download tracing span\n    let _span = tracing::span!(tracing::Level::INFO, \"download\").entered();\n\n    let mut download_args = vec![\n        \"download-weights\".to_string(),\n        model_id.to_string(),\n        \"--extension\".to_string(),\n        \".safetensors\".to_string(),\n        \"--logger-level\".to_string(),\n        \"INFO\".to_string(),\n        \"--json-output\".to_string(),\n    ];\n\n    if merge_lora {\n        download_args.push(\"--merge-lora\".to_string());\n    }\n\n    // Model optional revision\n    if let Some(revision) = &revision {\n        download_args.push(\"--revision\".to_string());\n        download_args.push(revision.to_string())\n    }\n\n    // Trust remote code for automatic peft fusion\n    if trust_remote_code {\n        download_args.push(\"--trust-remote-code\".to_string());\n    }\n\n    // Copy current process env\n    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();\n\n    // Remove LOG_LEVEL if present\n    envs.retain(|(name, _)| name != \"LOG_LEVEL\");\n\n    // Disable progress bar\n    envs.push((\"HF_HUB_DISABLE_PROGRESS_BARS\".into(), \"1\".into()));\n\n    // If huggingface_hub_cache is set, pass it to the download process\n    // Useful when running inside a docker container\n    if let Some(ref huggingface_hub_cache) = huggingface_hub_cache {\n        envs.push((\"HUGGINGFACE_HUB_CACHE\".into(), huggingface_hub_cache.into()));\n    };\n\n    // Enable hf transfer for insane download speeds\n    let enable_hf_transfer = env::var(\"HF_HUB_ENABLE_HF_TRANSFER\").unwrap_or(\"1\".to_string());\n    envs.push((\n        \"HF_HUB_ENABLE_HF_TRANSFER\".into(),\n        enable_hf_transfer.into(),\n    ));\n\n    // Parse Inference API token\n    if let Ok(api_token) = env::var(\"HF_API_TOKEN\") {\n        envs.push((\"HF_TOKEN\".into(), api_token.into()))\n    };\n\n    // If args.weights_cache_override is some, pass it to the download process\n    // Useful when running inside a HuggingFace Inference Endpoint\n    if let Some(weights_cache_override) = &weights_cache_override {\n        envs.push((\n            \"WEIGHTS_CACHE_OVERRIDE\".into(),\n            weights_cache_override.into(),\n        ));\n    };\n\n    // Start process\n    tracing::info!(\"Starting check and download process for {model_id}\");\n    let mut download_process = match Command::new(\"text-generation-server\")\n        .args(download_args)\n        .env_clear()\n        .envs(envs)\n        .stdout(Stdio::piped())\n        .stderr(Stdio::piped())\n        .process_group(0)\n        .spawn()\n    {\n        Ok(p) => p,\n        Err(err) => {\n            if err.kind() == io::ErrorKind::NotFound {\n                tracing::error!(\"text-generation-server not found in PATH\");\n                tracing::error!(\"Please install it with `make install-server`\")\n            } else {\n                tracing::error!(\"{}\", err);\n            }\n\n            return Err(LauncherError::DownloadError);\n        }\n    };\n\n    let download_stdout = BufReader::new(download_process.stdout.take().unwrap());\n\n    thread::spawn(move || {\n        log_lines(download_stdout);\n    });\n\n    let download_stderr = BufReader::new(download_process.stderr.take().unwrap());\n\n    // We read stderr in another thread as it seems that lines() can block in some cases\n    let (err_sender, err_receiver) = mpsc::channel();\n    thread::spawn(move || {\n        for line in download_stderr.lines().map_while(Result::ok) {\n            err_sender.send(line).unwrap_or(());\n        }\n    });\n\n    loop {\n        if let Some(status) = download_process.try_wait().unwrap() {\n            if status.success() {\n                tracing::info!(\"Successfully downloaded weights for {model_id}\");\n                break;\n            }\n\n            let mut err = String::new();\n            while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {\n                err = err + \"\\n\" + &line;\n            }\n\n            if let Some(signal) = status.signal() {\n                tracing::error!(\n                    \"Download process was signaled to shutdown with signal {signal}: {err}\"\n                );\n            } else {\n                tracing::error!(\"Download encountered an error: {err}\");\n            }\n\n            return Err(LauncherError::DownloadError);\n        }\n        if !running.load(Ordering::SeqCst) {\n            terminate(\"download\", download_process, Duration::from_secs(10)).unwrap();\n            return Ok(());\n        }\n        sleep(Duration::from_millis(100));\n    }\n    Ok(())\n}\n\n#[allow(clippy::too_many_arguments)]\nfn spawn_shards(\n    num_shard: usize,\n    args: &Args,\n    cuda_graphs: Vec<usize>,\n    max_total_tokens: Option<usize>,\n    max_input_tokens: Option<usize>,\n    quantize: Option<Quantization>,\n    max_log_level: LevelFilter,\n    shutdown: Arc<AtomicBool>,\n    shutdown_receiver: &mpsc::Receiver<()>,\n    shutdown_sender: mpsc::Sender<()>,\n    status_receiver: &mpsc::Receiver<ShardStatus>,\n    status_sender: mpsc::Sender<ShardStatus>,\n    running: Arc<AtomicBool>,\n    graceful_termination_timeout: u64,\n) -> Result<(), LauncherError> {\n    // Start shard processes\n    for rank in 0..num_shard {\n        let model_id = args.model_id.clone();\n        let revision = args.revision.clone();\n        let uds_path = args.shard_uds_path.clone();\n        let master_addr = args.master_addr.clone();\n        let huggingface_hub_cache = args.huggingface_hub_cache.clone();\n        let weights_cache_override = args.weights_cache_override.clone();\n        let status_sender = status_sender.clone();\n        let shutdown = shutdown.clone();\n        let shutdown_sender = shutdown_sender.clone();\n        let otlp_endpoint = args.otlp_endpoint.clone();\n        let otlp_service_name = args.otlp_service_name.clone();\n        let speculate = args.speculate;\n        let dtype = args.dtype;\n        let kv_cache_dtype = args.kv_cache_dtype;\n        let trust_remote_code = args.trust_remote_code;\n        let master_port = args.master_port;\n        let disable_custom_kernels = args.disable_custom_kernels;\n        let watermark_gamma = args.watermark_gamma;\n        let watermark_delta = args.watermark_delta;\n        let cuda_graphs_clone = cuda_graphs.clone();\n        let cuda_memory_fraction = args.cuda_memory_fraction;\n        let rope_scaling = args.rope_scaling;\n        let rope_factor = args.rope_factor;\n        let max_batch_size = args.max_batch_size;\n        let lora_adapters = args.lora_adapters.clone();\n        let enable_prefill_logprobs = args.enable_prefill_logprobs;\n        thread::spawn(move || {\n            shard_manager(\n                model_id,\n                revision,\n                quantize,\n                speculate,\n                dtype,\n                kv_cache_dtype,\n                trust_remote_code,\n                uds_path,\n                rank,\n                num_shard,\n                master_addr,\n                master_port,\n                huggingface_hub_cache,\n                weights_cache_override,\n                disable_custom_kernels,\n                watermark_gamma,\n                watermark_delta,\n                cuda_graphs_clone,\n                cuda_memory_fraction,\n                rope_scaling,\n                rope_factor,\n                max_total_tokens,\n                max_batch_size,\n                max_input_tokens,\n                lora_adapters,\n                enable_prefill_logprobs,\n                otlp_endpoint,\n                otlp_service_name,\n                max_log_level,\n                status_sender,\n                shutdown,\n                graceful_termination_timeout,\n                shutdown_sender,\n            )\n        });\n    }\n    drop(shutdown_sender);\n\n    // Wait for shard to start\n    let mut shard_ready = 0;\n    while running.load(Ordering::SeqCst) {\n        match status_receiver.try_recv() {\n            Ok(ShardStatus::Ready) => {\n                shard_ready += 1;\n                if shard_ready == num_shard {\n                    break;\n                }\n            }\n            Err(TryRecvError::Empty) => {\n                sleep(Duration::from_millis(100));\n            }\n            Ok(ShardStatus::Failed(rank)) => {\n                tracing::error!(\"Shard {rank} failed to start\");\n                shutdown_shards(shutdown, shutdown_receiver);\n                return Err(LauncherError::ShardCannotStart);\n            }\n            Err(TryRecvError::Disconnected) => {\n                tracing::error!(\"Shard status channel disconnected\");\n                shutdown_shards(shutdown, shutdown_receiver);\n                return Err(LauncherError::ShardDisconnected);\n            }\n        }\n    }\n    Ok(())\n}\n\n#[derive(Debug)]\nenum Gpu {\n    RTX4090,\n    T4,\n    L4,\n    L40,\n    L40S,\n    A10G,\n    A40,\n    H100,\n    A100,\n    H200,\n    Unknown(String),\n}\n\n#[derive(Debug)]\nstruct ComputeType {\n    count: usize,\n    card: Gpu,\n}\n\nimpl From<&str> for Gpu {\n    fn from(value: &str) -> Self {\n        match value {\n            \"nvidia-4090\" => Gpu::RTX4090,\n            \"nvidia-t4\" => Gpu::T4,\n            \"nvidia-l4\" => Gpu::L4,\n            \"nvidia-l40\" => Gpu::L40,\n            \"nvidia-l40s\" => Gpu::L40S,\n            \"nvidia-a10g\" => Gpu::A10G,\n            \"nvidia-a40\" => Gpu::A40,\n            \"nvidia-h100-80gb-hbm3\" => Gpu::H100,\n            \"nvidia-h100-nvl\" => Gpu::H100,\n            \"nvidia-h100\" => Gpu::H100,\n            \"nvidia-a100-sxm4-80gb\" => Gpu::A100,\n            \"nvidia-a100-sxm4-40gb\" => Gpu::A100,\n            \"nvidia-a100-80gb-pcie\" => Gpu::A100,\n            \"nvidia-a100\" => Gpu::A100,\n            \"nvidia-h200\" => Gpu::H200,\n            card => Gpu::Unknown(card.to_string()),\n        }\n    }\n}\n\nimpl std::fmt::Display for Gpu {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            Gpu::RTX4090 => write!(f, \"nvidia-4090\"),\n            Gpu::T4 => write!(f, \"nvidia-t4\"),\n            Gpu::L4 => write!(f, \"nvidia-l4\"),\n            Gpu::L40 => write!(f, \"nvidia-l40\"),\n            Gpu::L40S => write!(f, \"nvidia-l40s\"),\n            Gpu::A10G => write!(f, \"nvidia-a10g\"),\n            Gpu::A40 => write!(f, \"nvidia-a40\"),\n            Gpu::H100 => write!(f, \"nvidia-h100-80fb-hbm3\"),\n            Gpu::A100 => write!(f, \"nvidia-a100-sxm4-80gb\"),\n            Gpu::H200 => write!(f, \"nvidia-h200\"),\n            Gpu::Unknown(card) => write!(f, \"{}\", card),\n        }\n    }\n}\n\nimpl ComputeType {\n    fn f16_flop(&self) -> Option<u64> {\n        let card_flop = match &self.card {\n            // https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/\n            // Specs are unclear https://www.itcreations.com/nvidia-gpu/nvidia-geforce-rtx-4090-gpu\n            Gpu::RTX4090 => Some(82 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/tesla-t4/\n            Gpu::T4 => Some(65 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/l4/\n            Gpu::L4 => Some(121 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/l40/\n            Gpu::L40 => Some(181 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/l40s/\n            Gpu::L40S => Some(363 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/products/a10-gpu/\n            Gpu::A10G => Some(125 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/a40/\n            // https://images.nvidia.com/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf\n            Gpu::A40 => Some(149 * 10u64.pow(12)),\n            // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf\n            Gpu::A100 => Some(312 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/h100/\n            // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf\n            Gpu::H100 => Some(900 * 10u64.pow(12)),\n            // https://www.nvidia.com/en-us/data-center/h200/\n            Gpu::H200 => Some(989 * 10u64.pow(12)),\n            Gpu::Unknown(card) => {\n                tracing::warn!(\"Unkown compute for card {card}\");\n                None\n            }\n        };\n        card_flop.map(|f| f * self.count as u64)\n    }\n\n    fn vram(&self, memory_fraction: f32) -> Option<usize> {\n        let output = Command::new(\"nvidia-smi\")\n            .args([\"--query-gpu=memory.total\", \"--format=csv\"])\n            .output()\n            .ok()?;\n        let output = String::from_utf8(output.stdout).ok()?;\n        let fullname = output.split('\\n').nth(1)?;\n        let mut tokens = fullname.split(' ');\n        let amount = tokens.next()?;\n        let unit = tokens.next()?;\n        if unit != \"MiB\" {\n            tracing::warn!(\"Unexpected memory unit {unit}, expected MiB\");\n            return None;\n        }\n        let amount: usize = amount.parse().ok()?;\n        let amount = amount * 2usize.pow(20);\n        let wiggle_room: f32 = env::var(\"TGI_WIGGLE_ROOM\")\n            .ok()\n            .and_then(|wiggle| wiggle.parse().ok())\n            .unwrap_or(0.95);\n        let total = amount * self.count;\n        let adjusted = ((total as f32) * memory_fraction * wiggle_room) as usize;\n        Some(adjusted)\n    }\n}\n\nimpl From<ComputeType> for OsString {\n    fn from(value: ComputeType) -> Self {\n        format!(\"{}-{}\", value.count, value.card).into()\n    }\n}\n\nfn compute_type(count: usize) -> Option<ComputeType> {\n    let output = Command::new(\"nvidia-smi\")\n        .args([\"--query-gpu=gpu_name\", \"--format=csv\"])\n        .output()\n        .ok()?;\n    let output = String::from_utf8(output.stdout).ok()?;\n    let fullname = output.split('\\n').nth(1)?;\n    let cardname = fullname.replace(' ', \"-\").to_lowercase();\n    let card = (&*cardname).into();\n    Some(ComputeType { count, card })\n}\n\nfn spawn_webserver(\n    num_shard: usize,\n    args: Args,\n    max_input_tokens: Option<usize>,\n    max_total_tokens: Option<usize>,\n    max_batch_prefill_tokens: u32,\n    shutdown: Arc<AtomicBool>,\n    shutdown_receiver: &mpsc::Receiver<()>,\n) -> Result<Child, LauncherError> {\n    // All shard started\n    // Start webserver\n    tracing::info!(\"Starting Webserver\");\n    let mut router_args = vec![\n        \"--max-client-batch-size\".to_string(),\n        args.max_client_batch_size.to_string(),\n        \"--max-concurrent-requests\".to_string(),\n        args.max_concurrent_requests.to_string(),\n        \"--max-best-of\".to_string(),\n        args.max_best_of.to_string(),\n        \"--max-stop-sequences\".to_string(),\n        args.max_stop_sequences.to_string(),\n        \"--max-top-n-tokens\".to_string(),\n        args.max_top_n_tokens.to_string(),\n        \"--max-batch-prefill-tokens\".to_string(),\n        max_batch_prefill_tokens.to_string(),\n        \"--waiting-served-ratio\".to_string(),\n        args.waiting_served_ratio.to_string(),\n        \"--max-waiting-tokens\".to_string(),\n        args.max_waiting_tokens.to_string(),\n        \"--validation-workers\".to_string(),\n        args.validation_workers.to_string(),\n        \"--hostname\".to_string(),\n        args.hostname.to_string(),\n        \"--port\".to_string(),\n        args.port.to_string(),\n        \"--prometheus-port\".to_string(),\n        args.prometheus_port.to_string(),\n        \"--master-shard-uds-path\".to_string(),\n        format!(\"{}-0\", args.shard_uds_path),\n        \"--tokenizer-name\".to_string(),\n        args.model_id,\n        \"--payload-limit\".to_string(),\n        args.payload_limit.to_string(),\n    ];\n    if let Some(max_input_tokens) = max_input_tokens {\n        router_args.extend_from_slice(&[\n            \"--max-input-tokens\".to_string(),\n            max_input_tokens.to_string(),\n        ]);\n    }\n    if let Some(max_total_tokens) = max_total_tokens {\n        router_args.extend_from_slice(&[\n            \"--max-total-tokens\".to_string(),\n            max_total_tokens.to_string(),\n        ]);\n    }\n\n    // Pass usage stats flags to router\n    router_args.push(\"--usage-stats\".to_string());\n    router_args.push(args.usage_stats.to_string());\n\n    // Grammar support\n    if args.disable_grammar_support {\n        router_args.push(\"--disable-grammar-support\".to_string());\n    }\n\n    // Tokenizer config path\n    if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {\n        router_args.push(\"--tokenizer-config-path\".to_string());\n        router_args.push(tokenizer_config_path.to_string());\n    }\n\n    // Model optional max batch total tokens\n    if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {\n        router_args.push(\"--max-batch-total-tokens\".to_string());\n        router_args.push(max_batch_total_tokens.to_string());\n    }\n\n    // Router optional max batch size\n    if let Some(max_batch_size) = args.max_batch_size {\n        router_args.push(\"--max-batch-size\".to_string());\n        router_args.push(max_batch_size.to_string());\n    }\n\n    // Model optional revision\n    if let Some(ref revision) = args.revision {\n        router_args.push(\"--revision\".to_string());\n        router_args.push(revision.to_string())\n    }\n\n    if args.trust_remote_code {\n        router_args.push(\"--trust-remote-code\".to_string());\n    }\n\n    if args.json_output {\n        router_args.push(\"--json-output\".to_string());\n    }\n\n    // OpenTelemetry\n    if let Some(otlp_endpoint) = args.otlp_endpoint {\n        router_args.push(\"--otlp-endpoint\".to_string());\n        router_args.push(otlp_endpoint);\n    }\n\n    // OpenTelemetry\n    let otlp_service_name = args.otlp_service_name;\n    router_args.push(\"--otlp-service-name\".to_string());\n    router_args.push(otlp_service_name);\n\n    // CORS origins\n    for origin in args.cors_allow_origin.into_iter() {\n        router_args.push(\"--cors-allow-origin\".to_string());\n        router_args.push(origin);\n    }\n\n    // API Key\n    if let Some(api_key) = args.api_key {\n        router_args.push(\"--api-key\".to_string());\n        router_args.push(api_key);\n    }\n    // Ngrok\n    if args.ngrok {\n        router_args.push(\"--ngrok\".to_string());\n        router_args.push(\"--ngrok-authtoken\".to_string());\n        router_args.push(args.ngrok_authtoken.unwrap());\n        router_args.push(\"--ngrok-edge\".to_string());\n        router_args.push(args.ngrok_edge.unwrap());\n    }\n\n    // Copy current process env\n    let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();\n\n    // Parse Inference API token\n    if let Ok(api_token) = env::var(\"HF_API_TOKEN\") {\n        envs.push((\"HF_TOKEN\".into(), api_token.into()))\n    };\n\n    // Parse Compute type\n    if let Ok(compute_type) = env::var(\"COMPUTE_TYPE\") {\n        envs.push((\"COMPUTE_TYPE\".into(), compute_type.into()))\n    } else if let Some(compute_type) = compute_type(num_shard) {\n        envs.push((\"COMPUTE_TYPE\".into(), compute_type.into()))\n    }\n\n    let mut webserver = match Command::new(\"text-generation-router\")\n        .args(router_args)\n        .envs(envs)\n        .stdout(Stdio::piped())\n        .stderr(Stdio::piped())\n        .process_group(0)\n        .spawn()\n    {\n        Ok(p) => p,\n        Err(err) => {\n            tracing::error!(\"Failed to start webserver: {}\", err);\n            if err.kind() == io::ErrorKind::NotFound {\n                tracing::error!(\"text-generation-router not found in PATH\");\n                tracing::error!(\"Please install it with `make install-router`\")\n            } else {\n                tracing::error!(\"{}\", err);\n            }\n\n            shutdown_shards(shutdown, shutdown_receiver);\n            return Err(LauncherError::WebserverCannotStart);\n        }\n    };\n\n    // Redirect STDOUT and STDERR to the console\n    let webserver_stdout = webserver.stdout.take().unwrap();\n    let webserver_stderr = webserver.stderr.take().unwrap();\n\n    thread::spawn(move || {\n        let stdout = BufReader::new(webserver_stdout);\n        let stderr = BufReader::new(webserver_stderr);\n        for line in stdout.lines() {\n            println!(\"{}\", line.unwrap());\n        }\n        for line in stderr.lines() {\n            println!(\"{}\", line.unwrap());\n        }\n    });\n    Ok(webserver)\n}\n\nfn terminate(process_name: &str, mut process: Child, timeout: Duration) -> io::Result<ExitStatus> {\n    tracing::info!(\"Terminating {process_name}\");\n\n    let terminate_time = Instant::now();\n    signal::kill(Pid::from_raw(process.id() as i32), Signal::SIGTERM).unwrap();\n\n    tracing::info!(\"Waiting for {process_name} to gracefully shutdown\");\n    while terminate_time.elapsed() < timeout {\n        if let Some(status) = process.try_wait()? {\n            tracing::info!(\"{process_name} terminated\");\n            return Ok(status);\n        }\n        sleep(Duration::from_millis(100));\n    }\n    tracing::info!(\"Killing {process_name}\");\n\n    process.kill()?;\n    let exit_status = process.wait()?;\n\n    tracing::info!(\"{process_name} killed\");\n    Ok(exit_status)\n}\n\nfn main() -> Result<(), LauncherError> {\n    // Pattern match configuration\n    let args: Args = Args::parse();\n\n    let graceful_termination_timeout = args.graceful_termination_timeout;\n\n    // Filter events with LOG_LEVEL\n    let varname = \"LOG_LEVEL\";\n    let env_filter = if let Ok(log_level) = std::env::var(varname) {\n        // Override to avoid simple logs to be spammed with tokio level informations\n        let log_level = match &log_level[..] {\n            \"warn\" => \"text_generation_launcher=warn,text_generation_router=warn\",\n            \"info\" => \"text_generation_launcher=info,text_generation_router=info\",\n            \"debug\" => \"text_generation_launcher=debug,text_generation_router=debug\",\n            log_level => log_level,\n        };\n        EnvFilter::builder()\n            .with_default_directive(LevelFilter::INFO.into())\n            .parse_lossy(log_level)\n    } else {\n        EnvFilter::new(\"info\")\n    };\n    let max_log_level = env_filter.max_level_hint().unwrap_or(LevelFilter::INFO);\n\n    if args.json_output {\n        tracing_subscriber::fmt()\n            .with_env_filter(env_filter)\n            .json()\n            .init();\n    } else {\n        tracing_subscriber::fmt()\n            .with_env_filter(env_filter)\n            .compact()\n            .init();\n    }\n\n    if args.env {\n        let env_runtime = env_runtime::Env::new();\n        tracing::info!(\"{}\", env_runtime);\n    }\n\n    tracing::info!(\"{:#?}\", args);\n\n    let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();\n    let quantize = config.as_ref().and_then(|c| c.quantize);\n    // Quantization usually means you're even more RAM constrained.\n\n    let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);\n    tracing::info!(\"Using attention {attention} - Prefix caching {prefix_caching}\");\n    std::env::set_var(\"PREFIX_CACHING\", prefix_caching);\n    std::env::set_var(\"ATTENTION\", attention);\n\n    let num_shard = find_num_shards(args.sharded, args.num_shard)?;\n    if num_shard > 1 {\n        if matches!(args.quantize, Some(Quantization::Exl2)) {\n            return Err(LauncherError::ArgumentValidation(\n                \"Sharding is currently not supported with `exl2` quantization\".into(),\n            ));\n        }\n        tracing::info!(\"Sharding model on {num_shard} processes\");\n    }\n\n    let max_input_tokens = {\n        match (args.max_input_tokens, args.max_input_length) {\n            (Some(max_input_tokens), Some(max_input_length)) => {\n                return Err(LauncherError::ArgumentValidation(\n                    format!(\"Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.\",\n                )));\n            }\n            (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {\n                Some(max_input_tokens)\n            }\n            (None, None) => None,\n        }\n    };\n    let max_total_tokens = args.max_total_tokens;\n    let max_batch_prefill_tokens = {\n        match args.max_batch_prefill_tokens {\n            Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,\n            None => {\n                let compute_type = compute_type(num_shard);\n                let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());\n                // TODO: remove this when we correctly esimate the flops for VLMs\n                // this is a short term temporary fix to enable vlms to avoid rejecting images\n                let default_optimal = match config {\n                    Some(ref config) => match config.model_type.as_deref() {\n                        Some(\"qwen2_vl\") | Some(\"qwen2_5_vl\") => 10_000,\n                        Some(\"gemma3\") => 8000,\n                        _ => 4096,\n                    },\n                    None => 4096,\n                };\n                let default = compute_optimal.unwrap_or(default_optimal);\n                let vram_maximum = vram_maximum(\n                    config.as_ref(),\n                    compute_type.as_ref(),\n                    args.cuda_memory_fraction,\n                );\n                let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);\n                let value = if let Some(max_position_embeddings) = max_position_embeddings {\n                    default.min(max_position_embeddings)\n                } else {\n                    default\n                };\n                let value = if let Some(vram_maximum) = vram_maximum {\n                    if vram_maximum < value {\n                        tracing::warn!(\"Reducing the max batch prefill from {default} to {vram_maximum} because there is not enough VRAM to support it.\");\n                    }\n                    value.min(vram_maximum)\n                } else {\n                    value\n                };\n                tracing::info!(\"Default `max_batch_prefill_tokens` to {value}\");\n                value as u32\n            }\n        }\n    };\n\n    // Validate args\n    if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {\n        if max_input_tokens >= max_total_tokens {\n            return Err(LauncherError::ArgumentValidation(\n                    format!(\"`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})\"),\n                ));\n        }\n    }\n\n    if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {\n        tracing::warn!(\"Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.\");\n    }\n    let quantize = args.quantize.or(quantize);\n    let cuda_graphs = match (&args.cuda_graphs, &quantize) {\n        (Some(cuda_graphs), _) => cuda_graphs.iter().cloned().filter(|&c| c > 0).collect(),\n        #[allow(deprecated)]\n        (None, Some(Quantization::Bitsandbytes)) => {\n            tracing::warn!(\"Bitsandbytes doesn't work with cuda graphs, deactivating them\");\n            vec![]\n        }\n        (None, Some(Quantization::Exl2)) => {\n            tracing::warn!(\"Exl2 doesn't work with cuda graphs, deactivating them\");\n            vec![]\n        }\n        _ => {\n            let cuda_graphs = vec![1, 2, 4, 8, 16, 32];\n            tracing::info!(\"Using default cuda graphs {cuda_graphs:?}\");\n            cuda_graphs\n        }\n    };\n\n    if args.validation_workers == 0 {\n        return Err(LauncherError::ArgumentValidation(\n            \"`validation_workers` must be > 0\".to_string(),\n        ));\n    }\n    if args.trust_remote_code {\n        tracing::warn!(\n            \"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.\",\n            args.model_id\n        );\n    }\n\n    if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {\n        if let Some(max_total_tokens) = max_total_tokens {\n            if max_total_tokens as u32 > *max_batch_total_tokens {\n                return Err(LauncherError::ArgumentValidation(format!(\n                    \"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}\",\n                    max_total_tokens, max_batch_total_tokens\n                )));\n            }\n        }\n    }\n\n    if args.ngrok {\n        if args.ngrok_authtoken.is_none() {\n            return Err(LauncherError::ArgumentValidation(\n                \"`ngrok-authtoken` must be set when using ngrok tunneling\".to_string(),\n            ));\n        }\n\n        if args.ngrok_edge.is_none() {\n            return Err(LauncherError::ArgumentValidation(\n                \"`ngrok-edge` must be set when using ngrok tunneling\".to_string(),\n            ));\n        }\n    }\n\n    // Signal handler\n    let running = Arc::new(AtomicBool::new(true));\n    let r = running.clone();\n    ctrlc::set_handler(move || {\n        r.store(false, Ordering::SeqCst);\n    })\n    .expect(\"Error setting Ctrl-C handler\");\n\n    // Download and convert model weights\n    download_convert_model(\n        &args.model_id,\n        args.revision.as_deref(),\n        args.trust_remote_code,\n        args.huggingface_hub_cache.as_deref(),\n        args.weights_cache_override.as_deref(),\n        running.clone(),\n        true, // if its only a lora model - we should merge the lora adapters\n    )?;\n\n    // Download and convert lora adapters if any\n    if let Some(lora_adapters) = &args.lora_adapters {\n        for adapter in lora_adapters.split(',') {\n            // skip download if a path is provided\n            if adapter.contains('=') {\n                continue;\n            }\n\n            let adapter = adapter.trim();\n\n            // check if adapter has more than 1 '@'\n            if adapter.matches('@').count() > 1 {\n                return Err(LauncherError::ArgumentValidation(format!(\n                    \"Invalid LoRA adapter format: {}\",\n                    adapter\n                )));\n            }\n\n            // capture adapter_id, path, revision in format of adapter_id=path@revision\n            // path is disabled beforehand.\n            let mut splits = adapter.split(\"@\");\n            let adapter_id = splits.next().ok_or_else(|| {\n                LauncherError::ArgumentValidation(\"Missing adapter id\".to_string())\n            })?;\n            let revision = splits.next();\n            download_convert_model(\n                adapter_id,\n                revision,\n                args.trust_remote_code,\n                args.huggingface_hub_cache.as_deref(),\n                args.weights_cache_override.as_deref(),\n                running.clone(),\n                false, // avoid merging lora adapters if using multi-lora\n            )?;\n        }\n    }\n\n    if !running.load(Ordering::SeqCst) {\n        // Launcher was asked to stop\n        return Ok(());\n    }\n\n    // Shared shutdown bool\n    let shutdown = Arc::new(AtomicBool::new(false));\n    // Shared shutdown channel\n    // When shutting down, the main thread will wait for all senders to be dropped\n    let (shutdown_sender, shutdown_receiver) = mpsc::channel();\n\n    // Shared channel to track shard status\n    let (status_sender, status_receiver) = mpsc::channel();\n\n    spawn_shards(\n        num_shard,\n        &args,\n        cuda_graphs,\n        max_total_tokens,\n        max_input_tokens,\n        quantize,\n        max_log_level,\n        shutdown.clone(),\n        &shutdown_receiver,\n        shutdown_sender,\n        &status_receiver,\n        status_sender,\n        running.clone(),\n        graceful_termination_timeout,\n    )?;\n\n    // We might have received a termination signal\n    if !running.load(Ordering::SeqCst) {\n        shutdown_shards(shutdown, &shutdown_receiver);\n        return Ok(());\n    }\n\n    let mut webserver = spawn_webserver(\n        num_shard,\n        args,\n        max_input_tokens,\n        max_total_tokens,\n        max_batch_prefill_tokens,\n        shutdown.clone(),\n        &shutdown_receiver,\n    )\n    .inspect_err(|_| {\n        shutdown_shards(shutdown.clone(), &shutdown_receiver);\n    })?;\n\n    // Default exit code\n    let mut exit_code = Ok(());\n\n    while running.load(Ordering::SeqCst) {\n        if let Ok(ShardStatus::Failed(rank)) = status_receiver.try_recv() {\n            tracing::error!(\"Shard {rank} crashed\");\n            exit_code = Err(LauncherError::ShardFailed);\n            break;\n        };\n\n        match webserver.try_wait().unwrap() {\n            Some(_) => {\n                tracing::error!(\"Webserver Crashed\");\n                shutdown_shards(shutdown, &shutdown_receiver);\n                return Err(LauncherError::WebserverFailed);\n            }\n            None => {\n                sleep(Duration::from_millis(100));\n            }\n        };\n    }\n\n    // Graceful termination\n    terminate(\n        \"webserver\",\n        webserver,\n        Duration::from_secs(graceful_termination_timeout),\n    )\n    .unwrap();\n    shutdown_shards(shutdown, &shutdown_receiver);\n\n    exit_code\n}\n"
  },
  {
    "path": "load_tests/Makefile",
    "content": "\nShareGPT_V3_unfiltered_cleaned_split.json:\n\twget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n\nprepare_share: ShareGPT_V3_unfiltered_cleaned_split.json\n\tpython filter.py\n\nprepare_orca:\n\tpython orca.py\n"
  },
  {
    "path": "load_tests/benchmarks.py",
    "content": "import argparse\nimport datetime\nimport json\nimport os\nimport traceback\nfrom typing import Dict, Tuple, List\n\nimport GPUtil\nimport docker\nfrom docker.models.containers import Container\nfrom loguru import logger\nimport pandas as pd\n\n\nclass InferenceEngineRunner:\n    def __init__(self, model: str):\n        self.model = model\n\n    def run(self, parameters: list[tuple], gpus: int = 0):\n        NotImplementedError(\"This method should be implemented by the subclass\")\n\n    def stop(self):\n        NotImplementedError(\"This method should be implemented by the subclass\")\n\n\nclass TGIDockerRunner(InferenceEngineRunner):\n    def __init__(\n        self,\n        model: str,\n        image: str = \"ghcr.io/huggingface/text-generation-inference:latest\",\n        volumes=None,\n    ):\n        super().__init__(model)\n        if volumes is None:\n            volumes = []\n        self.container = None\n        self.image = image\n        self.volumes = volumes\n\n    def run(self, parameters: list[tuple], gpus: int = 0):\n        params = f\"--model-id {self.model} --port 8080\"\n        for p in parameters:\n            params += f\" --{p[0]} {str(p[1])}\"\n        logger.info(f\"Running TGI with parameters: {params}\")\n        volumes = {}\n        for v in self.volumes:\n            volumes[v[0]] = {\"bind\": v[1], \"mode\": \"rw\"}\n        self.container = run_docker(\n            self.image,\n            params,\n            \"Connected\",\n            \"ERROR\",\n            volumes=volumes,\n            gpus=gpus,\n            ports={\"8080/tcp\": 8080},\n        )\n\n    def stop(self):\n        if self.container:\n            self.container.stop()\n\n\nclass BenchmarkRunner:\n    def __init__(\n        self,\n        image: str = \"ghcr.io/huggingface/text-generation-inference-benchmark:latest\",\n        volumes: List[Tuple[str, str]] = None,\n    ):\n        if volumes is None:\n            volumes = []\n        self.container = None\n        self.image = image\n        self.volumes = volumes\n\n    def run(self, parameters: list[tuple], network_mode):\n        params = \"text-generation-inference-benchmark\"\n        for p in parameters:\n            params += f\" --{p[0]} {str(p[1])}\" if p[1] is not None else f\" --{p[0]}\"\n        logger.info(\n            f\"Running text-generation-inference-benchmarks with parameters: {params}\"\n        )\n        volumes = {}\n        for v in self.volumes:\n            volumes[v[0]] = {\"bind\": v[1], \"mode\": \"rw\"}\n        self.container = run_docker(\n            self.image,\n            params,\n            \"Benchmark finished\",\n            \"Fatal:\",\n            volumes=volumes,\n            extra_env={\n                \"RUST_LOG\": \"text_generation_inference_benchmark=info\",\n                \"RUST_BACKTRACE\": \"full\",\n            },\n            network_mode=network_mode,\n        )\n\n    def stop(self):\n        if self.container:\n            self.container.stop()\n\n\ndef run_docker(\n    image: str,\n    args: str,\n    success_sentinel: str,\n    error_sentinel: str,\n    ports: Dict[str, int] = None,\n    volumes=None,\n    network_mode: str = \"bridge\",\n    gpus: int = 0,\n    extra_env: Dict[str, str] = None,\n) -> Container:\n    if ports is None:\n        ports = {}\n    if volumes is None:\n        volumes = {}\n    if extra_env is None:\n        extra_env = {}\n    client = docker.from_env(timeout=300)\n    # retrieve the GPU devices from CUDA_VISIBLE_DEVICES\n    devices = [f\"{i}\" for i in range(get_num_gpus())][:gpus]\n    environment = {\"HF_TOKEN\": os.environ.get(\"HF_TOKEN\")}\n    environment.update(extra_env)\n    container = client.containers.run(\n        image,\n        args,\n        detach=True,\n        device_requests=(\n            [docker.types.DeviceRequest(device_ids=devices, capabilities=[[\"gpu\"]])]\n            if gpus > 0\n            else None\n        ),\n        volumes=volumes,\n        shm_size=\"1g\",\n        ports=ports,\n        network_mode=network_mode,\n        environment=environment,\n    )\n    for line in container.logs(stream=True):\n        print(line.decode(\"utf-8\"), end=\"\")\n        if success_sentinel.encode(\"utf-8\") in line:\n            break\n        if error_sentinel.encode(\"utf-8\") in line:\n            container.stop()\n            raise Exception(f\"Error starting container: {line}\")\n    return container\n\n\ndef get_gpu_names() -> str:\n    gpus = GPUtil.getGPUs()\n    if len(gpus) == 0:\n        return \"\"\n    return f'{len(gpus)}x{gpus[0].name if gpus else \"No GPU available\"}'\n\n\ndef get_gpu_name() -> str:\n    gpus = GPUtil.getGPUs()\n    if len(gpus) == 0:\n        return \"\"\n    return gpus[0].name\n\n\ndef get_num_gpus() -> int:\n    return len(GPUtil.getGPUs())\n\n\ndef build_df(model: str, data_files: dict[str, str]) -> pd.DataFrame:\n    df = pd.DataFrame()\n    now = datetime.datetime.now(datetime.timezone.utc)\n    created_at = now.isoformat()  # '2024-10-02T11:53:17.026215+00:00'\n    # Load the results\n    for key, filename in data_files.items():\n        with open(filename, \"r\") as f:\n            data = json.load(f)\n            for result in data[\"results\"]:\n                entry = result\n                [config] = pd.json_normalize(result[\"config\"]).to_dict(orient=\"records\")\n                entry.update(config)\n                entry[\"engine\"] = data[\"config\"][\"meta\"][\"engine\"]\n                entry[\"tp\"] = data[\"config\"][\"meta\"][\"tp\"]\n                entry[\"version\"] = data[\"config\"][\"meta\"][\"version\"]\n                entry[\"model\"] = model\n                entry[\"created_at\"] = created_at\n                del entry[\"config\"]\n                df = pd.concat([df, pd.DataFrame(entry, index=[0])])\n    return df\n\n\ndef main(sha, results_file):\n    results_dir = \"results\"\n    # get absolute path\n    results_dir = os.path.join(os.path.dirname(__file__), results_dir)\n    logger.info(\"Starting benchmark\")\n    models = [\n        (\"meta-llama/Llama-3.1-8B-Instruct\", 1),\n        # ('meta-llama/Llama-3.1-70B-Instruct', 4),\n        # ('mistralai/Mixtral-8x7B-Instruct-v0.1', 2),\n    ]\n    success = True\n    for model in models:\n        tgi_runner = TGIDockerRunner(model[0])\n        # create results directory\n        model_dir = os.path.join(\n            results_dir, f'{model[0].replace(\"/\", \"_\").replace(\".\", \"_\")}'\n        )\n        os.makedirs(model_dir, exist_ok=True)\n        runner = BenchmarkRunner(\n            volumes=[(model_dir, \"/opt/text-generation-inference-benchmark/results\")]\n        )\n        try:\n            tgi_runner.run([(\"max-concurrent-requests\", 512)], gpus=model[1])\n            logger.info(f\"TGI started for model {model[0]}\")\n            parameters = [\n                (\"tokenizer-name\", model[0]),\n                (\"max-vus\", 800),\n                (\"url\", \"http://localhost:8080\"),\n                (\"duration\", \"120s\"),\n                (\"warmup\", \"30s\"),\n                (\"benchmark-kind\", \"rate\"),\n                (\n                    \"prompt-options\",\n                    \"num_tokens=200,max_tokens=220,min_tokens=180,variance=10\",\n                ),\n                (\n                    \"decode-options\",\n                    \"num_tokens=200,max_tokens=220,min_tokens=180,variance=10\",\n                ),\n                (\n                    \"extra-meta\",\n                    f'\"engine=TGI,tp={model[1]},version={sha},gpu={get_gpu_name()}\"',\n                ),\n                (\"no-console\", None),\n            ]\n            rates = [(\"rates\", f\"{r / 10.}\") for r in list(range(8, 248, 8))]\n            parameters.extend(rates)\n            runner.run(parameters, f\"container:{tgi_runner.container.id}\")\n        except Exception as e:\n            logger.error(f\"Error running benchmark for model {model[0]}: {e}\")\n            # print the stack trace\n            print(traceback.format_exc())\n            success = False\n        finally:\n            tgi_runner.stop()\n            runner.stop()\n    if not success:\n        logger.error(\"Some benchmarks failed\")\n        exit(1)\n\n    df = pd.DataFrame()\n    # list recursively directories\n    directories = [\n        f\"{results_dir}/{d}\"\n        for d in os.listdir(results_dir)\n        if os.path.isdir(f\"{results_dir}/{d}\")\n    ]\n    logger.info(f\"Found result directories: {directories}\")\n    for directory in directories:\n        data_files = {}\n        for filename in os.listdir(directory):\n            if filename.endswith(\".json\"):\n                data_files[filename.split(\".\")[-2]] = f\"{directory}/{filename}\"\n        logger.info(f\"Processing directory {directory}\")\n        df = pd.concat([df, build_df(directory.split(\"/\")[-1], data_files)])\n    df[\"device\"] = get_gpu_name()\n    df[\"error_rate\"] = (\n        df[\"failed_requests\"]\n        / (df[\"failed_requests\"] + df[\"successful_requests\"])\n        * 100.0\n    )\n    df.to_parquet(results_file)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--sha\", help=\"SHA of the commit to add to the results\", required=True\n    )\n    parser.add_argument(\n        \"--results-file\",\n        help=\"The file where to store the results, can be a local file or a s3 path\",\n    )\n    args = parser.parse_args()\n    if args.results_file is None:\n        results_file = f\"{args.sha}.parquet\"\n    else:\n        results_file = args.results_file\n\n    main(args.sha, results_file)\n"
  },
  {
    "path": "load_tests/common.js",
    "content": "import { check } from 'k6';\nimport { scenario } from 'k6/execution';\nimport http from 'k6/http';\nimport { Trend, Counter } from 'k6/metrics';\n\nconst host = __ENV.HOST;\nconst model_id = __ENV.MODEL_ID;\nconst timePerToken = new Trend('time_per_token', true);\nconst tokens = new Counter('tokens');\nconst new_tokens = new Counter('new_tokens');\nconst input_tokens = new Counter('input_tokens');\nconst max_new_tokens = 50;\n\n// const shareGPT = JSON.parse(open(\"ShareGPT_V3_unfiltered_cleaned_split.json\"))\nconst shareGPT = JSON.parse(open(\"small.json\"))\n\n\nexport function get_options() {\n    return {\n        thresholds: {\n            http_req_failed: ['rate==0'],\n            // time_per_token: [{\n            //     threshold: `p(50)<${5 * reference_latency_ms}`,\n            //     abortOnFail: true,\n            //     delayAbortEval: '10s'\n            // }],\n        },\n        scenarios: {\n            // single_user: {\n            //     executor: 'constant-arrival-rate',\n            //     duration: '60s',\n            //     preAllocatedVUs: 1,\n            //     rate: 20,\n            //     timeUnit: '1s',\n            // },\n            // load_test: {\n            //     executor: 'constant-arrival-rate',\n            //     duration: '60s',\n            //     preAllocatedVUs: 100,\n            //     rate: 1,\n            //     timeUnit: '1s',\n            // },\n            // breakpoint: {\n            //     executor: 'ramping-arrival-rate', //Assure load increase if the system slows\n            //     preAllocatedVUs: 300,\n            //     stages: [\n            //         { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load\n            //     ],\n            // },\n            throughput: {\n                executor: 'shared-iterations',\n                vus: 100,\n                iterations: 200,\n                maxDuration: '40s',\n            },\n        },\n    };\n}\n\nfunction generate_payload(gpt, max_new_tokens) {\n    const input = gpt[\"conversations\"][0][\"value\"];\n    return { \"messages\": [{ \"role\": \"user\", \"content\": input }], \"temperature\": 0, \"model\": `${model_id}`, \"max_tokens\": max_new_tokens }\n}\n\nexport const options = get_options();\n\nexport default function run() {\n    const headers = { 'Content-Type': 'application/json' };\n    const query = shareGPT[scenario.iterationInTest % shareGPT.length];\n    const payload = JSON.stringify(generate_payload(query, max_new_tokens));\n    const res = http.post(`http://${host}/v1/chat/completions`, payload, {\n        headers,\n    });\n    if (res.status >= 400 && res.status < 500) {\n        return;\n    }\n\n\n    check(res, {\n        'Post status is 200': (res) => res.status === 200,\n    });\n    const duration = res.timings.duration;\n\n    if (res.status === 200) {\n        const body = res.json();\n        const completion_tokens = body.usage.completion_tokens;\n        const latency_ms_per_token = duration / completion_tokens;\n        timePerToken.add(latency_ms_per_token);\n        const prompt_tokens = body.usage.prompt_tokens;\n        input_tokens.add(prompt_tokens);\n        new_tokens.add(completion_tokens);\n        tokens.add(completion_tokens + prompt_tokens);\n    }\n}\n"
  },
  {
    "path": "load_tests/filter.py",
    "content": "import json\n\n\ndef main():\n    with open(\"./ShareGPT_V3_unfiltered_cleaned_split.json\", \"r\") as f:\n        data = json.load(f)\n\n    # Select only the first 2k conversations that start with a human.\n    max = 2000\n    conversations = []\n    for conversation in data:\n        conv = conversation.get(\"conversations\")\n        if conv and conv[0][\"from\"] == \"human\":\n            # Trim the rest of the output\n            conversation[\"conversations\"] = conversation[\"conversations\"][:1]\n            conversations.append(conversation)\n\n            if len(conversation) >= max:\n                break\n\n    with open(\"./small.json\", \"w\") as f:\n        data = json.dump(conversations, f, indent=4)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "load_tests/long.js",
    "content": "import { check } from 'k6';\nimport { scenario } from 'k6/execution';\nimport http from 'k6/http';\nimport { Trend, Counter } from 'k6/metrics';\n\nconst host = __ENV.HOST;\nconst model_id = __ENV.MODEL_ID;\nconst timePerToken = new Trend('time_per_token', true);\nconst tokens = new Counter('tokens');\nconst new_tokens = new Counter('new_tokens');\nconst input_tokens = new Counter('input_tokens');\nconst max_new_tokens = 50;\n\n// const shareGPT = JSON.parse(open(\"ShareGPT_V3_unfiltered_cleaned_split.json\"))\nconst shareGPT = JSON.parse(open(\"long.json\"))\n\n\nexport function get_options() {\n    return {\n        thresholds: {\n            http_req_failed: ['rate==0'],\n            // time_per_token: [{\n            //     threshold: `p(50)<${5 * reference_latency_ms}`,\n            //     abortOnFail: true,\n            //     delayAbortEval: '10s'\n            // }],\n        },\n        scenarios: {\n            // single_user: {\n            //     executor: 'constant-arrival-rate',\n            //     duration: '60s',\n            //     preAllocatedVUs: 1,\n            //     rate: 20,\n            //     timeUnit: '1s',\n            // },\n            // load_test: {\n            //     executor: 'constant-arrival-rate',\n            //     duration: '60s',\n            //     preAllocatedVUs: 100,\n            //     rate: 1,\n            //     timeUnit: '1s',\n            // },\n            // breakpoint: {\n            //     executor: 'ramping-arrival-rate', //Assure load increase if the system slows\n            //     preAllocatedVUs: 300,\n            //     stages: [\n            //         { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load\n            //     ],\n            // },\n            throughput: {\n                executor: 'shared-iterations',\n                vus: 10,\n                iterations: 10,\n                maxDuration: '120s',\n            },\n        },\n    };\n}\n\nfunction generate_payload(gpt, max_new_tokens) {\n    const input = gpt[\"conversations\"][0][\"value\"];\n    return { \"messages\": [{ \"role\": \"user\", \"content\": input }], \"temperature\": 0, \"model\": `${model_id}`, \"max_tokens\": max_new_tokens }\n}\n\nexport const options = get_options();\n\nexport default function run() {\n    const headers = { 'Content-Type': 'application/json' };\n    const query = shareGPT[scenario.iterationInTest % shareGPT.length];\n    const payload = JSON.stringify(generate_payload(query, max_new_tokens));\n    const res = http.post(`http://${host}/v1/chat/completions`, payload, {\n        headers,\n    });\n    if (res.status >= 400 && res.status < 500) {\n        return;\n    }\n\n\n    check(res, {\n        'Post status is 200': (res) => res.status === 200,\n    });\n    const duration = res.timings.duration;\n\n    if (res.status === 200) {\n        const body = res.json();\n        const completion_tokens = body.usage.completion_tokens;\n        const latency_ms_per_token = duration / completion_tokens;\n        timePerToken.add(latency_ms_per_token);\n        const prompt_tokens = body.usage.prompt_tokens;\n        input_tokens.add(prompt_tokens);\n        new_tokens.add(completion_tokens);\n        tokens.add(completion_tokens + prompt_tokens);\n    }\n}\n"
  },
  {
    "path": "load_tests/long.py",
    "content": "import datasets\nimport json\n\n\ndataset = datasets.load_dataset(\"ccdv/govreport-summarization\")\nmax_new_tokens = 50\n\n\nconversations = []\n\nfor i, item in enumerate(dataset[\"test\"]):\n    report = item[\"report\"]\n\n    messages = [{\"from\": \"human\", \"value\": f\"Summarize this report: ```{report}```\"}]\n\n    conversations.append({\"conversations\": messages})\n\nwith open(\"long.json\", \"w\") as f:\n    json.dump(conversations, f, indent=4)\n"
  },
  {
    "path": "load_tests/long_prompt2.py",
    "content": "# https://www.gutenberg.org/cache/epub/103/pg103.txt\nfrom openai import OpenAI\nimport os\nimport requests\n\nif not os.path.exists(\"pg103.txt\"):\n    response = requests.get(\"https://www.gutenberg.org/cache/epub/103/pg103.txt\")\n    with open(\"pg103.txt\", \"w\") as f:\n        f.write(response.text)\n\n\nlength = 130000\nwith open(\"pg103.txt\", \"r\") as f:\n    data = f.read()\n\nmessages = [{\"role\": \"user\", \"content\": data[: length * 4]}]\n\nclient = OpenAI(base_url=\"http://localhost:8000/v1\", api_key=\"w\")\n\ncompletion = client.chat.completions.create(\n    model=\"meta-llama/Llama-3.1-8B-Instruct\", messages=messages, max_tokens=2\n)\n"
  },
  {
    "path": "load_tests/orca.py",
    "content": "import json\nimport datasets\nimport tqdm\n\n\ndef main():\n    dataset = datasets.load_dataset(\"Open-Orca/OpenOrca\", split=\"train\")\n    # Select only the first 2k conversations that start with a human.\n    max = min(2000, len(dataset))\n    conversations = []\n    for item in tqdm.tqdm(dataset, total=max):\n        conversation = {\n            \"conversations\": [\n                {\"from\": \"human\", \"value\": item[\"question\"]},\n            ],\n            \"id\": item[\"id\"],\n        }\n        conversations.append(conversation)\n        if len(conversations) >= max:\n            break\n\n    with open(\"./small.json\", \"w\") as f:\n        json.dump(conversations, f, indent=4)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "load_tests/pyproject.toml",
    "content": "[tool.poetry]\nname = \"text-generation-inference-benchmarks\"\nversion = \"0.1.0\"\ndescription = \"\"\nauthors = [\"Hugo Larcher <hugo.larcher@huggingface.co>\"]\nreadme = \"README.md\"\n\n[tool.poetry.dependencies]\npython = \"^3.11\"\ndocker = \"^7.1.0\"\nloguru = \"^0.7.2\"\npsutil = \"^6.0.0\"\ngputil = \"^1.4.0\"\npandas = \"^2.2.3\"\npyarrow = \"^17.0.0\"\n\n[build-system]\nrequires = [\"poetry-core\"]\nbuild-backend = \"poetry.core.masonry.api\"\n"
  },
  {
    "path": "nix/client.nix",
    "content": "{\n  buildPythonPackage,\n  poetry-core,\n  aiohttp,\n  huggingface-hub,\n  pydantic,\n}:\n\nbuildPythonPackage {\n  name = \"text-generation\";\n\n  src = ../clients/python;\n\n  pyproject = true;\n\n  build-system = [ poetry-core ];\n\n  dependencies = [\n    aiohttp\n    huggingface-hub\n    pydantic\n  ];\n}\n"
  },
  {
    "path": "nix/crate-overrides.nix",
    "content": "{ pkgs, nix-filter }:\n\nlet\n  filter = nix-filter.lib;\nin\nwith pkgs;\ndefaultCrateOverrides\n// {\n  aws-lc-rs = attrs: {\n    # aws-lc-rs does its own custom parsing of Cargo environment\n    # variables like DEP_.*_INCLUDE. However buildRustCrate does\n    # not use the version number, so the parsing fails.\n    postPatch = ''\n      substituteInPlace build.rs \\\n        --replace-fail \\\n        \"assert!(!selected.is_empty()\" \\\n        \"// assert!(!selected.is_empty()\"\n    '';\n  };\n  rav1e = attrs: { env.CARGO_ENCODED_RUSTFLAGS = \"-C target-feature=-crt-static\"; };\n\n  grpc-metadata = attrs: {\n    src = filter {\n      root = ../backends/grpc-metadata;\n      include = with filter; [\n        isDirectory\n        (matchExt \"rs\")\n      ];\n    };\n  };\n  pyo3-build-config = attrs: {\n    buildInputs = [ python3 ];\n  };\n  text-generation-benchmark = attrs: {\n    src = filter {\n      root = ../benchmark;\n      include = with filter; [\n        isDirectory\n        (matchExt \"rs\")\n      ];\n    };\n  };\n  text-generation-client = attrs: {\n    src = filter {\n      root = ../.;\n      include = with filter; [\n        isDirectory\n        (and (inDirectory \"backends/client\") (matchExt \"rs\"))\n        (and (inDirectory \"proto\") (matchExt \"proto\"))\n      ];\n    };\n    postPatch = \"cd backends/client\";\n    buildInputs = [ protobuf ];\n  };\n  text-generation-launcher = attrs: {\n    src = filter {\n      root = ../launcher;\n      include = with filter; [\n        isDirectory\n        (matchExt \"rs\")\n      ];\n    };\n  };\n  text-generation-router = attrs: {\n    src = filter {\n      root = ../router;\n      include = with filter; [\n        isDirectory\n        (matchExt \"rs\")\n      ];\n    };\n  };\n  text-generation-router-v3 = attrs: {\n    # We need to do the src/source root dance so that the build\n    # has access to the protobuf file.\n    src = filter {\n      root = ../.;\n      include = with filter; [\n        isDirectory\n        (and (inDirectory \"backends/v3\") (matchExt \"rs\"))\n        (and (inDirectory \"proto\") (matchExt \"proto\"))\n      ];\n    };\n    postPatch = \"cd backends/v3\";\n    buildInputs = [ protobuf ];\n  };\n}\n"
  },
  {
    "path": "nix/docker.nix",
    "content": "{\n  stdenv,\n  dockerTools,\n  cacert,\n  text-generation-inference,\n  stream ? false,\n}:\n\nlet\n  build = if stream then dockerTools.streamLayeredImage else dockerTools.buildLayeredImage;\nin\nbuild {\n  name = \"tgi-docker\";\n  tag = \"latest\";\n  compressor = \"zstd\";\n  config = {\n    EntryPoint = [ \"${text-generation-inference}/bin/text-generation-inference\" ];\n    Env = [\n      \"HF_HOME=/data\"\n      \"PORT=80\"\n      # The CUDA container toolkit will mount the driver shim into the\n      # container. We just have to ensure that the dynamic loader finds\n      # the libraries.\n      \"LD_LIBRARY_PATH=/usr/lib64\"\n    ];\n\n  };\n  extraCommands = ''\n    mkdir -p tmp\n    chmod -R 1777 tmp\n  '';\n  contents = [\n    cacert\n    stdenv.cc\n  ];\n}\n"
  },
  {
    "path": "nix/impure-shell.nix",
    "content": "{\n  lib,\n  mkShell,\n  black,\n  cmake,\n  isort,\n  ninja,\n  which,\n  cudaPackages,\n  openssl,\n  pkg-config,\n  poetry,\n  protobuf,\n  python3,\n  pyright,\n  redocly,\n  ruff,\n  rust-bin,\n  server,\n\n  # Enable dependencies for building CUDA packages. Useful for e.g.\n  # developing marlin/moe-kernels in-place.\n  withCuda ? false,\n}:\n\nmkShell {\n  nativeBuildInputs =\n    [\n      black\n      isort\n      pkg-config\n      poetry\n      (rust-bin.stable.latest.default.override {\n        extensions = [\n          \"rust-analyzer\"\n          \"rust-src\"\n        ];\n      })\n      protobuf\n      pyright\n      redocly\n      ruff\n    ]\n    ++ (lib.optionals withCuda [\n      cmake\n      ninja\n      which\n\n      # For most Torch-based extensions, setting CUDA_HOME is enough, but\n      # some custom CMake builds (e.g. vLLM) also need to have nvcc in PATH.\n      cudaPackages.cuda_nvcc\n    ]);\n  buildInputs =\n    [\n      openssl.dev\n    ]\n    ++ (with python3.pkgs; [\n      venvShellHook\n      docker\n      pip\n      ipdb\n      click\n      openai\n      pytest\n      pytest-asyncio\n      syrupy\n    ])\n    ++ (lib.optionals withCuda (\n      with cudaPackages;\n      [\n        cuda_cccl\n        cuda_cudart\n        cuda_nvrtc\n        cuda_nvtx\n        cuda_profiler_api\n        cudnn\n        libcublas\n        libcusolver\n        libcusparse\n      ]\n    ));\n\n  inputsFrom = [ server ];\n\n  env = lib.optionalAttrs withCuda {\n    CUDA_HOME = \"${lib.getDev cudaPackages.cuda_nvcc}\";\n    TORCH_CUDA_ARCH_LIST = lib.concatStringsSep \";\" python3.pkgs.torch.cudaCapabilities;\n  };\n\n  venvDir = \"./.venv\";\n\n  postVenvCreation = ''\n    unset SOURCE_DATE_EPOCH\n    ( cd server ; python -m pip install --no-build-isolation --no-dependencies -e . )\n    ( cd clients/python ; python -m pip install --no-dependencies -e . )\n  '';\n\n  postShellHook =\n    ''\n      unset SOURCE_DATE_EPOCH\n      export PATH=${cudaPackages.backendStdenv.cc}/bin:$PATH:~/.cargo/bin\n    ''\n    # At various points in time, the latest gcc supported by CUDA differs\n    # from the default version in nixpkgs. A lot of the dependencies in\n    # the impure environment pull in the default gcc from nixpkgs, so we\n    # end up with the CUDA-supported gcc and the nixpkgs default gcc in\n    # the path. To ensure that we can build CUDA kernels, put the CUDA\n    # first in the path. It's a hack, but it works.\n    + lib.optionalString withCuda ''\n      export PATH=${cudaPackages.backendStdenv.cc}/bin:$PATH\n    '';\n}\n"
  },
  {
    "path": "nix/overlay.nix",
    "content": "final: prev: {\n  # You can use this overlay to temporarily override packages for\n  # development. For permanent overrides, it's better to do this in\n  # our package flake:\n  #\n  # https://github.com/huggingface/text-generation-inference-nix\n  #\n  # Note that overriding packages that are in the transitive closure\n  # of many other packages (e.g. transformers) will require a large\n  # rebuild.\n\n  pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [\n    (\n      python-self: python-super: with python-self; {\n        # Python package override example:\n        #transformers = python-super.transformers.overrideAttrs (\n        #  _: _: {\n        #    src = final.fetchFromGitHub {\n        #      owner = \"huggingface\";\n        #      repo = \"transformers\";\n        #      rev = \"v4.51.0\";\n        #      hash = \"sha256-dnVpc6fm1SYGcx7FegpwVVxUY6XRlsxLs5WOxYv11y8=\";\n        #    };\n        #  }\n        #);\n        #huggingface-hub = python-super.huggingface-hub.overrideAttrs (\n        #  _: _: {\n        #    src = final.fetchFromGitHub {\n        #      owner = \"huggingface\";\n        #      repo = \"huggingface_hub\";\n        #      rev = \"v0.30.0\";\n        #      hash = \"sha256-sz+n1uoWrSQPqJFiG/qCT6b4r08kD9MsoPZXbfWNB2o=\";\n        #    };\n        #  }\n        #);\n      }\n    )\n  ];\n\n  # Non-python package override example:\n  #\n  # ripgrep = prev.ripgrep.overrideAttrs (\n  #    _: _: {\n  #      src = final.fetchFromGitHub {\n  #      owner = \"BurntSushi\";\n  #      repo = \"ripgrep\";\n  #      rev = \"79cbe89deb1151e703f4d91b19af9cdcc128b765\";\n  #      hash = \"sha256-JPTM2KNmGMb+/jOfK3X7OM1wnN+3TU35SJOIcqmp3mg=\";\n  #   };\n  # });\n}\n"
  },
  {
    "path": "nix/server.nix",
    "content": "{\n  nix-filter,\n  buildPythonPackage,\n  poetry-core,\n  mypy-protobuf,\n  awq-inference-engine,\n  causal-conv1d,\n  compressed-tensors,\n  einops,\n  exllamav2,\n  flashinfer,\n  flash-attn,\n  flash-attn-layer-norm,\n  flash-attn-v1,\n  grpc-interceptor,\n  grpcio-reflection,\n  grpcio-status,\n  grpcio-tools,\n  hf-transfer,\n  hf-xet,\n  kernels,\n  loguru,\n  mamba-ssm,\n  moe,\n  opentelemetry-api,\n  opentelemetry-exporter-otlp,\n  opentelemetry-instrumentation-grpc,\n  opentelemetry-semantic-conventions,\n  outlines,\n  paged-attention,\n  peft,\n  pillow,\n  prometheus-client,\n  punica-sgmv,\n  py-cpuinfo,\n  pydantic,\n  quantization,\n  quantization-eetq,\n  rotary,\n  safetensors,\n  tokenizers,\n  torch,\n  sentencepiece,\n  transformers,\n  typer,\n}:\n\nlet\n  filter = nix-filter.lib;\nin\nbuildPythonPackage {\n  name = \"text-generation-server\";\n\n  src = filter {\n    root = ../.;\n    include = with filter; [\n      isDirectory\n      (and (inDirectory \"server\") (or_ (matchExt \"py\") (matchExt \"pyi\")))\n      \"server/pyproject.toml\"\n      (and (inDirectory \"proto/v3\") (matchExt \"proto\"))\n    ];\n  };\n\n  pyproject = true;\n\n  build-system = [ poetry-core ];\n\n  nativeBuildInputs = [ mypy-protobuf ];\n\n  pythonRelaxDeps = [\n    \"einops\"\n    \"huggingface-hub\"\n    \"loguru\"\n    \"opentelemetry-instrumentation-grpc\"\n    \"pillow\"\n    \"sentencepiece\"\n    \"typer\"\n  ];\n\n  pythonRemoveDeps = [ \"scipy\" ];\n\n  dependencies = [\n    awq-inference-engine\n    causal-conv1d\n    compressed-tensors\n    einops\n    exllamav2\n    flashinfer\n    flash-attn\n    flash-attn-layer-norm\n    grpc-interceptor\n    grpcio-reflection\n    grpcio-status\n    grpcio-tools\n    hf-transfer\n    hf-xet\n    kernels\n    loguru\n    mamba-ssm\n    moe\n    opentelemetry-api\n    opentelemetry-exporter-otlp\n    opentelemetry-instrumentation-grpc\n    opentelemetry-semantic-conventions\n    outlines\n    paged-attention\n    peft\n    pillow\n    prometheus-client\n    punica-sgmv\n    py-cpuinfo\n    pydantic\n    quantization\n    quantization-eetq\n    rotary\n    safetensors\n    sentencepiece\n    tokenizers\n    transformers\n    typer\n  ];\n\n  prePatch = ''\n    python -m grpc_tools.protoc -Iproto/v3 --python_out=server/text_generation_server/pb \\\n           --grpc_python_out=server/text_generation_server/pb --mypy_out=server/text_generation_server/pb proto/v3/generate.proto\n    find server/text_generation_server/pb/ -type f -name \"*.py\" -print0 -exec sed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' {} \\;\n    touch server/text_generation_server/pb/__init__.py\n    cd server\n  '';\n}\n"
  },
  {
    "path": "proto/generate.proto",
    "content": "syntax = \"proto3\";\n\npackage generate.v2;\n\nservice TextGenerationService {\n    /// Model Info\n    rpc Info (InfoRequest) returns (InfoResponse) {}\n    /// Service discovery\n    rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}\n    /// Empties batch cache\n    rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);\n    /// Remove requests from a cached batch\n    rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);\n    /// Warmup the model and compute max cache size\n    rpc Warmup (WarmupRequest) returns (WarmupResponse);\n    /// Prefill batch and decode first token\n    rpc Prefill (PrefillRequest) returns (PrefillResponse);\n    /// Decode token for a list of prefilled batches\n    rpc Decode (DecodeRequest) returns (DecodeResponse);\n    /// Health check\n    rpc Health (HealthRequest) returns (HealthResponse);\n}\n\nmessage HealthRequest {}\nmessage HealthResponse {}\n\n/// Empty request\nmessage InfoRequest {}\n\nmessage InfoResponse {\n    bool requires_padding = 1;\n    string dtype = 2;\n    string device_type = 3;\n    optional uint32 window_size = 4;\n    uint32 speculate = 5;\n}\n\n/// Empty request\nmessage ServiceDiscoveryRequest {}\n\nmessage ServiceDiscoveryResponse {\n    /// Other shards urls\n    repeated string urls = 1;\n}\n\nmessage ClearCacheRequest {\n    /// Optional batch id\n    optional uint64 id = 1;\n}\n\n/// Empty response\nmessage ClearCacheResponse {}\n\nenum GrammarType {\n    GRAMMAR_TYPE_NONE = 0;\n    GRAMMAR_TYPE_JSON = 1;\n    GRAMMAR_TYPE_REGEX = 2;\n}\n\nmessage NextTokenChooserParameters {\n    /// exponential scaling output probability distribution\n    float temperature = 1;\n    /// restricting to the k highest probability elements\n    uint32 top_k = 2;\n    /// restricting to top tokens summing to prob_cut_off <= prob_cut_off\n    float top_p = 3;\n    /// restricting to top tokens summing to prob_cut_off <= prob_cut_off\n    float typical_p = 4;\n    /// apply sampling on the logits\n    bool do_sample = 5;\n    /// random seed for sampling\n    uint64 seed = 6;\n    /// repetition penalty\n    float repetition_penalty = 7;\n    /// frequency penalty\n    float frequency_penalty = 9;\n    /// token watermarking using \"A Watermark for Large Language Models\"\n    bool watermark = 8;\n    /// grammar (applied if not empty)\n    string grammar = 10;\n    /// grammar type\n    GrammarType grammar_type = 11;\n}\n\nmessage StoppingCriteriaParameters {\n    /// Maximum number of generated tokens\n    uint32 max_new_tokens = 1;\n    /// Optional stopping sequences\n    repeated string stop_sequences = 2;\n    /// Ignore end of sequence token\n    /// used for benchmarking\n    bool ignore_eos_token = 3;\n}\n\nmessage Request {\n    /// Request ID\n    uint64 id = 1;\n    /// The generation context\n    string inputs = 2;\n    /// Context truncation\n    uint32 truncate = 3;\n    /// Next Token Chooser Parameters\n    NextTokenChooserParameters parameters = 4;\n    /// Stopping Criteria Parameters\n    StoppingCriteriaParameters stopping_parameters = 5;\n    /// Return prefill logprobs\n    bool prefill_logprobs = 6;\n    /// Return most likely n tokens\n    uint32 top_n_tokens = 7;\n}\n\nmessage Batch {\n    /// Batch ID\n    uint64 id = 1;\n    /// Individual requests\n    repeated Request requests = 2;\n    /// Batch size (==len(requests))\n    uint32 size = 3;\n    /// Maximum number of tokens this batch will grow to\n    uint32 max_tokens = 4;\n}\n\nmessage CachedBatch {\n    /// Batch ID\n    uint64 id = 1;\n    /// Individual requests ids\n    repeated uint64 request_ids = 2;\n    /// Batch size (==len(requests))\n    uint32 size = 3;\n    /// Maximum number of tokens this batch will grow to\n    uint32 max_tokens = 4;\n}\n\nenum FinishReason {\n    FINISH_REASON_LENGTH = 0;\n    FINISH_REASON_EOS_TOKEN = 1;\n    FINISH_REASON_STOP_SEQUENCE = 2;\n}\n\nmessage GeneratedText {\n    /// Output\n    string text = 1;\n    /// Number of generated tokens\n    uint32 generated_tokens = 2;\n    /// Finish reason\n    FinishReason finish_reason = 3;\n    /// Seed\n    optional uint64 seed = 4;\n}\n\nmessage Tokens {\n    /// Token IDs\n    repeated uint32 ids = 1;\n    /// Logprobs\n    repeated float logprobs = 2;\n    /// tokens\n    repeated string texts = 3;\n    /// special\n    repeated bool is_special = 4;\n}\n\nmessage Generation {\n    /// Request ID\n    uint64 request_id = 1;\n    /// Prefill tokens (optional)\n    Tokens prefill_tokens = 2;\n    Tokens tokens = 3;\n    /// Complete generated text\n    optional GeneratedText generated_text = 4;\n    /// Top tokens\n    repeated Tokens top_tokens = 5;\n}\n\nmessage FilterBatchRequest {\n    /// Batch ID\n    uint64 batch_id = 1;\n    /// Requests to keep\n    repeated uint64 request_ids = 2;\n}\n\nmessage FilterBatchResponse {\n    /// Filtered Batch (cached)\n    CachedBatch batch = 1;\n}\n\n\nmessage PrefillRequest {\n    /// Batch\n    Batch batch = 1;\n}\n\nmessage PrefillResponse {\n    /// Generation\n    repeated Generation generations = 1;\n    /// Next batch (cached)\n    optional CachedBatch batch = 2;\n    /// Forward elapsed time in nanoseconds\n    uint64 forward_ns = 3;\n    /// Decode elapsed time in nanoseconds\n    uint64 decode_ns = 4;\n    /// Total elapsed time in nanoseconds\n    uint64 total_ns = 5;\n}\n\nmessage DecodeRequest {\n    /// Cached batches\n    repeated CachedBatch batches = 1;\n}\n\nmessage DecodeResponse {\n    /// Decodes\n    repeated Generation generations = 1;\n    /// Next batch (cached)\n    optional CachedBatch batch = 2;\n    /// Forward elapsed time in nanoseconds\n    uint64 forward_ns = 3;\n    /// Decode elapsed time in nanoseconds\n    uint64 decode_ns = 4;\n    /// Total elapsed time in nanoseconds\n    uint64 total_ns = 5;\n    /// Concatenate elapsed time in nanoseconds\n    optional uint64 concat_ns = 6;\n}\n\nmessage WarmupRequest {\n    /// Batch to warmup on\n    Batch batch = 1;\n    uint32 max_input_length = 2;\n    uint32 max_prefill_tokens = 3;\n    uint32 max_total_tokens = 4;\n}\n\nmessage WarmupResponse {\n    /// Maximum number of tokens supported by the model\n    optional uint32 max_supported_total_tokens = 1;\n}\n"
  },
  {
    "path": "proto/v3/generate.proto",
    "content": "syntax = \"proto3\";\n\npackage generate.v3;\n\nservice TextGenerationService {\n  /// Model Info\n  rpc Info(InfoRequest) returns (InfoResponse) {}\n  /// Service discovery\n  rpc ServiceDiscovery(ServiceDiscoveryRequest)\n      returns (ServiceDiscoveryResponse) {}\n  /// Empties batch cache\n  rpc ClearCache(ClearCacheRequest) returns (ClearCacheResponse);\n  /// Remove requests from a cached batch\n  rpc FilterBatch(FilterBatchRequest) returns (FilterBatchResponse);\n  /// Warmup the model and compute max cache size\n  rpc Warmup(WarmupRequest) returns (WarmupResponse);\n  /// Prefill batch and decode first token\n  rpc Prefill(PrefillRequest) returns (PrefillResponse);\n  /// Decode token for a list of prefilled batches\n  rpc Decode(DecodeRequest) returns (DecodeResponse);\n  /// Health check\n  rpc Health(HealthRequest) returns (HealthResponse);\n}\n\nmessage HealthRequest {}\nmessage HealthResponse {}\n\n/// Empty request\nmessage InfoRequest {}\n\nmessage InfoResponse {\n  bool requires_padding = 1;\n  string dtype = 2;\n  string device_type = 3;\n  optional uint32 window_size = 4;\n  uint32 speculate = 5;\n  bool support_chunking = 6;\n  bool use_prefix_caching = 7;\n  string attention_impl = 8;\n  uint32 block_size = 9;\n}\n\n/// Empty request\nmessage ServiceDiscoveryRequest {}\n\nmessage ServiceDiscoveryResponse {\n  /// Other shards urls\n  repeated string urls = 1;\n}\n\nmessage ClearCacheRequest {\n  /// Optional batch id\n  optional uint64 id = 1;\n}\n\n/// Empty response\nmessage ClearCacheResponse {}\n\nmessage Image {\n  /// Binary image data.\n  bytes data = 1;\n\n  /// Image MIME type.\n  string mimetype = 2;\n}\n\nmessage InputChunk {\n  oneof chunk {\n    /// Plain text data\n    string text = 1;\n    /// Image data\n    Image image = 2;\n  }\n}\n\nmessage Input { repeated InputChunk chunks = 1; }\n\nenum GrammarType {\n  GRAMMAR_TYPE_NONE = 0;\n  GRAMMAR_TYPE_JSON = 1;\n  GRAMMAR_TYPE_REGEX = 2;\n}\n\nmessage NextTokenChooserParameters {\n  /// exponential scaling output probability distribution\n  float temperature = 1;\n  /// restricting to the k highest probability elements\n  uint32 top_k = 2;\n  /// restricting to top tokens summing to prob_cut_off <= prob_cut_off\n  float top_p = 3;\n  /// restricting to top tokens summing to prob_cut_off <= prob_cut_off\n  float typical_p = 4;\n  /// apply sampling on the logits\n  bool do_sample = 5;\n  /// random seed for sampling\n  uint64 seed = 6;\n  /// repetition penalty\n  float repetition_penalty = 7;\n  /// frequency penalty\n  float frequency_penalty = 9;\n  /// token watermarking using \"A Watermark for Large Language Models\"\n  bool watermark = 8;\n  /// grammar (applied if not empty)\n  string grammar = 10;\n  /// grammar type\n  GrammarType grammar_type = 11;\n}\n\nmessage StoppingCriteriaParameters {\n  /// Maximum number of generated tokens\n  uint32 max_new_tokens = 1;\n  /// Optional stopping sequences\n  repeated string stop_sequences = 2;\n  /// Ignore end of sequence token\n  /// used for benchmarking\n  bool ignore_eos_token = 3;\n}\n\nmessage Request {\n  /// Request ID\n  uint64 id = 1;\n  /// The generation context as chunks\n  Input input_chunks = 8;\n  /// The generation context, stringified input_chunks\n  string inputs = 2;\n  /// Context truncation\n  uint32 truncate = 3;\n  /// Next Token Chooser Parameters\n  NextTokenChooserParameters parameters = 4;\n  /// Stopping Criteria Parameters\n  StoppingCriteriaParameters stopping_parameters = 5;\n  /// Return prefill logprobs\n  bool prefill_logprobs = 6;\n  /// Return most likely n tokens\n  uint32 top_n_tokens = 7;\n  /// Paged attention blocks\n  repeated uint32 blocks = 9;\n  /// Paged attention slots\n  repeated uint32 slots = 10;\n  /// LORA adapter index\n  optional string adapter_id = 11;\n  /// Tokens that can be retrieved from the KV cache.\n  /// This value is set for the first prefill and never reset\n  uint32 cache_len = 12;\n  /// Context truncation\n  bool add_special_tokens = 13;\n  /// Chunk of tokens that must be computed for the first prefill\n  /// This value is set for the first prefill and never reset\n  optional uint32 chunk_len = 14;\n}\n\nmessage Batch {\n  /// Batch ID\n  uint64 id = 1;\n  /// Individual requests\n  repeated Request requests = 2;\n  /// Batch size (==len(requests))\n  uint32 size = 3;\n  /// Maximum number of tokens this batch will grow to\n  uint32 max_tokens = 4;\n  /// Maximum number of Paged Attention blocks\n  uint32 max_blocks = 5;\n}\n\nmessage CachedBatch {\n  /// Batch ID\n  uint64 id = 1;\n  /// Individual requests ids\n  repeated uint64 request_ids = 2;\n  /// Batch size (==len(requests))\n  uint32 size = 3;\n  /// Maximum number of tokens this batch will grow to\n  uint32 max_tokens = 4;\n  /// Number of tokens in the next forward\n  uint32 current_tokens = 5;\n}\n\nenum FinishReason {\n  FINISH_REASON_LENGTH = 0;\n  FINISH_REASON_EOS_TOKEN = 1;\n  FINISH_REASON_STOP_SEQUENCE = 2;\n}\n\nmessage GeneratedText {\n  /// Output\n  string text = 1;\n  /// Number of generated tokens\n  uint32 generated_tokens = 2;\n  /// Finish reason\n  FinishReason finish_reason = 3;\n  /// Seed\n  optional uint64 seed = 4;\n}\n\nmessage Tokens {\n  /// Token IDs\n  repeated uint32 ids = 1;\n  /// Logprobs\n  repeated float logprobs = 2;\n  /// tokens\n  repeated string texts = 3;\n  /// special\n  repeated bool is_special = 4;\n}\n\nmessage Generation {\n  /// Request ID\n  uint64 request_id = 1;\n  /// Prefill tokens (optional)\n  Tokens prefill_tokens = 2;\n  Tokens tokens = 3;\n  /// Complete generated text\n  optional GeneratedText generated_text = 4;\n  /// Top tokens\n  repeated Tokens top_tokens = 5;\n}\n\nmessage FilterBatchRequest {\n  /// Batch ID\n  uint64 batch_id = 1;\n  /// Requests to keep\n  repeated uint64 request_ids = 2;\n}\n\nmessage FilterBatchResponse {\n  /// Filtered Batch (cached)\n  CachedBatch batch = 1;\n}\n\nmessage PrefillRequest {\n  /// Batch\n  Batch batch = 1;\n  /// Optional cached batch\n  CachedBatch cached_batch = 2;\n}\n\nmessage PrefillResponse {\n  /// Generation\n  repeated Generation generations = 1;\n  /// Next batch (cached)\n  optional CachedBatch batch = 2;\n  /// Forward elapsed time in nanoseconds\n  uint64 forward_ns = 3;\n  /// Decode elapsed time in nanoseconds\n  uint64 decode_ns = 4;\n  /// Total elapsed time in nanoseconds\n  uint64 total_ns = 5;\n  /// Concatenate elapsed time in nanoseconds\n  optional uint64 concat_ns = 6;\n}\n\nmessage DecodeRequest {\n  /// Cached batches\n  repeated CachedBatch batches = 1;\n}\n\nmessage DecodeResponse {\n  /// Decodes\n  repeated Generation generations = 1;\n  /// Next batch (cached)\n  optional CachedBatch batch = 2;\n  /// Forward elapsed time in nanoseconds\n  uint64 forward_ns = 3;\n  /// Decode elapsed time in nanoseconds\n  uint64 decode_ns = 4;\n  /// Total elapsed time in nanoseconds\n  uint64 total_ns = 5;\n  /// Concatenate elapsed time in nanoseconds\n  optional uint64 concat_ns = 6;\n}\n\nmessage WarmupRequest {\n  /// Batch to warmup on\n  Batch batch = 1;\n  optional uint32 max_input_tokens = 2;\n  uint32 max_prefill_tokens = 3;\n  optional uint32 max_total_tokens = 4;\n}\n\nmessage WarmupResponse {\n  /// Maximum number of tokens supported by the model\n  optional uint32 max_supported_total_tokens = 1;\n  /// Maximum input tokens by clients should be equal to request value if it's set\n  /// Otherwise warmup automatically allocates a value here\n  uint32 max_input_tokens = 2;\n  /// Maximum total tokens by clients should be equal to request value if it's set\n  /// Otherwise warmup automatically allocates a value here\n  uint32 max_total_tokens = 3;\n}\n"
  },
  {
    "path": "router/Cargo.toml",
    "content": "[package]\nname = \"text-generation-router\"\ndescription = \"Text Generation Webserver\"\nbuild = \"build.rs\"\nversion.workspace = true\nedition.workspace = true\nauthors.workspace = true\nhomepage.workspace = true\n\n[dependencies]\nanyhow = \"1\"\nasync-trait = \"0.1.74\"\nasync-stream = \"0.3.5\"\naxum = { version = \"0.7\", features = [\"json\"] }\naxum-tracing-opentelemetry = \"0.16\"\nclap = { version = \"4.4.5\", features = [\"derive\", \"env\"] }\nfutures = \"0.3.28\"\nhf-hub = { workspace = true }\nitertools = \"0.10\"\njsonschema = { version = \"0.28.0\" }\nmetrics = { workspace = true }\nmetrics-exporter-prometheus = { workspace = true }\nnohash-hasher = \"0.2.0\"\nopentelemetry = { version = \"0.20.0\", features = [\"rt-tokio\"] }\nopentelemetry-otlp = \"0.13.0\"\noutlines-core = { git = \"https://github.com/dottxt-ai/outlines-core.git\", rev = \"ba10c619fc9bf3c487e43f49bdecb95a24bb465c\" }\nrand = \"0.8.5\"\nreqwest = { version = \"0.11.20\", features = [\"blocking\"] }\nserde = \"1.0.188\"\nserde_json = \"1.0.107\"\nthiserror = \"1.0.48\"\ntokenizers = { workspace = true }\ntokio = { version = \"1.32.0\", features = [\n  \"rt\",\n  \"rt-multi-thread\",\n  \"parking_lot\",\n  \"signal\",\n  \"sync\",\n] }\ntokio-stream = \"0.1.14\"\ntower-http = { version = \"0.5.1\", features = [\"cors\"] }\ntracing = \"0.1.40\"\ntracing-opentelemetry = \"0.21.0\"\ntracing-subscriber = { version = \"0.3.18\", features = [\"json\", \"env-filter\"] }\nutoipa = { version = \"4.2.0\", features = [\"axum_extras\"] }\nutoipa-swagger-ui = { version = \"6.0.0\", features = [\"axum\"] }\nngrok = { version = \"0.13.1\", features = [\"axum\"], optional = true }\ninit-tracing-opentelemetry = { version = \"0.14.1\", features = [\n  \"opentelemetry-otlp\",\n] }\nminijinja = { workspace = true, features = [\"loop_controls\"] }\nminijinja-contrib = { workspace = true }\nfutures-util = \"0.3.30\"\nregex = \"1.10.3\"\nonce_cell = \"1.19.0\"\nimage = \"0.25.1\"\nbase64 = { workspace = true }\nsysinfo = \"0.30.13\"\nuuid = { version = \"1.9.1\", default-features = false, features = [\n  \"v4\",\n  \"fast-rng\",\n  \"macro-diagnostics\",\n] }\ncsv = \"1.3.0\"\nureq = \"=2.9\"\npyo3 = { workspace = true }\nchrono = \"0.4.39\"\n\n\n[build-dependencies]\nvergen = { version = \"8.2.5\", features = [\"build\", \"git\", \"gitcl\"] }\n\n[features]\ndefault = [\"ngrok\"]\nngrok = [\"dep:ngrok\"]\ngoogle = []\nkserve = []\n"
  },
  {
    "path": "router/README.md",
    "content": "# Router\n\nAlso named `webserver` throughout the docs.\n\nThis router is handling most of the logic to handle the \"batches\" tell\nwhen to pass new `prefill` requests and pausing `decode` requests, which ones etc...\n\nIt uses gRPC to communicate with the shards which can therefore be kept\nmuch simpler and focus on having the most efficient forward passes as possible.\n\n## Continuous batching\n\nOne important feature of `text-generation-inference` is enabled\nby this `router`.\n\nContinuous batching is the act of regularly running queries in the same\n`forward` step of the LLM (a \"batch\") and also removing them when they are\nfinished.\n\nIn order for continuous batching to be useful, you need to have more compute available\nwith respect to the memory requirements of your model. This is essentially true for\nLLMs and the larger the model, the truer it gets (since you have to pool multiple\nGPUs to load the model, you effectively have a lot of compute power at your hands).\n\n\nStatic batching is the act of doing several queries at the same time, but usually\nthis is controlled by the client, and therefore the amount of batching is decided\nbeforehand.\n\nFor text-generation, and LLMs which are memory bound we can try to be much more\nefficient with the available compute, by having client sending us single queries,\nand let the router mix&match queries into or out of batches to make the use the\ncompute the most efficiently. This is possible because for LLMs the total compute\nfor running the model is much bigger than doing mix&match of the batches themselves.\n\n\n### Simple continuous batching\n\ntext-generation works by feeding a prompt to a model, and iteratively calling\n`forward` on the model to produce new text, 1 token at a time.\n\nThe first idea is simple, when a query arrives, we start working on it directly.\nWhen new queries arrive, we simply wait for the current `forward` to be finished\nthen batch the current running prompt with the new query, and call `forward`.\n\nWhenever either query is finished: either the model produce EOS (end of sentence) token\nor the query reached the allowed limit. We simply drop it from the batch, remove\nall the allocated memory and we can continue with the rest until nothing is left.\n\nThis simple idea generalizes very well and we could potentially stack many requests\nin the same batch.\n\nOne thing to note, is that queries can be potentially run with different parameters\nmeaning different way to choose the next token (sampling, not sampling, temperature, top_k etc..). This is not problematic for the proposed approach we just need to do the sampling\nindependantly on each member of the batch.\n\n### Prefill, decode and past key values\n\nIn order to make LLMs and text-generation efficient, there's actually a very powerful\ntrick that can be used, which is the \"caching\" of some attention matrices. [More on that\nin the first part of this blog](https://huggingface.co/blog/accelerated-inference#getting-to-the-first-10x-speedup)\n\nWhat this means, is that the first \"pass\" of a prompt is different from the subsequent\n\"forward\" passes. Since for the first one we have to compute the entire attention matrix, whereas in the follow-ups only require to compute the new token attention.\nThe first pass is called `prefill` throughout this codebase where as the follow-ups are called `decode`.\n\nSince `prefill` is much more expensive than `decode` we don't want to do it all the time,\nbut a currently running query is probably doing `decode`. If we want to do the continuous\nbatching as explained previously we need to run `prefill` at some point in order to create\nthe attention matrix required to be able to join the `decode` group.\n\n`text-generation-inference` uses a bunch of different strategies and parameters in\norder to enable you to find the sweet spot between exploiting the hardware and perceived latency.\n\nWith no continuous batching at all, latency is going to be super good, but throughput (meaning\nthe total number of requests allowed in a given timeframe) is going to be super bad (since it's essentially 1).\n\nWith static batching, you can probably reach the maximum throughput (by using the maximum total batch size applicable to your hardware), but the latency is super bad since in order to have maximum throughput you need to wait for requests to come in before processing.\n\nWith continuous batching you can find a sweet spot. In general latency is the most critical\nparameter users care about. But a 2x latency slowdown for 10x more users on the same\nhardware is an acceptable tradeoff.\n\n## Token streaming\n\nThis is a very important aspect of client UX. As mentionned above, latency is the\nmost critical perceived quality of an LLM API.\n\nWith token streaming, the server can start answering after the first `prefill` pass\ndirectly, without waiting for all the generation to be done. For extremely long queries\nthis means clients can start to see something happening orders of magnitude before\nthe work is done. Seeing something in progress allows them to cut short if it's not\nwhat's wanted but also it \"feels\" better.\n"
  },
  {
    "path": "router/build.rs",
    "content": "use std::error::Error;\nuse vergen::EmitBuilder;\n\nfn main() -> Result<(), Box<dyn Error>> {\n    // Try to get the git sha from the local git repository\n    if EmitBuilder::builder()\n        .fail_on_error()\n        .git_sha(false)\n        .emit()\n        .is_err()\n    {\n        // Unable to get the git sha\n        if let Ok(sha) = std::env::var(\"GIT_SHA\") {\n            // Set it from an env var\n            println!(\"cargo:rustc-env=VERGEN_GIT_SHA={sha}\");\n        }\n    }\n\n    // Set docker label if present\n    if let Ok(label) = std::env::var(\"DOCKER_LABEL\") {\n        // Set it from an env var\n        println!(\"cargo:rustc-env=DOCKER_LABEL={label}\");\n    }\n\n    Ok(())\n}\n"
  },
  {
    "path": "router/src/chat.rs",
    "content": "use crate::{\n    infer::InferError, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionDelta,\n    ChatCompletionLogprobs, CompletionType, DeltaToolCall, Function, FunctionDefinition,\n    StreamOptions, StreamResponse, TextMessage, ToolCallDelta, Usage,\n};\nuse serde::Deserialize;\nuse serde_json::Value;\n\n#[derive(Debug, Deserialize)]\nstruct ToolCall {\n    _name: String,\n    #[serde(flatten, default)]\n    /// Using Map to preserve order\n    arguments: serde_json::Map<String, Value>,\n}\n#[derive(Debug, Deserialize)]\nstruct Call {\n    function: ToolCall,\n}\n\n#[cfg_attr(test, derive(Debug))]\npub(crate) enum ChatEvent {\n    NoTool,\n    Events(Vec<CompletionType>),\n}\n\n#[cfg_attr(test, derive(Debug))]\npub(crate) enum ChatChoice {\n    NoTool,\n    ToolCalls(Vec<crate::ToolCall>),\n}\n\npub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferError> {\n    let call: Call = serde_json::from_str(generated_text).map_err(|e| {\n        InferError::ToolError(format!(\n            \"Failed to parse generated text: {} {:?}\",\n            e, generated_text\n        ))\n    })?;\n    let name = call.function._name;\n\n    match &name[..] {\n        \"no_tool\" => {\n            // parse the content message\n            Ok(ChatChoice::NoTool)\n        }\n        name => {\n            let tool_calls = vec![crate::ToolCall {\n                id: \"0\".to_string(),\n                r#type: \"function\".to_string(),\n                function: FunctionDefinition {\n                    description: None,\n                    name: name.to_string(),\n                    arguments: serde_json::to_value(call.function.arguments).map_err(|err| {\n                        InferError::ToolError(format!(\n                            \"Could not convert arguments to JSON map {err}\"\n                        ))\n                    })?,\n                },\n            }];\n            Ok(ChatChoice::ToolCalls(tool_calls))\n        }\n    }\n}\n\n/// Convert a StreamResponse into an Event to be sent over SSE\nfn create_event_from_stream_token(\n    stream_token: &StreamResponse,\n    logprobs: bool,\n    inner_using_tools: bool,\n    system_fingerprint: String,\n    model_id: String,\n    function_name: Option<String>,\n    id: String,\n) -> CompletionType {\n    let current_time = std::time::SystemTime::now()\n        .duration_since(std::time::UNIX_EPOCH)\n        .unwrap_or_else(|_| std::time::Duration::from_secs(0))\n        .as_secs();\n\n    let logprobs = logprobs.then(|| {\n        ChatCompletionLogprobs::from((stream_token.token.clone(), stream_token.top_tokens.clone()))\n    });\n\n    // replace the content with the tool calls if grammar is present\n    let content = if !stream_token.token.special {\n        Some(stream_token.token.text.clone())\n    } else {\n        None\n    };\n    let (content, tool_calls) = if inner_using_tools {\n        // Cast into a vec\n        (None, content)\n    } else {\n        (content, None)\n    };\n    let finish_reason = stream_token\n        .details\n        .as_ref()\n        .map(|details| details.finish_reason.format(true));\n    let delta = match (content, tool_calls) {\n        (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {\n            role: \"assistant\".to_string(),\n            content: delta,\n            ..Default::default()\n        }),\n        (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {\n            role: \"assistant\".to_string(),\n            tool_calls: vec![DeltaToolCall {\n                index: 0,\n                id,\n                r#type: \"function\".to_string(),\n                function: Function {\n                    name: function_name,\n                    arguments: tool_calls,\n                },\n            }],\n        }),\n        (None, None) => ChatCompletionDelta::Chat(TextMessage {\n            role: \"assistant\".to_string(),\n            content: \"\".to_string(),\n            ..Default::default()\n        }),\n    };\n    let choices = vec![ChatCompletionChoice {\n        index: 0,\n        delta,\n        logprobs,\n        finish_reason,\n    }];\n    CompletionType::ChatCompletionChunk(ChatCompletionChunk::new(\n        model_id,\n        system_fingerprint,\n        current_time,\n        choices,\n        None,\n    ))\n}\n\n#[derive(Debug)]\nenum StreamState {\n    /// Before the tools was parsed\n    Buffering,\n    /// We detected a tool call here\n    Tool,\n    /// This is without tool calling\n    Content,\n}\n\npub struct ChatState {\n    state: StreamState,\n    text: String,\n    options: StreamOptions,\n    model_id: String,\n    fingerprint: String,\n    logprobs: bool,\n    id: String,\n}\n\nimpl ChatState {\n    pub fn new(\n        using_tools: bool,\n        options: StreamOptions,\n        fingerprint: String,\n        model_id: String,\n        logprobs: bool,\n        id: String,\n    ) -> Self {\n        let state = if using_tools {\n            StreamState::Buffering\n        } else {\n            StreamState::Content\n        };\n        let text = String::new();\n        Self {\n            state,\n            text,\n            options,\n            fingerprint,\n            model_id,\n            logprobs,\n            id,\n        }\n    }\n\n    pub fn push(&mut self, mut stream_token: StreamResponse) -> ChatEvent {\n        let mut events = vec![];\n        let token_text = &stream_token.token.text;\n        match self.state {\n            StreamState::Buffering => {\n                self.text.push_str(token_text);\n                tracing::info!(\"Current text {:?}\", self.text);\n                let partial = &self.text;\n                let partial =\n                    partial.trim_end_matches(|c: char| c.is_whitespace() || c == ',' || c == '}');\n                if let Ok(call) = serde_json::from_str::<Call>(&format!(\"{}}}}}\", partial)) {\n                    // This can be no_tool before the content has been emitted\n                    if call.function._name != \"no_tool\" {\n                        stream_token.token.text = \"{\".to_string();\n                        let chat_complete = create_event_from_stream_token(\n                            &stream_token,\n                            self.logprobs,\n                            true,\n                            self.fingerprint.clone(),\n                            self.model_id.clone(),\n                            Some(call.function._name),\n                            self.id.clone(),\n                        );\n\n                        events.push(chat_complete);\n                        self.state = StreamState::Tool;\n                    } else {\n                        return ChatEvent::NoTool;\n                    }\n                }\n            }\n            StreamState::Tool => {\n                self.text.push_str(token_text);\n                if serde_json::from_str::<Call>(&self.text).is_ok() {\n                    self.state = StreamState::Buffering;\n                    let mut text = stream_token.token.text.trim_end();\n                    // Effectively trimming only the last closing brace\n                    if text.ends_with('}') {\n                        text = &text[..text.len() - 1];\n                    }\n                    stream_token.token.text = text.to_string();\n                    let chat_complete = create_event_from_stream_token(\n                        &stream_token,\n                        self.logprobs,\n                        true,\n                        self.fingerprint.clone(),\n                        self.model_id.clone(),\n                        None,\n                        self.id.clone(),\n                    );\n                    events.push(chat_complete);\n                } else {\n                    let chat_complete = create_event_from_stream_token(\n                        &stream_token,\n                        self.logprobs,\n                        true,\n                        self.fingerprint.clone(),\n                        self.model_id.clone(),\n                        None,\n                        self.id.clone(),\n                    );\n                    events.push(chat_complete);\n                }\n            }\n            StreamState::Content => {\n                let chat_complete = create_event_from_stream_token(\n                    &stream_token,\n                    self.logprobs,\n                    false,\n                    self.fingerprint.clone(),\n                    self.model_id.clone(),\n                    None,\n                    self.id.clone(),\n                );\n\n                events.push(chat_complete);\n            }\n        }\n\n        if self.options.include_usage {\n            if let Some(details) = stream_token.details {\n                let completion_tokens = details.generated_tokens;\n                let prompt_tokens = details.input_length;\n                let total_tokens = prompt_tokens + completion_tokens;\n\n                let usage = Usage {\n                    completion_tokens,\n                    prompt_tokens,\n                    total_tokens,\n                };\n                let current_time = std::time::SystemTime::now()\n                    .duration_since(std::time::UNIX_EPOCH)\n                    .unwrap_or_else(|_| std::time::Duration::from_secs(0))\n                    .as_secs();\n\n                let chat_complete = CompletionType::ChatCompletionChunk(ChatCompletionChunk {\n                    id: String::new(),\n                    created: current_time,\n                    model: self.model_id.clone(),\n                    system_fingerprint: self.fingerprint.clone(),\n                    choices: vec![],\n                    usage: Some(Usage {\n                        prompt_tokens: usage.prompt_tokens,\n                        completion_tokens: usage.completion_tokens,\n                        total_tokens: usage.total_tokens,\n                    }),\n                });\n\n                events.push(chat_complete);\n            }\n        }\n        ChatEvent::Events(events)\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use crate::{\n        ChatCompletionChoice, ChatCompletionDelta, FinishReason, StreamDetails, TextMessage, Token,\n    };\n\n    use super::*;\n\n    fn get_tool_call_content(event: &CompletionType) -> (Option<&String>, &String) {\n        match event {\n            CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {\n                assert_eq!(choices.len(), 1);\n                if let ChatCompletionChoice {\n                    delta: ChatCompletionDelta::Tool(ToolCallDelta { tool_calls, .. }),\n                    ..\n                } = &choices[0]\n                {\n                    assert_eq!(tool_calls.len(), 1);\n                    let DeltaToolCall {\n                        index,\n                        id,\n                        r#type,\n                        function,\n                    } = &tool_calls[0];\n                    assert_eq!(*index, 0);\n                    assert_eq!(id, \"0\");\n                    assert_eq!(r#type, \"function\");\n                    (function.name.as_ref(), &function.arguments)\n                } else {\n                    panic!(\"Expected plain message\");\n                }\n            }\n            _ => panic!(\"Unexpected chunk\"),\n        }\n    }\n\n    #[test]\n    fn test_chat_stream() {\n        let mut chat_state = ChatState::new(\n            false,\n            StreamOptions {\n                include_usage: false,\n            },\n            \"fingerprint\".to_string(),\n            \"model_id\".to_string(),\n            false,\n            \"0\".to_string(),\n        );\n\n        let events = chat_state.push(StreamResponse {\n            generated_text: None,\n            token: Token {\n                id: 42,\n                text: \"Hi\".to_string(),\n                logprob: 0.0,\n                special: false,\n            },\n            top_tokens: vec![],\n            index: 0,\n            details: None,\n        });\n        if let ChatEvent::Events(events) = events {\n            assert_eq!(events.len(), 1);\n            match &events[0] {\n                CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {\n                    assert_eq!(\n                        choices,\n                        &[ChatCompletionChoice {\n                            index: 0,\n                            delta: ChatCompletionDelta::Chat(TextMessage {\n                                role: \"assistant\".to_string(),\n                                content: \"Hi\".to_string(),\n                                tool_call_id: None,\n                            }),\n                            logprobs: None,\n                            finish_reason: None,\n                        }]\n                    );\n                }\n                _ => panic!(\"Unexpected chunk\"),\n            }\n        } else {\n            panic!(\"Expected chat events\");\n        }\n    }\n\n    #[test]\n    fn test_chat_stream_usage() {\n        let mut chat_state = ChatState::new(\n            false,\n            StreamOptions {\n                include_usage: true,\n            },\n            \"fingerprint\".to_string(),\n            \"model_id\".to_string(),\n            false,\n            \"0\".to_string(),\n        );\n\n        let events = chat_state.push(StreamResponse {\n            generated_text: None,\n            token: Token {\n                id: 42,\n                text: \"Hi\".to_string(),\n                logprob: 0.0,\n                special: false,\n            },\n            top_tokens: vec![],\n            index: 0,\n            details: Some(StreamDetails {\n                input_length: 2,\n                generated_tokens: 10,\n                seed: None,\n                finish_reason: FinishReason::Length,\n            }),\n        });\n        if let ChatEvent::Events(events) = events {\n            assert_eq!(events.len(), 2);\n            match &events[0] {\n                CompletionType::ChatCompletionChunk(ChatCompletionChunk { choices, .. }) => {\n                    assert_eq!(\n                        choices,\n                        &[ChatCompletionChoice {\n                            index: 0,\n                            delta: ChatCompletionDelta::Chat(TextMessage {\n                                role: \"assistant\".to_string(),\n                                content: \"Hi\".to_string(),\n                                tool_call_id: None,\n                            }),\n                            logprobs: None,\n                            // HAS A FINISH REASON\n                            finish_reason: Some(\"length\".to_string()),\n                        }]\n                    );\n                }\n                _ => panic!(\"Unexpected chunk\"),\n            }\n            match &events[1] {\n                CompletionType::ChatCompletionChunk(ChatCompletionChunk { usage, .. }) => {\n                    assert_eq!(\n                        *usage,\n                        Some(Usage {\n                            prompt_tokens: 2,\n                            completion_tokens: 10,\n                            total_tokens: 12,\n                        })\n                    );\n                }\n                _ => panic!(\"Unexpected chunk\"),\n            }\n        } else {\n            panic!(\"Expected chat events\");\n        }\n    }\n\n    #[test]\n    fn test_chat_stream_tool_no_tool_simple() {\n        let mut chat_state = ChatState::new(\n            true,\n            StreamOptions {\n                include_usage: true,\n            },\n            \"fingerprint\".to_string(),\n            \"model_id\".to_string(),\n            false,\n            \"0\".to_string(),\n        );\n\n        let tokens = vec![\n            \"{\\\"\".to_string(),\n            \"function\".to_string(),\n            \"\\\":\".to_string(),\n            \" {\\\"\".to_string(),\n            \"_\".to_string(),\n            \"name\".to_string(),\n            \"\\\":\".to_string(),\n            \" \\\"\".to_string(),\n            \"no\".to_string(),\n            \"_tool\".to_string(),\n            \"\\\",\".to_string(),\n            \" \\\"\".to_string(),\n            \"content\".to_string(),\n            \"\\\":\".to_string(),\n            \" \\\"\".to_string(),        // Token 14\n            \"I\".to_string(),          // Event 1\n            \" am\".to_string(),        // Event 2\n            \" a\".to_string(),         // Event 3\n            \" helpful\".to_string(),   // Event 4\n            \" assistant\".to_string(), // Event 5\n            \"!\\\"\".to_string(),        // Event 6 (with trailing quore removed)\n            \"}\".to_string(),\n            \"}\".to_string(),\n        ];\n        let tokens: Vec<_> = tokens\n            .into_iter()\n            .map(|text| StreamResponse {\n                generated_text: None,\n                token: Token {\n                    id: 42,\n                    text: text.to_string(),\n                    logprob: 0.0,\n                    special: false,\n                },\n                top_tokens: vec![],\n                index: 0,\n                details: None,\n            })\n            .collect();\n\n        // Initial ignored output\n        for token in &tokens[..10] {\n            let events = chat_state.push(token.clone());\n            if let ChatEvent::Events(events) = events {\n                assert_eq!(events.len(), 0, \"{events:?}\");\n            } else {\n                panic!(\"Expected chat events\");\n            }\n        }\n\n        // No tool output\n        let events = chat_state.push(tokens[10].clone());\n        if let ChatEvent::NoTool = events {\n            assert!(true);\n        } else {\n            panic!(\"Expected chat events\");\n        }\n    }\n\n    #[test]\n    fn test_chat_stream_tool_no_tool_empty() {\n        let mut chat_state = ChatState::new(\n            true,\n            StreamOptions {\n                include_usage: true,\n            },\n            \"fingerprint\".to_string(),\n            \"model_id\".to_string(),\n            false,\n            \"0\".to_string(),\n        );\n\n        let tokens = vec![\n            \"{\\\"\".to_string(),\n            \"function\".to_string(),\n            \"\\\":\".to_string(),\n            \" {\\\"\".to_string(),\n            \"_\".to_string(),\n            \"name\".to_string(),\n            \"\\\":\".to_string(),\n            \" \\\"\".to_string(),\n            \"no\".to_string(),\n            \"_tool\".to_string(),\n            \"\\\",\".to_string(),\n            \" \\\"\".to_string(),\n            \"content\".to_string(),\n            \"\\\":\\\"\".to_string(),\n            \"\\\"}\".to_string(), // Token 13\n            \"}\".to_string(),   // Event 1\n        ];\n        let tokens: Vec<_> = tokens\n            .into_iter()\n            .map(|text| StreamResponse {\n                generated_text: None,\n                token: Token {\n                    id: 42,\n                    text: text.to_string(),\n                    logprob: 0.0,\n                    special: false,\n                },\n                top_tokens: vec![],\n                index: 0,\n                details: None,\n            })\n            .collect();\n\n        // Initial ignored output\n        for token in &tokens[..10] {\n            let events = chat_state.push(token.clone());\n            if let ChatEvent::Events(events) = events {\n                assert_eq!(events.len(), 0, \"{events:?}\");\n            } else {\n                panic!(\"Expected chat events\");\n            }\n        }\n\n        // No tool output\n        let events = chat_state.push(tokens[10].clone());\n        if let ChatEvent::NoTool = events {\n            assert!(true);\n        } else {\n            panic!(\"Expected chat events\");\n        }\n    }\n\n    #[test]\n    fn test_chat_stream_tool_get_weather() {\n        let mut chat_state = ChatState::new(\n            true,\n            StreamOptions {\n                include_usage: true,\n            },\n            \"fingerprint\".to_string(),\n            \"model_id\".to_string(),\n            false,\n            \"0\".to_string(),\n        );\n\n        let tokens = vec![\n            \"{\\\"\".to_string(),\n            \"function\".to_string(),\n            \"\\\":\".to_string(),\n            \" {\\\"\".to_string(),\n            \"_\".to_string(),\n            \"name\".to_string(),\n            \"\\\":\".to_string(),\n            \" \\\"\".to_string(),\n            \"get\".to_string(),\n            \"_current\".to_string(),\n            \"_weather\".to_string(),\n            \"\\\",\".to_string(),\n            // Event 1 is the function name\n            // Event 2 is the start of the arguments \"{\"\n            \" \\\"\".to_string(),        // Event 3\n            \"location\".to_string(),   // Event 4\n            \"\\\":\".to_string(),        // Event 5\n            \" \\\"\".to_string(),        // Event 6\n            \"San\".to_string(),        // Event 7\n            \" Francisco\".to_string(), // Event 8\n            \",\".to_string(),          // Event 9\n            \" CA\".to_string(),        // Event 10\n            \"\\\",\".to_string(),        // Event 11\n            \" \\\"\".to_string(),        // Event 12\n            \"format\".to_string(),     // Event 13\n            \"\\\":\".to_string(),        // Event 14\n            \" \\\"\".to_string(),        // Event 15\n            \"c\".to_string(),          // Event 16\n            \"elsius\".to_string(),     // Event 17\n            \"\\\"}}\".to_string(),       // Event 18 retained (trailing brace removed)\n        ];\n        let tokens: Vec<_> = tokens\n            .into_iter()\n            .map(|text| StreamResponse {\n                generated_text: None,\n                token: Token {\n                    id: 42,\n                    text: text.to_string(),\n                    logprob: 0.0,\n                    special: false,\n                },\n                top_tokens: vec![],\n                index: 0,\n                details: None,\n            })\n            .collect();\n\n        // Initial ignored output\n        for token in &tokens[..11] {\n            let events = chat_state.push(token.clone());\n            if let ChatEvent::Events(events) = events {\n                assert_eq!(events.len(), 0, \"{events:?}\");\n            } else {\n                panic!(\"Expected chat events\");\n            }\n        }\n\n        // No tool output\n        let mut output = String::new();\n        let mut output_name = String::new();\n        for token in &tokens[11..11 + 17] {\n            let events = chat_state.push(token.clone());\n            if let ChatEvent::Events(events) = events {\n                assert_eq!(events.len(), 1);\n                let (name, arguments) = get_tool_call_content(&events[0]);\n                if let Some(name) = name {\n                    assert_eq!(name, \"get_current_weather\");\n                    output_name.push_str(name);\n                }\n                output.push_str(arguments);\n            } else {\n                panic!(\"Expected chat events\");\n            }\n        }\n\n        assert_eq!(output_name, \"get_current_weather\");\n        assert_eq!(\n            output,\n            \"{ \\\"location\\\": \\\"San Francisco, CA\\\", \\\"format\\\": \\\"celsius\\\"}\"\n        );\n\n        // No tool finish\n        for token in &tokens[11 + 17..] {\n            let events = chat_state.push(token.clone());\n            if let ChatEvent::Events(events) = events {\n                assert_eq!(events.len(), 0, \"{events:?}\");\n            } else {\n                panic!(\"Expected chat events\");\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "router/src/config.rs",
    "content": "use serde::{Deserialize, Serialize};\nuse std::collections::{HashMap, HashSet};\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(tag = \"model_type\")]\n#[serde(rename_all = \"snake_case\")]\npub struct LlavaNext {\n    pub(crate) text_config: TextConfig,\n    pub(crate) vision_config: VisionConfig,\n    pub(crate) image_grid_pinpoints: Vec<(usize, usize)>,\n}\n\nfn get_anyres_image_grid_shape(\n    height: usize,\n    width: usize,\n    grid_pinpoints: &[(usize, usize)],\n    patch_size: usize,\n) -> (usize, usize) {\n    let (height, width) = select_best_resolution(height, width, grid_pinpoints);\n    (height / patch_size, width / patch_size)\n}\n\n/// Selects the best resolution from a list of possible resolutions based on the original size.\n/// This is done by calculating the effective and wasted resolution for each possible resolution.\n/// The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.\nfn select_best_resolution(\n    original_height: usize,\n    original_width: usize,\n    possible_resolutions: &[(usize, usize)],\n) -> (usize, usize) {\n    let mut best_fit = None;\n    let mut max_effective_resolution = 0;\n    let mut min_wasted_resolution = f32::NEG_INFINITY;\n\n    for (height, width) in possible_resolutions {\n        let wscale = *width as f32 / original_width as f32;\n        let hscale = *height as f32 / original_height as f32;\n        // f32 partial ord.\n        let scale = if wscale > hscale { hscale } else { wscale };\n        let downscaled_width = (*width as f32 * scale) as usize;\n        let downscaled_height = (*height as f32 * scale) as usize;\n        let effective_resolution = std::cmp::min(\n            downscaled_width * downscaled_height,\n            original_width * original_height,\n        );\n        let wasted_resolution = (width * height) - effective_resolution;\n\n        if effective_resolution > max_effective_resolution\n            || (effective_resolution == max_effective_resolution\n                && (wasted_resolution as f32) < min_wasted_resolution)\n        {\n            max_effective_resolution = effective_resolution;\n            min_wasted_resolution = wasted_resolution as f32;\n            best_fit = Some((*height, *width));\n        }\n    }\n\n    best_fit.unwrap_or((original_height, original_width))\n}\n\nfn get_unpadded_features(\n    height: usize,\n    width: usize,\n    npatches: usize,\n    num_patch_height: usize,\n    num_patch_width: usize,\n) -> (usize, usize) {\n    let current_height = npatches * num_patch_height;\n    let current_width = npatches * num_patch_width;\n\n    let aspect_ratio: f64 = width as f64 / height as f64;\n    let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;\n    let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {\n        let new_height = (height * current_width) / width;\n        let padding = (current_height - new_height) / 2;\n        (current_height - (2 * padding), current_width)\n    } else {\n        let new_width = (width * current_height) / height;\n        let padding = (current_width - new_width) / 2;\n        (current_height, current_width - (2 * padding))\n    };\n\n    let unpadded_features = current_height * current_width;\n    let newline_features = current_height;\n    (unpadded_features, newline_features)\n}\n\nimpl LlavaNext {\n    pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {\n        let image_size = self.vision_config.image_size;\n        let patch_size = self.vision_config.patch_size;\n        assert!(image_size % patch_size == 0);\n        let npatches = image_size / patch_size;\n        // Dimensions are intentionally swapped to be bug-compatible with\n        // upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59\n        let (num_patch_width, num_patch_height) =\n            get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);\n\n        let (unpadded_features, newline_features) =\n            get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);\n        // The base patch covers the entire image\n        let base_features = npatches.pow(2);\n        unpadded_features + newline_features + base_features\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Llama4VisionConfig {\n    image_size: usize,\n    patch_size: usize,\n    pixel_shuffle_ratio: f64,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Llama4 {\n    text_config: TextConfig,\n    vision_config: Llama4VisionConfig,\n}\n\nfn gcd(a: usize, b: usize) -> usize {\n    if b == 0 {\n        a\n    } else {\n        gcd(b, a % b)\n    }\n}\n\nfn get_factors(dividend: usize) -> HashSet<usize> {\n    let mut factors_set = HashSet::new();\n\n    for i in 1..=((dividend as f64).sqrt() as usize) {\n        if dividend % i == 0 {\n            factors_set.insert(i);\n            factors_set.insert(dividend / i);\n        }\n    }\n\n    factors_set\n}\n\nfn find_supported_resolutions(max_num_chunks: usize, height: usize) -> Vec<(usize, usize)> {\n    let patch_size = height;\n\n    let mut asp_dict: HashMap<(usize, usize), Vec<(usize, usize)>> = HashMap::new();\n\n    for chunk_size in (1..=max_num_chunks).rev() {\n        let mut _factors: Vec<_> = get_factors(chunk_size).into_iter().collect();\n        _factors.sort();\n        let _asp_ratios: Vec<(usize, usize)> =\n            _factors.iter().map(|&f| (f, chunk_size / f)).collect();\n\n        for (h, w) in _asp_ratios {\n            let divisor = gcd(h, w);\n            let key = (h / divisor, w / divisor); // reduced aspect ratio as key\n\n            asp_dict.entry(key).or_default().push((h, w));\n        }\n    }\n\n    let mut possible_resolutions = vec![];\n\n    for (_key, value) in asp_dict {\n        for (h, w) in value {\n            possible_resolutions.push((h * patch_size, w * patch_size));\n        }\n    }\n\n    possible_resolutions\n}\n\nfn get_best_fit(\n    original_height: usize,\n    original_width: usize,\n    possible_resolutions: &[(usize, usize)],\n    resize_to_max_canvas: bool,\n) -> (usize, usize) {\n    let orig_h = original_height as f32;\n    let orig_w = original_width as f32;\n\n    let mut scales = Vec::with_capacity(possible_resolutions.len());\n\n    for &(h, w) in possible_resolutions.iter() {\n        let scale_h = h as f32 / orig_h;\n        let scale_w = w as f32 / orig_w;\n        let scale = scale_h.min(scale_w);\n        scales.push(scale);\n    }\n\n    let upscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s >= 1.0).collect();\n    let selected_scale = if !upscaling_options.is_empty() {\n        if resize_to_max_canvas {\n            upscaling_options.into_iter().fold(f32::MIN, f32::max)\n        } else {\n            upscaling_options.into_iter().fold(f32::MAX, f32::min)\n        }\n    } else {\n        let downscaling_options: Vec<f32> = scales.iter().copied().filter(|&s| s < 1.0).collect();\n        downscaling_options.into_iter().fold(f32::MIN, f32::max)\n    };\n\n    let chosen_canvas: Vec<(usize, usize)> = possible_resolutions\n        .iter()\n        .zip(scales.iter())\n        .filter(|&(_, &s)| (s - selected_scale).abs() < f32::EPSILON)\n        .map(|(&(h, w), _)| (h, w))\n        .collect();\n\n    if chosen_canvas.len() > 1 {\n        chosen_canvas\n            .into_iter()\n            .min_by_key(|(h, w)| h * w)\n            .unwrap()\n    } else {\n        chosen_canvas[0]\n    }\n}\n\nimpl Llama4 {\n    pub fn image_size(&self) -> usize {\n        self.vision_config.image_size\n    }\n\n    pub fn patch_size(&self) -> usize {\n        self.vision_config.patch_size\n    }\n\n    pub fn pixel_shuffle_ratio(&self) -> f64 {\n        self.vision_config.pixel_shuffle_ratio\n    }\n    pub fn get_aspect_ratios(\n        &self,\n        height: usize,\n        width: usize,\n        max_chunks: usize,\n    ) -> (usize, usize) {\n        let patch_size = self.vision_config.image_size;\n        let supported = find_supported_resolutions(max_chunks, patch_size);\n        let (target_h, target_w) = get_best_fit(height, width, &supported, false);\n        (target_h / patch_size, target_w / patch_size)\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct ClipVisionModel {\n    image_size: usize,\n    patch_size: usize,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Idefics3 {}\n\nimpl Idefics3 {\n    pub fn get_max_longest_edge(&self) -> usize {\n        364\n    }\n\n    pub fn get_number_of_features(&self) -> usize {\n        169\n    }\n\n    pub fn get_max_longest_edge_for_image_resize(&self) -> usize {\n        1456\n    }\n\n    pub fn get_max_image_size(&self) -> usize {\n        4096\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Idefics2 {}\n\nimpl Idefics2 {\n    pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {\n        64\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct PaliTextConfig {\n    pub(crate) num_image_tokens: usize,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Paligemma {\n    pub(crate) text_config: PaliTextConfig,\n}\n\nimpl Paligemma {\n    pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {\n        self.text_config.num_image_tokens\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Qwen2VlVisionConfig {\n    pub(crate) depth: usize,\n    pub(crate) embed_dim: usize,\n    pub(crate) mlp_ratio: usize,\n    pub(crate) num_heads: usize,\n    pub(crate) in_chans: usize,\n    pub(crate) hidden_size: usize,\n    pub(crate) patch_size: usize,\n    pub(crate) spatial_merge_size: usize,\n    pub(crate) spatial_patch_size: usize,\n    pub(crate) temporal_patch_size: usize,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Qwen2Vl {\n    pub(crate) vision_config: Qwen2VlVisionConfig,\n}\n\nimpl Qwen2Vl {\n    pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {\n        let num_pixels = height * width;\n        num_pixels / self.vision_config.patch_size.pow(2)\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Qwen2_5VlVisionConfig {\n    // pub(crate) depth: usize,\n    // pub(crate) hidden_act: String,\n    // pub(crate) hidden_size: usize,\n    // pub(crate) intermediate_size: usize,\n    // pub(crate) num_heads: usize,\n    // pub(crate) in_chans: usize,\n    // pub(crate) out_hidden_size: usize,\n    // pub(crate) patch_size: usize,\n    // pub(crate) spatial_merge_size: usize,\n    pub(crate) spatial_patch_size: usize,\n    // pub(crate) window_size: usize,\n    // pub(crate) fullatt_block_indexes: Vec<usize>,\n    // pub(crate) tokens_per_second: usize,\n    // pub(crate) temporal_patch_size: usize,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Qwen2_5Vl {\n    pub(crate) vision_config: Qwen2_5VlVisionConfig,\n}\n\nimpl Qwen2_5Vl {\n    pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {\n        let num_pixels = height * width;\n        num_pixels / self.vision_config.spatial_patch_size.pow(2)\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Gemma3VisionConfig {\n    pub(crate) image_size: usize,\n    pub(crate) patch_size: usize,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct Gemma3 {\n    vision_config: Gemma3VisionConfig,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(tag = \"model_type\")]\n#[serde(rename_all = \"snake_case\")]\npub enum Config {\n    Qwen2_5Vl(Qwen2_5Vl),\n    Qwen2Vl(Qwen2Vl),\n    LlavaNext(LlavaNext),\n    ClipVisionModel(ClipVisionModel),\n    Mistral,\n    Mamba,\n    Idefics,\n    Mllama,\n    Idefics2(Idefics2),\n    Idefics3(Idefics3),\n    Ssm,\n    GptBigcode,\n    Granite,\n    Santacoder,\n    Bloom,\n    Mpt,\n    Gpt2,\n    Gptj,\n    GptNeox,\n    Phi,\n    #[serde(rename = \"phi-msft\")]\n    PhiMsft,\n    Phi3,\n    Phimoe,\n    Llama,\n    Llama4(Llama4),\n    Baichuan,\n    Paligemma(Paligemma),\n    Gemma,\n    Gemma2,\n    Gemma3(Gemma3),\n    Gemma3Text,\n    Cohere,\n    Drbx,\n    Falcon,\n    Mixtral,\n    Starcoder2,\n    Qwen2,\n    Opt,\n    T5,\n    DeepseekV2,\n    DeepseekV3,\n    Qwen3,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct TextConfig {}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\n#[serde(rename_all = \"snake_case\")]\npub struct VisionConfig {\n    pub(crate) image_size: usize,\n    pub(crate) patch_size: usize,\n}\n\n#[cfg(test)]\nmod test {\n    use super::*;\n\n    #[test]\n    fn test_llava_next_features() {\n        let config = LlavaNext {\n            text_config: TextConfig {},\n            vision_config: VisionConfig {\n                image_size: 336,\n                patch_size: 14,\n            },\n            image_grid_pinpoints: vec![\n                (336, 672),\n                (672, 336),\n                (672, 672),\n                (1008, 336),\n                (336, 1008),\n            ],\n        };\n\n        let slots = config.get_number_of_features(20, 20);\n        assert_eq!(slots, 1176);\n        let slots = config.get_number_of_features(640, 640);\n        assert_eq!(slots, 2928);\n        let slots = config.get_number_of_features(480, 640);\n        assert_eq!(slots, 2340);\n        let slots = config.get_number_of_features(899, 1024);\n        assert_eq!(slots, 2634);\n        let slots = config.get_number_of_features(1024, 899);\n        assert_eq!(slots, 2640);\n        let slots = config.get_number_of_features(1067, 1600);\n        assert_eq!(slots, 2144);\n    }\n}\n"
  },
  {
    "path": "router/src/infer/chat_template.rs",
    "content": "use crate::infer::InferError;\nuse crate::{\n    ChatTemplateInputs, Message, MessageBody, MessageChunk, TextMessage, TokenizerConfigToken, Tool,\n};\nuse chrono::Local;\nuse minijinja::{Environment, ErrorKind, Template};\nuse minijinja_contrib::pycompat;\n\n/// Raise a exception (custom function) used in the chat templates\npub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {\n    Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))\n}\n\n/// Get the current date in a specific format (custom function), similar to `datetime.now().strftime()` in Python\npub(crate) fn strftime_now(format_str: String) -> Result<String, minijinja::Error> {\n    Ok(Local::now().format(&format_str).to_string())\n}\n\n#[derive(Debug, Clone)]\npub(crate) struct ChatTemplate {\n    template: Template<'static, 'static>,\n    bos_token: Option<String>,\n    eos_token: Option<String>,\n    use_default_tool_template: bool,\n}\n\nimpl ChatTemplate {\n    pub(crate) fn new(\n        template: String,\n        bos_token: Option<TokenizerConfigToken>,\n        eos_token: Option<TokenizerConfigToken>,\n    ) -> Self {\n        let mut env = Box::new(Environment::new());\n        // enable things like .strip() or .capitalize()\n        env.set_unknown_method_callback(pycompat::unknown_method_callback);\n\n        // TODO: replace with better solution\n        // hack to adjust gemma3 template for debug\n        // replace 'messages[0]['content'][0]['text']' with 'messages[0]['content']'\n        let mutated_template = template.replace(\n            \"messages[0]['content'][0]['text']\",\n            \"messages[0]['content']\",\n        );\n        //  Hack to fix Qwen3 templating.\n        //  It uses python notation to reverse lists, which do not exist in minijinja\n        //  so we're using the reverse filter instead.\n        let mutated_template = mutated_template.replace(\"[::-1]\", \"|reverse\");\n        // TODO: replace with a better solution\n        // Hack to remove the {% generation %} and {% endgeneration %} statements from\n        // the Jinja2 chat templates if there, since those are only using for assistant\n        // masking during training, and should be ignored during inference\n        let mutated_template = mutated_template.replace(\"{% generation %}\", \"\");\n        let mutated_template = mutated_template.replace(\"{% endgeneration %}\", \"\");\n\n        let template_str = mutated_template.into_boxed_str();\n        env.add_function(\"raise_exception\", raise_exception);\n        env.add_function(\"strftime_now\", strftime_now);\n        tracing::debug!(\"Loading template: {}\", template_str);\n\n        // leaking env and template_str as read-only, static resources for performance.\n        let template = Box::leak(env)\n            .template_from_str(Box::leak(template_str))\n            .unwrap();\n\n        // get the list of variables that are used in the template\n        let variables = template.undeclared_variables(true);\n        // check if the `tools` variable is used in the template\n        let use_default_tool_template = !variables.contains(\"tools\");\n        tracing::debug!(\"Use default tool template: {}\", use_default_tool_template);\n\n        Self {\n            template,\n            bos_token: bos_token.map(|token| token.as_str().to_string()),\n            eos_token: eos_token.map(|token| token.as_str().to_string()),\n            use_default_tool_template,\n        }\n    }\n\n    pub(crate) fn apply(\n        &self,\n        mut messages: Vec<Message>,\n        tools_and_prompt: Option<(Vec<Tool>, String)>,\n    ) -> Result<String, InferError> {\n        let tools = match tools_and_prompt {\n            Some((tools, tool_prompt)) => {\n                // check if the `tools` variable is used in the template\n                // if not, we need to append the tools to the last message\n                let text = if self.use_default_tool_template {\n                    match serde_json::to_string(&tools) {\n                        Ok(tools_str) => format!(\"\\n---\\n{}\\n{}\", tools_str, tool_prompt),\n                        Err(e) => return Err(InferError::ToolError(e.to_string())),\n                    }\n                } else {\n                    // if the `tools` variable is used in the template, we just append the tool_prompt\n                    format!(\"\\n---\\n{}\", tool_prompt)\n                };\n                if let Some(last_message) = messages.last_mut() {\n                    if let MessageBody::Content { content } = &mut last_message.body {\n                        content.push(MessageChunk::Text { text });\n                    }\n                }\n                Some(tools)\n            }\n            None => None,\n        };\n\n        let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();\n        let final_message = messages.last().cloned();\n        let mut rendered_template = self\n            .template\n            .render(ChatTemplateInputs {\n                messages,\n                bos_token: self.bos_token.as_deref(),\n                eos_token: self.eos_token.as_deref(),\n                add_generation_prompt: true,\n                tools,\n            })\n            .map_err(InferError::TemplateError)?;\n\n        // if the last message is from the assistant, continue the generation prompt\n        rendered_template = match final_message {\n            Some(msg) if msg.role == \"assistant\" => {\n                match rendered_template.rfind(msg.content.as_str()) {\n                    // implementation based on feature in transformers pipeline\n                    // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418\n                    Some(index) => rendered_template[..index + msg.content.len()]\n                        .trim_end()\n                        .to_string(),\n                    None => rendered_template,\n                }\n            }\n            _ => rendered_template,\n        };\n\n        Ok(rendered_template)\n    }\n}\n\n// tests\n#[cfg(test)]\nmod tests {\n    use crate::infer::chat_template::{raise_exception, strftime_now};\n    use crate::infer::ChatTemplate;\n    use crate::{\n        ChatTemplateInputs, Message, MessageBody, MessageChunk, MessageContent, TextMessage,\n        TokenizerConfigToken, Tool, Url,\n    };\n    use chrono::Local;\n    use minijinja::Environment;\n\n    #[test]\n    fn test_chat_template() {\n        let env = Environment::new();\n\n        let source = r#\"\n        {% for message in messages %}\n            {% if message['role'] == 'system' %}\n                {% if message['content']%}\n                    {{'### System:\\n' + message['content']+'\\n\\n'}}\n                {% endif %}\n            {% elif message['role'] == 'user' %}\n                {{'### User:\\n' + message['content']+'\\n\\n'}}\n            {% elif message['role'] == 'assistant' %}\n                {{'### Assistant:\\n'  + message['content']}}\n            {% endif %}\n            {% if loop.last and add_generation_prompt %}\n                {{ '### Assistant:\\n' }}\n            {% endif %}\n        {% endfor %}\"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n\n        assert_eq!(\n            result,\n            \"### User:\\nHi!\\n\\n### Assistant:\\nHello how can I help?### User:\\nWhat is Deep Learning?\\n\\n### Assistant:\\nmagic!### Assistant:\\n\"\n        );\n    }\n\n    #[test]\n    fn test_chat_template_with_tool_response() {\n        let env = Environment::new();\n\n        // template modified from Llama-3.1-8B-Instruct\n        // https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/0e9e39f249a16976918f6564b8830bc894c89659/tokenizer_config.json#L2053\n        // the main change is accesing `message.tool_call_id` from the messages\n        let source = r#\"\n        {{- bos_token }}\n        {%- if custom_tools is defined %}\n            {%- set tools = custom_tools %}\n        {%- endif %}\n        {%- if not tools_in_user_message is defined %}\n            {%- set tools_in_user_message = true %}\n        {%- endif %}\n        {%- if not date_string is defined %}\n            {%- set date_string = \"26 Jul 2024\" %}\n        {%- endif %}\n        {%- if not tools is defined %}\n            {%- set tools = none %}\n        {%- endif %}\n\n        {#- This block extracts the system message, so we can slot it into the right place. #}\n        {%- if messages[0]['role'] == 'system' %}\n            {%- set system_message = messages[0]['content']|trim %}\n            {%- set messages = messages[1:] %}\n        {%- else %}\n            {%- set system_message = \"\" %}\n        {%- endif %}\n\n        {#- System message + builtin tools #}\n        {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n        {%- if builtin_tools is defined or tools is not none %}\n            {{- \"Environment: ipython\\n\" }}\n        {%- endif %}\n        {%- if builtin_tools is defined %}\n            {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n        {%- endif %}\n        {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n        {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n        {%- if tools is not none and not tools_in_user_message %}\n            {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n            {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n            {{- \"Do not use variables.\\n\\n\" }}\n            {%- for t in tools %}\n                {{- t | tojson(indent=4) }}\n                {{- \"\\n\\n\" }}\n            {%- endfor %}\n        {%- endif %}\n        {{- system_message }}\n        {{- \"<|eot_id|>\" }}\n\n        {#- Custom tools are passed in a user message with some extra guidance #}\n        {%- if tools_in_user_message and not tools is none %}\n            {#- Extract the first user message so we can plug it in here #}\n            {%- if messages | length != 0 %}\n                {%- set first_user_message = messages[0]['content']|trim %}\n                {%- set messages = messages[1:] %}\n            {%- else %}\n                {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n        {%- endif %}\n            {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n            {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n            {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n            {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n            {{- \"Do not use variables.\\n\\n\" }}\n            {%- for t in tools %}\n                {{- t | tojson(indent=4) }}\n                {{- \"\\n\\n\" }}\n            {%- endfor %}\n            {{- first_user_message + \"<|eot_id|>\"}}\n        {%- endif %}\n\n        {%- for message in messages %}\n            {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n                {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n            {%- elif 'tool_calls' in message %}\n                {%- if not message.tool_calls|length == 1 %}\n                    {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n                {%- endif %}\n                {%- set tool_call = message.tool_calls[0].function %}\n                {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n                    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n                    {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n                    {%- for arg_name, arg_val in tool_call.arguments | items %}\n                        {{- arg_name + '=\"' + arg_val + '\"' }}\n                        {%- if not loop.last %}\n                            {{- \", \" }}\n                        {%- endif %}\n                        {%- endfor %}\n                    {{- \")\" }}\n                {%- else  %}\n                    {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n                    {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n                    {{- '\"parameters\": ' }}\n                    {{- tool_call.arguments | tojson }}\n                    {{- \"}\" }}\n                {%- endif %}\n                {%- if builtin_tools is defined %}\n                    {#- This means we're in ipython mode #}\n                    {{- \"<|eom_id|>\" }}\n                {%- else %}\n                    {{- \"<|eot_id|>\" }}\n                {%- endif %}\n            {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n                {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n                    {{- \"TOOL CALL ID: \" + message.tool_call_id + \"\\n\\n\" }}\n                {%- if message.content is mapping or message.content is iterable %}\n                    {{- message.content | tojson }}\n                {%- else %}\n                    {{- message.content }}\n                {%- endif %}\n                {{- \"<|eot_id|>\" }}\n            {%- endif %}\n        {%- endfor %}\n        {%- if add_generation_prompt %}\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n        {%- endif %}\n        \"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: r#\"[ { \"id\": \"0\", \"function\": { \"arguments\": '{\"longitude\": 2.2945, \"latitude\": 48.8567}', \"name\": \"get_weather\", \"description\": None, }, \"type\": \"function\", } ]\"#.to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"tool\".to_string(),\n                    content: \"6.7\".to_string(),\n                    tool_call_id: Some(\"0\".to_string()),\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n\n        assert_eq!(\n            result,\n            r#\"[BOS]<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHi!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n[ { \"id\": \"0\", \"function\": { \"arguments\": '{\"longitude\": 2.2945, \"latitude\": 48.8567}', \"name\": \"get_weather\", \"description\": None, }, \"type\": \"function\", } ]<|eot_id|><|start_header_id|>ipython<|end_header_id|>\n\nTOOL CALL ID: 0\n\n\"6.7\"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n\"#\n        );\n    }\n\n    #[test]\n    fn test_chat_template_loop_controls() {\n        // some chat templates as e.g. CohereForAI/c4ai-command-r7b-12-202 contain `break`\n        // statements in their chat templates, so the feature `loop_controls` has been included\n        // in `minijinja`\n        let env = Environment::new();\n\n        let source = r#\"\n        {% set user_count = 0 %}\n        {% for message in messages %}\n            {% if message['role'] == 'user' %}\n                {{'### User:\\n' + message['content']+'\\n\\n'}}\n                {% set user_count = user_count + 1 %}\n                {% if user_count >= 2 %}\n                    {% break %}\n                {% endif %}\n            {% elif message['role'] == 'assistant' %}\n                {{'### Assistant:\\n'  + message['content']}}\n            {% endif %}\n        {% endfor %}\n        {% if add_generation_prompt %}\n            {{ '### Assistant:\\n' }}\n        {% endif %}\"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n\n        assert_eq!(\n            result,\n            \"### User:\\nHi!\\n\\n### Assistant:\\nHello how can I help?### User:\\nWhat is Deep Learning?\\n\\n### Assistant:\\n\"\n        );\n    }\n\n    #[test]\n    fn test_chat_template_invalid_with_raise() {\n        let mut env = Environment::new();\n        env.add_function(\"raise_exception\", raise_exception);\n        env.add_function(\"strftime_now\", strftime_now);\n\n        let source = r#\"\n        {{ bos_token }}\n        {% for message in messages %}\n        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n        {% endif %}\n        {% if message['role'] == 'user' %}\n        {{ '[INST] ' + message['content'] + ' [/INST]' }}\n        {% elif message['role'] == 'assistant' %}\n        {{ message['content'] + eos_token}}\n        {% else %}\n        {{ raise_exception('Only user and assistant roles are supported!') }}\n        {% endif %}\n        {% endfor %}\"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi again!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();\n\n        match result {\n            Ok(_) => panic!(\"Should have failed\"),\n            Err(e) => {\n                assert_eq!(\n                    e.detail().unwrap(),\n                    \"Conversation roles must alternate user/assistant/user/assistant/...\"\n                );\n            }\n        }\n    }\n\n    #[test]\n    fn test_chat_template_valid_with_raise() {\n        let mut env = Environment::new();\n        env.add_function(\"raise_exception\", raise_exception);\n        env.add_function(\"strftime_now\", strftime_now);\n\n        let source = r#\"\n        {{ bos_token }}\n        {% for message in messages %}\n        {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n        {% endif %}\n        {% if message['role'] == 'user' %}\n        {{ '[INST] ' + message['content'] + ' [/INST]' }}\n        {% elif message['role'] == 'assistant' %}\n        {{ message['content'] + eos_token}}\n        {% else %}\n        {{ raise_exception('Only user and assistant roles are supported!') }}\n        {% endif %}\n        {% endfor %}\"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n        assert_eq!(result, \"[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]\");\n    }\n\n    #[test]\n    fn test_chat_template_valid_with_strftime_now() {\n        let mut env = Environment::new();\n        env.add_function(\"raise_exception\", raise_exception);\n        env.add_function(\"strftime_now\", strftime_now);\n\n        let source = r#\"\n        {% set today = strftime_now(\"%Y-%m-%d\") %}\n        {% set default_system_message = \"The current date is \" + today + \".\" %}\n        {{ bos_token }}\n        {% if messages[0]['role'] == 'system' %}\n            { set system_message = messages[0]['content'] %}\n            {%- set loop_messages = messages[1:] %}\n        {% else %}\n            {%- set system_message = default_system_message %}\n            {%- set loop_messages = messages %}\n        {% endif %}\n        {{ '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}\n        {% for message in loop_messages %}\n            {% if message['role'] == 'user' %}\n                {{ '[INST]' + message['content'] + '[/INST]' }}\n            {% elif message['role'] == 'assistant' %}\n                {{ message['content'] + eos_token }}\n            {% else %}\n                {{ raise_exception('Only user and assistant roles are supported!') }}\n            {% endif %}\n        {% endfor %}\n        \"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let current_date = Local::now().format(\"%Y-%m-%d\").to_string();\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n        assert_eq!(result, format!(\"[BOS][SYSTEM_PROMPT]The current date is {}.[/SYSTEM_PROMPT][INST]Hi![/INST]Hello how can I help?[EOS][INST]What is Deep Learning?[/INST]magic![EOS]\", current_date));\n    }\n\n    #[test]\n    fn test_chat_template_valid_with_add_generation_prompt() {\n        let mut env = Environment::new();\n        env.add_function(\"raise_exception\", raise_exception);\n        env.add_function(\"strftime_now\", strftime_now);\n\n        let source = r#\"\n        {% for message in messages %}\n        {{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}\n        {% endfor %}\n        {% if add_generation_prompt %}\n            {{ '<|im_start|>assistant\\n' }}\n        {% endif %}\"#;\n\n        // trim all the whitespace\n        let source = source\n            .lines()\n            .map(|line| line.trim())\n            .collect::<Vec<&str>>()\n            .join(\"\");\n\n        let tmpl = env.template_from_str(&source);\n\n        let chat_template_inputs = ChatTemplateInputs {\n            messages: vec![\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"Hi!\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"Hello how can I help?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"user\".to_string(),\n                    content: \"What is Deep Learning?\".to_string(),\n                    ..Default::default()\n                },\n                TextMessage {\n                    role: \"assistant\".to_string(),\n                    content: \"magic!\".to_string(),\n                    ..Default::default()\n                },\n            ],\n            bos_token: Some(\"[BOS]\"),\n            eos_token: Some(\"[EOS]\"),\n            add_generation_prompt: true,\n            ..Default::default()\n        };\n\n        let result = tmpl.unwrap().render(chat_template_inputs).unwrap();\n        assert_eq!(result, \"<|im_start|>user\\nHi!<|im_end|>\\n<|im_start|>assistant\\nHello how can I help?<|im_end|>\\n<|im_start|>user\\nWhat is Deep Learning?<|im_end|>\\n<|im_start|>assistant\\nmagic!<|im_end|>\\n<|im_start|>assistant\\n\");\n    }\n\n    struct ChatTemplateTestItem {\n        name: &'static str,\n        chat_template: &'static str,\n        input: ChatTemplateInputs<'static>,\n        target: &'static str,\n    }\n\n    #[test]\n    fn test_many_chat_templates() {\n        let example_chat = vec![\n            TextMessage {\n                role: \"user\".to_string(),\n                content: \"Hello, how are you?\".to_string(),\n                ..Default::default()\n            },\n            TextMessage {\n                role: \"assistant\".to_string(),\n                content: \"I'm doing great. How can I help you today?\".to_string(),\n                ..Default::default()\n            },\n            TextMessage {\n                role: \"user\".to_string(),\n                content: \"I'd like to show off how chat templating works!\".to_string(),\n                ..Default::default()\n            },\n        ];\n\n        let example_chat_with_system = [TextMessage {\n            role: \"system\".to_string(),\n            content: \"You are a friendly chatbot who always responds in the style of a pirate\"\n                .to_string(),\n            ..Default::default()\n        }]\n        .iter()\n        .chain(&example_chat)\n        .cloned()\n        .collect::<Vec<_>>();\n\n        let test_default_templates = vec![\n            ChatTemplateTestItem {\n                name: \"_base\",\n                chat_template: \"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\\\n' + message['content'] + '<|im_end|>' + '\\\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\\\n' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"\"),\n                    ..Default::default()\n                },\n                target: \"<|im_start|>user\\nHello, how are you?<|im_end|>\\n<|im_start|>assistant\\nI'm doing great. How can I help you today?<|im_end|>\\n<|im_start|>user\\nI'd like to show off how chat templating works!<|im_end|>\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"blenderbot\",\n                chat_template: \"{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \" Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"blenderbot_small\",\n                chat_template: \"{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ '  ' }}{% endif %}{% endfor %}{{ eos_token }}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \" Hello, how are you?  I'm doing great. How can I help you today?   I'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"bloom\",\n                chat_template: \"{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"gpt_neox\",\n                chat_template: \"{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"<|endoftext|>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>\",\n            },\n            ChatTemplateTestItem {\n                name: \"gpt2\",\n                chat_template: \"{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"<|endoftext|>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>\",\n            },\n            ChatTemplateTestItem {\n                name: \"llama\",\n                // NOTE: the `.strip()` has been replaced with `| trim` in the following template\n                chat_template: \"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token +'[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content | trim + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat_with_system.clone(),\n                    add_generation_prompt: true,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s>[INST] <<SYS>>\\nYou are a friendly chatbot who always responds in the style of a pirate\\n<</SYS>>\\n\\nHello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]\",\n            },\n            ChatTemplateTestItem {\n                name: \"whisper\",\n                chat_template: \"{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: true,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"<|endoftext|>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>\",\n            },\n        ];\n\n        #[allow(unused_variables)] // name is unused\n        for ChatTemplateTestItem {\n            name,\n            chat_template,\n            input,\n            target,\n        } in test_default_templates\n        {\n            let mut env = Environment::new();\n            env.add_function(\"raise_exception\", raise_exception);\n            env.add_function(\"strftime_now\", strftime_now);\n            let tmpl = env.template_from_str(chat_template);\n            let result = tmpl.unwrap().render(input).unwrap();\n            assert_eq!(result, target);\n        }\n\n        let test_custom_templates = vec![\n            ChatTemplateTestItem {\n                name: \"HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)\",\n                chat_template: \"{% for message in messages %}\\n{% if message['role'] == 'user' %}\\n{{ '<|user|>\\\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'system' %}\\n{{ '<|system|>\\\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'assistant' %}\\n{{ '<|assistant|>\\\\n'  + message['content'] + eos_token }}\\n{% endif %}\\n{% if loop.last and add_generation_prompt %}\\n{{ '<|assistant|>' }}\\n{% endif %}\\n{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat_with_system.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<|system|>\\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\\nHello, how are you?</s><|assistant|>\\nI'm doing great. How can I help you today?</s><|user|>\\nI'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)\",\n                chat_template: \"{% for message in messages %}\\n{% if message['role'] == 'user' %}\\n{{ '<|user|>\\\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'system' %}\\n{{ '<|system|>\\\\n' + message['content'] + eos_token }}\\n{% elif message['role'] == 'assistant' %}\\n{{ '<|assistant|>\\\\n'  + message['content'] + eos_token }}\\n{% endif %}\\n{% if loop.last and add_generation_prompt %}\\n{{ '<|assistant|>' }}\\n{% endif %}\\n{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: vec![\n                        TextMessage {\n                            role: \"system\".to_string(),\n                            content: \"You are a friendly chatbot who always responds in the style of a pirate\".to_string(),\n                            ..Default::default()\n                        },\n                        TextMessage {\n                            role: \"user\".to_string(),\n                            content: \"How many helicopters can a human eat in one sitting?\".to_string(),\n                            ..Default::default()\n                        },\n                    ],\n                    add_generation_prompt: true,\n                    bos_token: Some(\"\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<|system|>\\nYou are a friendly chatbot who always responds in the style of a pirate</s><|user|>\\nHow many helicopters can a human eat in one sitting?</s><|assistant|>\",\n            },\n            ChatTemplateTestItem {\n                name: \"HuggingFaceH4/zephyr-7b-gemma-v0.1\",\n                chat_template: \"{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\\\n' + message['content'] + '<|im_end|>' + '\\\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<bos>\"),\n                    eos_token: Some(\"<eos>\"),\n                    ..Default::default()\n                },\n                target: \"<bos><|im_start|>user\\nHello, how are you?<|im_end|>\\n<|im_start|>assistant\\nI'm doing great. How can I help you today?<|im_end|>\\n<|im_start|>user\\nI'd like to show off how chat templating works!<|im_end|>\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"mistralai/Mistral-7B-Instruct-v0.1\",\n                chat_template: \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]\",\n            },\n            ChatTemplateTestItem {\n                name: \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n                chat_template: \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s>[INST] I'd like to show off how chat templating works! [/INST]\",\n            },\n            ChatTemplateTestItem {\n                name: \"cognitivecomputations/dolphin-2.5-mixtral-8x7b\",\n                chat_template: \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\\\n' + message['content'] + '<|im_end|>' + '\\\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\\\n' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<|im_start|>user\\nHello, how are you?<|im_end|>\\n<|im_start|>assistant\\nI'm doing great. How can I help you today?<|im_end|>\\n<|im_start|>user\\nI'd like to show off how chat templating works!<|im_end|>\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"openchat/openchat-3.5-0106\",\n                // `.title()` has been replaced with `| upper` in the following template\n                chat_template: \"{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + (message['role'] | title) + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s>GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>\",\n            },\n            ChatTemplateTestItem {\n                name: \"upstage/SOLAR-10.7B-Instruct-v1.0\",\n                chat_template: \"{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you?</s>I'm doing great. How can I help you today?</s>I'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"codellama/CodeLlama-70b-Instruct-hf\",\n                // NOTE: `.strip()` has been replaced with `| trim` in the following template\n                chat_template: \"{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\\\\n\\\\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\\\\nDestination: user\\\\n\\\\n '}}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s>Source: user\\n\\n Hello, how are you? <step> Source: assistant\\n\\n I'm doing great. How can I help you today? <step> Source: user\\n\\n I'd like to show off how chat templating works! <step> Source: assistant\\nDestination: user\\n\\n \",\n            },\n            ChatTemplateTestItem {\n                name: \"Deci/DeciLM-7B-instruct\",\n                chat_template: \"{% for message in messages %}\\n{% if message['role'] == 'user' %}\\n{{ '### User:\\\\n' + message['content'] }}\\n{% elif message['role'] == 'system' %}\\n{{ '### System:\\\\n' + message['content'] }}\\n{% elif message['role'] == 'assistant' %}\\n{{ '### Assistant:\\\\n'  + message['content'] }}\\n{% endif %}\\n{% if loop.last and add_generation_prompt %}\\n{{ '### Assistant:' }}\\n{% endif %}\\n{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"### User:\\nHello, how are you?### Assistant:\\nI'm doing great. How can I help you today?### User:\\nI'd like to show off how chat templating works!\",\n            },\n            ChatTemplateTestItem {\n                name: \"Qwen/Qwen1.5-72B-Chat\",\n                chat_template: \"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\\\\nYou are a helpful assistant<|im_end|>\\\\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\\\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\\\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\\\n' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<|im_start|>system\\nYou are a helpful assistant<|im_end|>\\n<|im_start|>user\\nHello, how are you?<|im_end|>\\n<|im_start|>assistant\\nI'm doing great. How can I help you today?<|im_end|>\\n<|im_start|>user\\nI'd like to show off how chat templating works!\",\n            },\n            ChatTemplateTestItem {\n                name: \"deepseek-ai/deepseek-llm-7b-chat\",\n                chat_template: \"{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\\\\n\\\\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<｜begin▁of▁sentence｜>\"),\n                    eos_token: Some(\"<｜end▁of▁sentence｜>\"),\n                    ..Default::default()\n                },\n                target: \"<｜begin▁of▁sentence｜>User: Hello, how are you?\\n\\nAssistant: I'm doing great. How can I help you today?<｜end▁of▁sentence｜>User: I'd like to show off how chat templating works!\\n\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"h2oai/h2o-danube-1.8b-chat\",\n                chat_template: \"{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>'  + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<|prompt|>Hello, how are you?</s><|answer|>I'm doing great. How can I help you today?</s><|prompt|>I'd like to show off how chat templating works!</s>\",\n            },\n            ChatTemplateTestItem {\n                name: \"internlm/internlm2-chat-7b\",\n                chat_template: \"{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\\\\n' + message['content'] + '<|im_end|>' + '\\\\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\\\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"<s><|im_start|>user\\nHello, how are you?<|im_end|>\\n<|im_start|>assistant\\nI'm doing great. How can I help you today?<|im_end|>\\n<|im_start|>user\\nI'd like to show off how chat templating works!<|im_end|>\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"TheBloke/deepseek-coder-33B-instruct-AWQ\",\n                chat_template: \"{%- set found_item = false -%}\\n{%- for message in messages -%}\\n    {%- if message['role'] == 'system' -%}\\n        {%- set found_item = true -%}\\n    {%- endif -%}\\n{%- endfor -%}\\n{%- if not found_item -%}\\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\\\n'}}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if message['role'] == 'system' %}\\n{{ message['content'] }}\\n    {%- else %}\\n        {%- if message['role'] == 'user' %}\\n{{'### Instruction:\\\\n' + message['content'] + '\\\\n'}}\\n        {%- else %}\\n{{'### Response:\\\\n' + message['content'] + '\\\\n<|EOT|>\\\\n'}}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{{'### Response:\\\\n'}}\\n\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<｜begin▁of▁sentence｜>\"),\n                    eos_token: Some(\"<|EOT|>\"),\n                    ..Default::default()\n                },\n                target: \"You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n### Instruction:\\nHello, how are you?\\n### Response:\\nI'm doing great. How can I help you today?\\n<|EOT|>\\n### Instruction:\\nI'd like to show off how chat templating works!\\n### Response:\\n\",\n            },\n            ChatTemplateTestItem {\n                name: \"ericzzz/falcon-rw-1b-chat\",\n                // `.strip()` has been replaced with `| trim` in the following template\n                chat_template: \"{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'] | trim }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim }}{% elif message['role'] == 'assistant' %}{{ '[RESP] '  + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<|endoftext|>\"),\n                    eos_token: Some(\"<|endoftext|>\"),\n                    ..Default::default()\n                },\n                target: \"[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!\",\n            },\n            ChatTemplateTestItem {\n                name: \"abacusai/Smaug-34B-v0.1\",\n                chat_template: \"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"Hello, how are you? [/INST] I'm doing great. How can I help you today? </s><s>[INST] I'd like to show off how chat templating works! [/INST]\",\n            },\n            ChatTemplateTestItem {\n                name: \"maywell/Synatra-Mixtral-8x7B\",\n                chat_template: \"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\\n{% elif message['role'] == 'assistant' %}### Response:\\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\\n{% endif %}\\n{% endfor %}\\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\\n### Response:\\n{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:Hello, how are you?### Response:I'm doing great. How can I help you today?### Instruction:I'd like to show off how chat templating works!\",\n            },\n            ChatTemplateTestItem {\n                name: \"deepseek-ai/deepseek-coder-33b-instruct\",\n                chat_template: \"{% if not add_generation_prompt is defined %}\\n{% set add_generation_prompt = false %}\\n{% endif %}\\n{%- set ns = namespace(found=false) -%}\\n{%- for message in messages -%}\\n    {%- if message['role'] == 'system' -%}\\n        {%- set ns.found = true -%}\\n    {%- endif -%}\\n{%- endfor -%}\\n{{bos_token}}{%- if not ns.found -%}\\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\\\n'}}\\n{%- endif %}\\n{%- for message in messages %}\\n    {%- if message['role'] == 'system' %}\\n{{ message['content'] }}\\n    {%- else %}\\n        {%- if message['role'] == 'user' %}\\n{{'### Instruction:\\\\n' + message['content'] + '\\\\n'}}\\n        {%- else %}\\n{{'### Response:\\\\n' + message['content'] + '\\\\n<|EOT|>\\\\n'}}\\n        {%- endif %}\\n    {%- endif %}\\n{%- endfor %}\\n{% if add_generation_prompt %}\\n{{'### Response:'}}\\n{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<｜begin▁of▁sentence｜>\"),\n                    eos_token: Some(\"</EOT>\"),\n                    ..Default::default()\n                },\n                target: \"<｜begin▁of▁sentence｜>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n### Instruction:\\nHello, how are you?\\n### Response:\\nI'm doing great. How can I help you today?\\n<|EOT|>\\n### Instruction:\\nI'd like to show off how chat templating works!\\n\",\n            },\n            // NOT INCLUDED\n            // - meetkai/functionary-medium-v3.2\n            // - fireworks-ai/firefunction-v1\n            // https://github\n            ChatTemplateTestItem {\n                name: \"maywell/PiVoT-MoE\",\n                chat_template: \"{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}\",\n                input: ChatTemplateInputs {\n                    messages: example_chat_with_system.clone(),\n                    add_generation_prompt: false,\n                    bos_token: Some(\"<s>\"),\n                    eos_token: Some(\"</s>\"),\n                    ..Default::default()\n                },\n                target: \"You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!\",\n            },\n        ];\n\n        #[allow(unused_variables)] // name is unused\n        for ChatTemplateTestItem {\n            name,\n            chat_template,\n            input,\n            target,\n        } in test_custom_templates\n        {\n            let mut env = Environment::new();\n            env.add_function(\"raise_exception\", raise_exception);\n            env.add_function(\"strftime_now\", strftime_now);\n            // trim all the whitespace\n            let chat_template = chat_template\n                .lines()\n                .map(|line| line.trim())\n                .collect::<Vec<&str>>()\n                .join(\"\");\n\n            let tmpl = env.template_from_str(&chat_template);\n            let result = tmpl.unwrap().render(input).unwrap();\n            assert_eq!(result, target);\n        }\n    }\n\n    #[test]\n    fn test_chat_template_with_default_tool_template() {\n        let ct = ChatTemplate::new(\n            \"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}\".to_string(),\n            Some(TokenizerConfigToken::String(\"<s>\".to_string())),\n            Some(TokenizerConfigToken::String(\"</s>\".to_string())),\n        );\n\n        // convert TextMessage to Message\n        let msgs: Vec<Message> = vec![\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\n                        \"I'd like to show off how chat templating works!\".to_string(),\n                    ),\n                },\n            },\n            Message {\n                name: None,\n                role: \"assistant\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\n                        \"Great! How can I help you today?\".to_string(),\n                    ),\n                },\n            },\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\"Just testing\".to_string()),\n                },\n            },\n        ];\n        let tools_string = r#\"[{\"type\": \"function\",\"function\": {\"name\": \"get_current_weather\",\"description\": \"Get the current weather\",\"parameters\": {\"type\": \"object\",\"properties\": {\"location\": {\"type\": \"string\",\"description\": \"The city and state, e.g. San Francisco, CA\"},\"format\": {\"type\": \"string\",\"enum\": [\"celsius\", \"fahrenheit\"],\"description\": \"The temperature unit to use. Infer this from the users location.\"}},\"required\": [\"location\", \"format\"]}}}]\"#.to_string();\n        let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();\n        let tool_prompt = \"This default prompt will be used\".to_string();\n        let tools_and_prompt = Some((tools, tool_prompt));\n        let result = ct.apply(msgs, tools_and_prompt);\n        let expected = \"<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\\n---\\n[{\\\"type\\\":\\\"function\\\",\\\"function\\\":{\\\"description\\\":\\\"Get the current weather\\\",\\\"name\\\":\\\"get_current_weather\\\",\\\"arguments\\\":\\\"{\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\",\\\\\\\"properties\\\\\\\":{\\\\\\\"location\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\",\\\\\\\"description\\\\\\\":\\\\\\\"The city and state, e.g. San Francisco, CA\\\\\\\"},\\\\\\\"format\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\",\\\\\\\"enum\\\\\\\":[\\\\\\\"celsius\\\\\\\",\\\\\\\"fahrenheit\\\\\\\"],\\\\\\\"description\\\\\\\":\\\\\\\"The temperature unit to use. Infer this from the users location.\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"location\\\\\\\",\\\\\\\"format\\\\\\\"]}\\\"}}]\\nThis default prompt will be used [/INST]\".to_string();\n        assert_eq!(result.unwrap(), expected);\n    }\n\n    #[test]\n    fn test_chat_template_with_custom_tool_template() {\n        // chat template from meta-llama/Meta-Llama-3.1-8B-Instruct\n        let ct = ChatTemplate::new(\n            \"{{- bos_token }}\\n{%- if not tools_in_user_message is defined %}\\n    {%- set tools_in_user_message = true %}\\n{%- endif %}\\n{%- if not date_string is defined %}\\n    {%- set date_string = \\\"26 Jul 2024\\\" %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n    {%- set tools = none %}\\n{%- endif %}\\n\\n{#- This block extracts the system message, so we can slot it into the right place. #}\\n{%- if messages[0]['role'] == 'system' %}\\n    {%- set system_message = messages[0]['content']|trim %}\\n    {%- set messages = messages[1:] %}\\n{%- else %}\\n    {%- set system_message = \\\"\\\" %}\\n{%- endif %}\\n\\n{#- System message + builtin tools #}\\n{{- \\\"<|start_header_id|>system<|end_header_id|>\\\\n\\\\n\\\" }}\\n{%- if builtin_tools is defined or tools is not none %}\\n    {{- \\\"Environment: ipython\\\\n\\\" }}\\n{%- endif %}\\n{%- if builtin_tools is defined %}\\n    {{- \\\"Tools: \\\" + builtin_tools | reject('equalto', 'code_interpreter') | join(\\\", \\\") + \\\"\\\\n\\\\n\\\"}}\\n{%- endif %}\\n{{- \\\"Cutting Knowledge Date: December 2023\\\\n\\\" }}\\n{{- \\\"Today Date: \\\" + date_string + \\\"\\\\n\\\\n\\\" }}\\n{%- if tools is not none and not tools_in_user_message %}\\n    {{- \\\"You have access to the following functions. To call a function, please respond with JSON for a function call.\\\" }}\\n    {{- 'Respond in the format {\\\"name\\\": function name, \\\"parameters\\\": dictionary of argument name and its value}.' }}\\n    {{- \\\"Do not use variables.\\\\n\\\\n\\\" }}\\n    {%- for t in tools %}\\n        {{- t | tojson(indent=4) }}\\n        {{- \\\"\\\\n\\\\n\\\" }}\\n    {%- endfor %}\\n{%- endif %}\\n{{- system_message }}\\n{{- \\\"<|eot_id|>\\\" }}\\n\\n{#- Custom tools are passed in a user message with some extra guidance #}\\n{%- if tools_in_user_message and not tools is none %}\\n    {#- Extract the first user message so we can plug it in here #}\\n    {%- if messages | length != 0 %}\\n        {%- set first_user_message = messages[0]['content']|trim %}\\n        {%- set messages = messages[1:] %}\\n    {%- else %}\\n        {{- raise_exception(\\\"Cannot put tools in the first user message when there's no first user message!\\\") }}\\n{%- endif %}\\n    {{- '<|start_header_id|>user<|end_header_id|>\\\\n\\\\n' -}}\\n    {{- \\\"Given the following functions, please respond with a JSON for a function call \\\" }}\\n    {{- \\\"with its proper arguments that best answers the given prompt.\\\\n\\\\n\\\" }}\\n    {{- 'Respond in the format {\\\"name\\\": function name, \\\"parameters\\\": dictionary of argument name and its value}.' }}\\n    {{- \\\"Do not use variables.\\\\n\\\\n\\\" }}\\n    {%- for t in tools %}\\n        {{- t | tojson(indent=4) }}\\n        {{- \\\"\\\\n\\\\n\\\" }}\\n    {%- endfor %}\\n    {{- first_user_message + \\\"<|eot_id|>\\\"}}\\n{%- endif %}\\n\\n{%- for message in messages %}\\n    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\\n        {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\\\n\\\\n'+ message['content'] | trim + '<|eot_id|>' }}\\n    {%- elif 'tool_calls' in message %}\\n        {%- if not message.tool_calls|length == 1 %}\\n            {{- raise_exception(\\\"This model only supports single tool-calls at once!\\\") }}\\n        {%- endif %}\\n        {%- set tool_call = message.tool_calls[0].function %}\\n        {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\\\n\\\\n' -}}\\n            {{- \\\"<|python_tag|>\\\" + tool_call.name + \\\".call(\\\" }}\\n            {%- for arg_name, arg_val in tool_call.arguments | items %}\\n                {{- arg_name + '=\\\"' + arg_val + '\\\"' }}\\n                {%- if not loop.last %}\\n                    {{- \\\", \\\" }}\\n                {%- endif %}\\n                {%- endfor %}\\n            {{- \\\")\\\" }}\\n        {%- else  %}\\n            {{- '<|start_header_id|>assistant<|end_header_id|>\\\\n\\\\n' -}}\\n            {{- '{\\\"name\\\": \\\"' + tool_call.name + '\\\", ' }}\\n            {{- '\\\"parameters\\\": ' }}\\n            {{- tool_call.arguments | tojson }}\\n            {{- \\\"}\\\" }}\\n        {%- endif %}\\n        {%- if builtin_tools is defined %}\\n            {#- This means we're in ipython mode #}\\n            {{- \\\"<|eom_id|>\\\" }}\\n        {%- else %}\\n            {{- \\\"<|eot_id|>\\\" }}\\n        {%- endif %}\\n    {%- elif message.role == \\\"tool\\\" or message.role == \\\"ipython\\\" %}\\n        {{- \\\"<|start_header_id|>ipython<|end_header_id|>\\\\n\\\\n\\\" }}\\n        {%- if message.content is mapping or message.content is iterable %}\\n            {{- message.content | tojson }}\\n        {%- else %}\\n            {{- message.content }}\\n        {%- endif %}\\n        {{- \\\"<|eot_id|>\\\" }}\\n    {%- endif %}\\n{%- endfor %}\\n{%- if add_generation_prompt %}\\n    {{- '<|start_header_id|>assistant<|end_header_id|>\\\\n\\\\n' }}\\n{%- endif %}\\n\".to_string(),\n            Some(TokenizerConfigToken::String(\"<s>\".to_string())),\n            Some(TokenizerConfigToken::String(\"</s>\".to_string())),\n        );\n        let msgs: Vec<Message> = vec![\n            Message {\n                name: None,\n                role: \"system\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\n                        \"Youre a helpful assistant! Answer the users question best you can.\"\n                            .to_string(),\n                    ),\n                },\n            },\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\n                        \"What is the weather like in Brooklyn, New York?\".to_string(),\n                    ),\n                },\n            },\n        ];\n        let tools_string = r#\"[{\"type\": \"function\",\"function\": {\"name\": \"get_current_weather\",\"description\": \"Get the current weather\",\"parameters\": {\"type\": \"object\",\"properties\": {\"location\": {\"type\": \"string\",\"description\": \"The city and state, e.g. San Francisco, CA\"},\"format\": {\"type\": \"string\",\"enum\": [\"celsius\", \"fahrenheit\"],\"description\": \"The temperature unit to use. Infer this from the users location.\"}},\"required\": [\"location\", \"format\"]}}}]\"#.to_string();\n        let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();\n        let tool_prompt = \"This default prompt will be used\".to_string();\n        let tools_and_prompt = Some((tools, tool_prompt));\n        let result = ct.apply(msgs, tools_and_prompt);\n        let expected = \"<s><|start_header_id|>system<|end_header_id|>\\n\\nEnvironment: ipython\\nCutting Knowledge Date: December 2023\\nToday Date: 26 Jul 2024\\n\\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\\n\\nRespond in the format {\\\"name\\\": function name, \\\"parameters\\\": dictionary of argument name and its value}.Do not use variables.\\n\\n{\\n    \\\"function\\\": {\\n        \\\"arguments\\\": \\\"{\\\\\\\"type\\\\\\\":\\\\\\\"object\\\\\\\",\\\\\\\"properties\\\\\\\":{\\\\\\\"location\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\",\\\\\\\"description\\\\\\\":\\\\\\\"The city and state, e.g. San Francisco, CA\\\\\\\"},\\\\\\\"format\\\\\\\":{\\\\\\\"type\\\\\\\":\\\\\\\"string\\\\\\\",\\\\\\\"enum\\\\\\\":[\\\\\\\"celsius\\\\\\\",\\\\\\\"fahrenheit\\\\\\\"],\\\\\\\"description\\\\\\\":\\\\\\\"The temperature unit to use. Infer this from the users location.\\\\\\\"}},\\\\\\\"required\\\\\\\":[\\\\\\\"location\\\\\\\",\\\\\\\"format\\\\\\\"]}\\\",\\n        \\\"description\\\": \\\"Get the current weather\\\",\\n        \\\"name\\\": \\\"get_current_weather\\\"\\n    },\\n    \\\"type\\\": \\\"function\\\"\\n}\\n\\nWhat is the weather like in Brooklyn, New York?\\n---\\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\".to_string();\n        assert_eq!(result.unwrap(), expected);\n    }\n\n    #[test]\n    fn test_chat_template_with_special_system_prompt() {\n        // chat template from gemma3\n        let ct = ChatTemplate::new(\n            r#\"{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n    {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = \"\" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n    {%- endif -%}\n    {%- if (message['role'] == 'assistant') -%}\n        {%- set role = \"model\" -%}\n    {%- else -%}\n        {%- set role = message['role'] -%}\n    {%- endif -%}\n    {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n    {%- if message['content'] is string -%}\n        {{ message['content'] | trim }}\n    {%- elif message['content'] is iterable -%}\n        {%- for item in message['content'] -%}\n            {%- if item['type'] == 'image' -%}\n                {{ '<start_of_image>' }}\n            {%- elif item['type'] == 'text' -%}\n                {{ item['text'] | trim }}\n            {%- endif -%}\n        {%- endfor -%}\n    {%- else -%}\n        {{ raise_exception(\"Invalid content type\") }}\n    {%- endif -%}\n    {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n    {{'<start_of_turn>model\n'}}\n{%- endif -%}\n\"#\n            .to_string(),\n            Some(TokenizerConfigToken::String(\"<bos>\".to_string())),\n            Some(TokenizerConfigToken::String(\"</eos>\".to_string())),\n        );\n        let msgs: Vec<Message> = vec![\n            Message {\n                name: None,\n                role: \"system\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::MultipleChunks(vec![MessageChunk::Text {\n                        text: \"You are a helpful assistant.\".to_string(),\n                    }]),\n                },\n            },\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::MultipleChunks(vec![\n                        MessageChunk::Text {\n                            text: \"I'm already using this supplement \".to_string(),\n                        },\n                        MessageChunk::ImageUrl {\n                            image_url: Url {\n                                url:  \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG\".to_string()\n                            },\n                        },\n                        MessageChunk::Text {\n                            text: \"and I want to use this one too \".to_string()\n                        },\n                        MessageChunk::ImageUrl {\n                            image_url: Url {\n                                url: \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg\".to_string()\n                            },\n                        },\n                        MessageChunk::Text {\n                            text: \" what are cautions?\".to_string()\n                        },\n                    ]),\n                },\n            },\n        ];\n\n        let result = ct.apply(msgs, None);\n        let expected = \"<bos><start_of_turn>user\\nYou are a helpful assistant.\\n\\nI'm already using this supplement ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3018.JPG)and I want to use this one too ![](https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_3015.jpg) what are cautions?<end_of_turn>\\n<start_of_turn>model\\n\".to_string();\n        assert_eq!(result.unwrap(), expected);\n    }\n}\n"
  },
  {
    "path": "router/src/infer/mod.rs",
    "content": "// pub(crate) mod v2;\nmod chat_template;\npub mod tool_grammar;\n\nuse crate::validation::{ValidGenerateRequest, Validation, ValidationError};\nuse crate::Tool;\nuse crate::{\n    ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, HubTokenizerConfig,\n    Message, PrefillToken, Token,\n};\nuse async_stream::stream;\nuse async_trait::async_trait;\nuse axum::response::sse::Event;\nuse chat_template::ChatTemplate;\nuse futures::future::try_join_all;\nuse futures::Stream;\nuse minijinja::ErrorKind;\nuse serde::Serialize;\nuse std::sync::atomic::{AtomicBool, Ordering};\nuse std::sync::Arc;\nuse thiserror::Error;\nuse tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};\nuse tokio::time::Instant;\nuse tokio_stream::wrappers::UnboundedReceiverStream;\nuse tokio_stream::StreamExt;\nuse tracing::instrument;\n\n#[async_trait]\npub trait Backend {\n    fn schedule(\n        &self,\n        request: ValidGenerateRequest,\n    ) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;\n\n    async fn health(&self, current_health: bool) -> bool;\n\n    /// The state of the health on startup\n    /// Typically false, or true if the backend includes\n    /// a warmup phase.\n    fn start_health(&self) -> bool {\n        false\n    }\n\n    fn name(&self) -> &'static str;\n}\n\n/// Inference struct\n#[derive(Clone)]\npub struct Infer {\n    /// Validation\n    validation: Validation,\n    /// Request backend\n    backend: Arc<dyn Backend + Send + Sync>,\n    /// Chat template\n    pub(crate) chat_template: Option<ChatTemplate>,\n    /// Inference limit\n    limit_concurrent_requests: Arc<Semaphore>,\n    /// Backend health\n    backend_health: Arc<AtomicBool>,\n}\n\nimpl Infer {\n    #[allow(clippy::too_many_arguments)]\n    pub(crate) fn new(\n        backend: impl Backend + Send + Sync + 'static,\n        validation: Validation,\n        max_concurrent_requests: usize,\n        tokenizer_config: HubTokenizerConfig,\n        processor_config: HubProcessorConfig,\n    ) -> Self {\n        let chat_template = tokenizer_config\n            .chat_template\n            .or(processor_config.chat_template)\n            .and_then(|t| match t {\n                ChatTemplateVersions::Single(template) => Some(template),\n                ChatTemplateVersions::Multiple(templates) => templates\n                    .into_iter()\n                    .find(|t| t.name == \"default\")\n                    .map(|t| t.template),\n            })\n            .map(|t| ChatTemplate::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token));\n\n        // Inference limit with a semaphore\n        let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));\n\n        // Backend health\n        let backend_health = Arc::new(AtomicBool::new(backend.start_health()));\n\n        Self {\n            validation,\n            backend: Arc::new(backend),\n            chat_template,\n            limit_concurrent_requests: semaphore,\n            backend_health,\n        }\n    }\n\n    /// Add a new request to the queue and return a stream of InferStreamResponse\n    #[instrument(skip_all)]\n    pub(crate) async fn generate_stream<'a>(\n        &'a self,\n        request: GenerateRequest,\n    ) -> Result<\n        (\n            OwnedSemaphorePermit,\n            u32, // input_length\n            impl Stream<Item = Result<InferStreamResponse, InferError>> + 'a,\n        ),\n        InferError,\n    > {\n        // Limit concurrent requests by acquiring a permit from the semaphore\n        let permit = self\n            .clone()\n            .limit_concurrent_requests\n            .try_acquire_owned()\n            .map_err(|err| {\n                metrics::counter!(\"tgi_request_failure\", \"err\" => \"overloaded\").increment(1);\n                tracing::error!(\"{err}\");\n                err\n            })?;\n\n        // Validate request\n        let mut local_request = request.clone();\n        let valid_request = self.validation.validate(request).await.map_err(|err| {\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"validation\").increment(1);\n            tracing::error!(\"{err}\");\n            err\n        })?;\n\n        let seed = valid_request.parameters.seed;\n        local_request.parameters.seed = Some(seed);\n        let input_length = valid_request.input_length;\n        let max_total_new_tokens = valid_request.stopping_parameters.max_total_new_tokens;\n        let mut generation_stream = self.backend.schedule(valid_request)?;\n\n        // Wrap generation stream to update the backend health if the stream contains an error\n        let final_stream = stream! {\n            let mut total_generated_tokens = 0;\n            let mut first_start = None;\n            let mut first_queued = None;\n            let mut all_generated_text: Option<GeneratedText> = None;\n\n            while let Some(response) = generation_stream.next().await {\n                let response = response.inspect_err(|_err| {\n                    self.backend_health.store(false, Ordering::SeqCst);\n                })?;\n\n                match response {\n                    InferStreamResponse::Prefill(_) => yield Ok(response),\n                    InferStreamResponse::Intermediate { .. } => {\n                        total_generated_tokens += 1;\n                        yield Ok(response);\n                    }\n                    InferStreamResponse::End { token, top_tokens,generated_text, start, queued  } => {\n                        total_generated_tokens += 1;\n                        first_start = first_start.or(Some(start));\n                        first_queued = first_queued.or(Some(queued));\n                        if let Some(v) = all_generated_text.as_mut() {\n                                v.text.push_str(&generated_text.text);\n                                v.generated_tokens = total_generated_tokens;\n                                v.finish_reason = generated_text.finish_reason.clone();\n                        };\n\n                        if matches!(generated_text.finish_reason, FinishReason::Length) && total_generated_tokens < max_total_new_tokens {\n                            local_request.inputs.push_str(&generated_text.text);\n                            all_generated_text = all_generated_text.or(Some(generated_text));\n\n                            let valid_request = match self.validation.validate(local_request.clone()).await {\n                                Ok(valid_request) => valid_request,\n                                Err(err) => {\n                                    tracing::debug!(\"Failed to continue request: {err}\");\n                                    yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });\n                                    break;\n                                }\n                            };\n\n                            generation_stream = match self.backend.schedule(valid_request) {\n                                Ok(stream) => {\n                                    tracing::debug!(\"Continue request\");\n                                    yield Ok(InferStreamResponse::Intermediate { token, top_tokens } );\n                                    stream\n                                },\n                                Err(err) => {\n                                    tracing::debug!(\"Failed to continue request: {err}\");\n                                    yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap(), start: first_start.unwrap(), queued: first_queued.unwrap() });\n                                    break;\n                                }\n                            }\n                        } else {\n                            yield Ok(InferStreamResponse::End {token, top_tokens, generated_text: all_generated_text.unwrap_or(generated_text), start: first_start.unwrap(), queued: first_queued.unwrap() });\n                            break;\n                        }\n\n                    }\n                }\n            }\n        };\n\n        Ok((permit, input_length, final_stream))\n    }\n\n    /// Tokenizer the input\n    #[instrument(skip_all)]\n    pub(crate) async fn tokenize(\n        &self,\n        request: GenerateRequest,\n    ) -> Result<tokenizers::Encoding, InferError> {\n        // Tokenize request\n        let inputs = request.inputs;\n        let add_special_tokens = request.add_special_tokens;\n        let truncate = request.parameters.truncate;\n        let encoding = self\n            .validation\n            .tokenize(inputs, add_special_tokens, truncate)\n            .await\n            .map_err(|err| {\n                tracing::error!(\"Tokenization {err}\");\n                err\n            })?;\n\n        // Return Encoding\n        Ok(encoding.0)\n    }\n\n    /// Apply the chat template to the chat request\n    #[instrument(skip_all)]\n    pub(crate) fn apply_chat_template(\n        &self,\n        messages: Vec<Message>,\n        tools_and_prompt: Option<(Vec<Tool>, String)>,\n    ) -> Result<String, InferError> {\n        self.chat_template\n            .as_ref()\n            .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?\n            .apply(messages, tools_and_prompt)\n            .map_err(|e| {\n                metrics::counter!(\"tgi_request_failure\", \"err\" => \"template\").increment(1);\n                tracing::error!(\"{e}\");\n                e\n            })\n    }\n\n    /// Add a new request to the queue and return a InferResponse\n    #[instrument(skip_all)]\n    pub(crate) async fn generate(\n        &self,\n        request: GenerateRequest,\n    ) -> Result<InferResponse, InferError> {\n        let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);\n\n        // Create stream and keep semaphore permit as long as generate lives\n        let (_permit, _input_length, stream) = self.generate_stream(request).await?;\n\n        // Return values\n        let mut result_prefill = Vec::new();\n        let mut result_tokens = Vec::new();\n        let mut result_top_tokens = Vec::new();\n        let mut result_generated_text = None;\n        let mut result_start = None;\n        let mut result_queued = None;\n\n        let mut stream = Box::pin(stream);\n\n        // Iterate on stream\n        while let Some(response) = stream.next().await {\n            match response? {\n                // Add prefill tokens\n                InferStreamResponse::Prefill(prefill_tokens) => {\n                    result_prefill = prefill_tokens;\n                }\n                // Push last token\n                InferStreamResponse::Intermediate { token, top_tokens } => {\n                    result_tokens.push(token);\n                    result_top_tokens.push(top_tokens);\n                }\n                // Final message\n                // Set return values\n                InferStreamResponse::End {\n                    token,\n                    generated_text,\n                    start,\n                    queued,\n                    top_tokens,\n                } => {\n                    result_tokens.push(token);\n                    result_top_tokens.push(top_tokens);\n                    result_generated_text = Some(generated_text);\n                    result_start = Some(start);\n                    result_queued = Some(queued)\n                }\n            }\n        }\n\n        // Check that we received a `InferStreamResponse::End` message\n        if let (Some(generated_text), Some(queued), Some(start)) =\n            (result_generated_text, result_queued, result_start)\n        {\n            Ok(InferResponse {\n                prefill: result_prefill,\n                _input_length,\n                tokens: result_tokens,\n                generated_text,\n                queued,\n                start,\n                top_tokens: if use_top_tokens {\n                    result_top_tokens\n                } else {\n                    Vec::new()\n                },\n            })\n        } else {\n            let err = InferError::IncompleteGeneration;\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"incomplete\").increment(1);\n            tracing::error!(\"{err}\");\n            Err(err)\n        }\n    }\n    /// Add best_of new requests to the queue and return a InferResponse of the sequence with\n    /// the highest log probability per token\n    #[instrument(skip(self, request))]\n    pub(crate) async fn generate_best_of(\n        &self,\n        request: GenerateRequest,\n        best_of: usize,\n    ) -> Result<(InferResponse, Vec<InferResponse>), InferError> {\n        // validate  best_of parameter separately\n        let best_of = self.validation.validate_best_of(best_of)?;\n\n        // create multiple generate requests\n        let mut infer_responses: Vec<InferResponse> =\n            try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;\n\n        // get the sequence with the highest log probability per token\n        let mut max_index = 0;\n        let mut max_logprob: f32 = f32::MIN;\n\n        for (i, response) in infer_responses.iter().enumerate() {\n            // mean logprobs of the generated tokens\n            let sequence_logprob = response\n                .tokens\n                .iter()\n                .map(|token| token.logprob)\n                .sum::<f32>()\n                / response.tokens.len() as f32;\n\n            // set best sequence\n            if sequence_logprob > max_logprob {\n                max_index = i;\n                max_logprob = sequence_logprob;\n            }\n        }\n        let best_response = infer_responses.remove(max_index);\n        Ok((best_response, infer_responses))\n    }\n\n    #[instrument(skip(self))]\n    pub(crate) async fn health(&self) -> bool {\n        let health = self\n            .backend\n            .health(self.backend_health.load(Ordering::SeqCst))\n            .await;\n        self.backend_health.store(health, Ordering::SeqCst);\n        health\n    }\n}\n\n#[derive(Debug)]\npub struct GeneratedText {\n    pub text: String,\n    pub generated_tokens: u32,\n    pub finish_reason: FinishReason,\n    pub seed: Option<u64>,\n}\n\n#[derive(Debug)]\npub enum InferStreamResponse {\n    // Optional first message\n    Prefill(Vec<PrefillToken>),\n    // Intermediate messages\n    Intermediate {\n        token: Token,\n        top_tokens: Vec<Token>,\n    },\n    // Last message\n    End {\n        token: Token,\n        top_tokens: Vec<Token>,\n        generated_text: GeneratedText,\n        start: Instant,\n        queued: Instant,\n    },\n}\n\n#[derive(Debug)]\npub(crate) struct InferResponse {\n    /// input_length is the input as perceived by the rust tokenizer in the\n    /// validation pathway. It is redundant with prefill.len() but prefill\n    /// has data only if the user asked for it. This will always be filled.\n    pub(crate) _input_length: u32,\n    pub(crate) prefill: Vec<PrefillToken>,\n    pub(crate) tokens: Vec<Token>,\n    pub(crate) generated_text: GeneratedText,\n    pub(crate) queued: Instant,\n    pub(crate) start: Instant,\n    pub(crate) top_tokens: Vec<Vec<Token>>,\n}\n\n#[derive(Debug, Error)]\npub enum InferError {\n    #[error(\"Request failed during generation: {0}\")]\n    GenerationError(String),\n    #[error(\"Model is overloaded\")]\n    Overloaded(#[from] TryAcquireError),\n    #[error(\"Input validation error: {0}\")]\n    ValidationError(#[from] ValidationError),\n    #[error(\"Incomplete generation\")]\n    IncompleteGeneration,\n    #[error(\"Incomplete generation stream\")]\n    IncompleteGenerationStream,\n    #[error(\"Template error: {0}\")]\n    TemplateError(#[from] minijinja::Error),\n    #[error(\"Missing template vatiable: {0}\")]\n    MissingTemplateVariable(String),\n    #[error(\"Tool error: {0}\")]\n    ToolError(String),\n    #[error(\"Stream event serialization error\")]\n    StreamSerializationError(String),\n}\n\nimpl InferError {\n    pub(crate) fn error_type(&self) -> &str {\n        match self {\n            InferError::GenerationError(_) => \"generation\",\n            InferError::Overloaded(_) => \"overloaded\",\n            InferError::ValidationError(_) => \"validation\",\n            InferError::IncompleteGeneration => \"incomplete_generation\",\n            InferError::IncompleteGenerationStream => \"incomplete_generation_stream\",\n            InferError::TemplateError(_) => \"template_error\",\n            InferError::MissingTemplateVariable(_) => \"missing_template_variable\",\n            InferError::ToolError(_) => \"tool_error\",\n            InferError::StreamSerializationError(_) => \"stream_serialization_error\",\n        }\n    }\n\n    pub(crate) fn into_openai_event(self) -> Event {\n        Event::default()\n            .json_data(OpenaiErrorEvent {\n                error: APIError {\n                    message: self.to_string(),\n                    http_status_code: 422,\n                },\n            })\n            .unwrap()\n    }\n}\n\n#[derive(Serialize)]\npub struct APIError {\n    message: String,\n    http_status_code: usize,\n}\n\n#[derive(Serialize)]\npub struct OpenaiErrorEvent {\n    error: APIError,\n}\n"
  },
  {
    "path": "router/src/infer/tool_grammar.rs",
    "content": "use crate::infer::InferError;\nuse crate::{\n    FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,\n};\nuse serde_json::{json, Map, Value};\nuse std::collections::HashMap;\n\npub(crate) struct ToolGrammar {}\n\nimpl ToolGrammar {\n    // find a tool by name\n    fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {\n        tools\n            .iter()\n            .find(|tool| tool.function.name == name)\n            .cloned()\n            .ok_or_else(|| InferError::ToolError(format!(\"Tool with name {} not found\", name)))\n    }\n\n    pub fn apply(\n        tools: Vec<Tool>,\n        tool_choice: ToolChoice,\n    ) -> Result<Option<(Vec<Tool>, JsonSchemaTool)>, InferError> {\n        let tools_to_use = match tool_choice {\n            ToolChoice::Function(function) => {\n                vec![Self::find_tool_by_name(&tools, &function.name)?]\n            }\n            ToolChoice::Required => tools,\n            ToolChoice::Auto => {\n                // only add the no_tool function if the user has selected the auto option\n                tools\n                    .iter()\n                    .cloned()\n                    .chain(std::iter::once(Tool {\n                        r#type: \"function\".to_string(),\n                        function: FunctionDefinition {\n                            name: \"no_tool\".to_string(),\n                            description: Some(\n                                \"Open ended response with no specific tool selected\".to_string(),\n                            ),\n                            arguments: json!({\n                                \"type\": \"object\",\n                                // \"properties\": {\n                                //     \"content\": {\n                                //         \"type\": \"string\",\n                                //         \"description\": \"The response content\",\n                                //     }\n                                // },\n                                // \"required\": [\"content\"]\n                            }),\n                        },\n                    }))\n                    .collect::<Vec<_>>()\n            }\n            ToolChoice::NoTool => vec![],\n        };\n\n        // if no tools are provided or if the user has selected the no_tool option, return None\n        if tools_to_use.is_empty() {\n            return Ok(None);\n        }\n\n        let functions: HashMap<String, serde_json::Value> = tools_to_use\n            .iter()\n            .map(|tool| {\n                let func = tool.function.clone();\n\n                let mut params = Map::new();\n\n                params.insert(\n                    \"description\".to_string(),\n                    Value::String(func.description.unwrap_or_default()),\n                );\n\n                let mut properties = Map::new();\n                let mut required = vec![Value::String(\"_name\".to_string())];\n\n                properties.insert(\n                    \"_name\".to_string(),\n                    json!({\n                        \"type\": \"string\",\n                        \"const\": func.name.clone(),\n                    }),\n                );\n\n                if let Value::Object(args) = func.arguments {\n                    if let Some(Value::Object(props)) = args.get(\"properties\") {\n                        properties.extend(props.clone());\n                    }\n                    if let Some(Value::Array(reqs)) = args.get(\"required\") {\n                        required.extend(reqs.clone());\n                    }\n                    params.insert(\n                        \"additionalProperties\".to_string(),\n                        Value::Bool(\n                            args.get(\"additionalProperties\").and_then(|v| v.as_str())\n                                == Some(\"true\"),\n                        ),\n                    );\n                }\n\n                params.insert(\"properties\".to_string(), Value::Object(properties));\n                params.insert(\"required\".to_string(), Value::Array(required));\n\n                (func.name, Value::Object(params))\n            })\n            .collect();\n\n        let tool_schema = JsonSchemaTool {\n            functions_map: FunctionsMap { functions },\n            properties: Properties {\n                function: tools_to_use\n                    .iter()\n                    .map(|tool| FunctionRef {\n                        ref_path: format!(\"#/$functions/{}\", tool.function.name.clone()),\n                    })\n                    .collect(),\n            },\n        };\n\n        Ok(Some((tools_to_use, tool_schema)))\n    }\n}\n"
  },
  {
    "path": "router/src/kserve.rs",
    "content": "use crate::infer::Infer;\nuse crate::{\n    default_parameters,\n    server::{generate_internal, ComputeType},\n    Deserialize, ErrorResponse, GenerateParameters, GenerateRequest, Serialize, ToSchema,\n};\nuse axum::extract::{Extension, Path};\nuse axum::http::{HeaderMap, StatusCode};\nuse axum::response::IntoResponse;\nuse axum::Json;\nuse futures::stream::FuturesUnordered;\nuse futures::TryStreamExt;\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub struct OutputChunk {\n    pub name: String,\n    pub shape: Vec<usize>,\n    pub datatype: String,\n    pub data: Vec<u8>,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub struct InferenceOutput {\n    pub id: String,\n    pub outputs: Vec<OutputChunk>,\n}\n\n#[derive(Debug, Deserialize, ToSchema)]\npub(crate) struct InferenceRequest {\n    pub id: String,\n    #[serde(default = \"default_parameters\")]\n    pub parameters: GenerateParameters,\n    pub inputs: Vec<Input>,\n    pub outputs: Vec<Output>,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub(crate) struct Input {\n    pub name: String,\n    pub shape: Vec<usize>,\n    pub datatype: String,\n    pub data: Vec<u8>,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub(crate) struct Output {\n    pub name: String,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub struct LiveResponse {\n    pub live: bool,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub struct ReadyResponse {\n    pub live: bool,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema)]\npub struct MetadataServerResponse {\n    pub name: String,\n    pub version: String,\n    pub extensions: Vec<String>,\n}\n\n#[utoipa::path(\n    post,\n    tag = \"Text Generation Inference\",\n    path = \"/v2/health/live\",\n    responses(\n        (status = 200, description = \"Service is live\", body = LiveReponse),\n        (status = 404, description = \"Service not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kserve_health_live() -> Json<LiveResponse> {\n    let data = LiveResponse { live: true };\n    Json(data)\n}\n\n#[utoipa::path(\n    get,\n    tag = \"Text Generation Inference\",\n    path = \"/v2/health/ready\",\n    responses(\n        (status = 200, description = \"Service is ready\", body = ReadyResponse),\n        (status = 404, description = \"Service not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kserve_health_ready() -> Json<ReadyResponse> {\n    let data = ReadyResponse { live: true };\n    Json(data)\n}\n\n#[utoipa::path(\n    get,\n    tag = \"Text Generation Inference\",\n    path = \"/v2\",\n    responses(\n        (status = 200, description = \"Metadata retrieved\", body = MetadataServerResponse),\n        (status = 404, description = \"Service not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kerve_server_metadata() -> Json<MetadataServerResponse> {\n    let data = MetadataServerResponse {\n        name: \"text-generation-inference\".to_string(),\n        version: env!(\"CARGO_PKG_VERSION\").to_string(),\n        extensions: vec![\n            \"health\".to_string(),\n            \"models\".to_string(),\n            \"metrics\".to_string(),\n        ],\n    };\n    Json(data)\n}\n\n#[utoipa::path(\n    get,\n    tag = \"Text Generation Inference\",\n    path = \"/v2/models/{model_name}/versions/{model_version}\",\n    responses(\n        (status = 200, description = \"Model version metadata retrieved\", body = MetadataServerResponse),\n        (status = 404, description = \"Model or version not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kserve_model_metadata(\n    Path((model_name, model_version)): Path<(String, String)>,\n) -> Json<MetadataServerResponse> {\n    let data = MetadataServerResponse {\n        name: model_name,\n        version: model_version,\n        extensions: vec![\"infer\".to_string(), \"ready\".to_string()],\n    };\n    Json(data)\n}\n\n#[utoipa::path(\n    get,\n    tag = \"Text Generation Inference\",\n    path = \"/v2/models/{model_name}/versions/{model_version}/ready\",\n    responses(\n        (status = 200, description = \"Model version is ready\", body = ReadyResponse),\n        (status = 404, description = \"Model or version not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kserve_model_metadata_ready(\n    Path((_model_name, _model_version)): Path<(String, String)>,\n) -> Json<ReadyResponse> {\n    let data = ReadyResponse { live: true };\n    Json(data)\n}\n\n#[utoipa::path(\n    post,\n    tag = \"Text Generation Inference\",\n    path = \"/v2/models/{model_name}/versions/{model_version}/infer\",\n    request_body = Json<InferenceRequest>,\n    responses(\n        (status = 200, description = \"Inference executed successfully\", body = InferenceOutput),\n        (status = 404, description = \"Model or version not found\", body = ErrorResponse,\n            example = json!({\"error\": \"No response\"}))\n    )\n)]\npub async fn kserve_model_infer(\n    infer: Extension<Infer>,\n    Extension(compute_type): Extension<ComputeType>,\n    Json(payload): Json<InferenceRequest>,\n) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {\n    let id = payload.id.clone();\n    let str_inputs = payload\n        .inputs\n        .iter()\n        .map(|input| {\n            std::str::from_utf8(&input.data).map_err(|e| {\n                (\n                    StatusCode::UNPROCESSABLE_ENTITY,\n                    Json(ErrorResponse {\n                        error: e.to_string(),\n                        error_type: \"utf8\".to_string(),\n                    }),\n                )\n            })\n        })\n        .collect::<Result<Vec<_>, _>>()?;\n\n    if str_inputs.len() != payload.outputs.len() {\n        return Err((\n            StatusCode::UNPROCESSABLE_ENTITY,\n            Json(ErrorResponse {\n                error: \"Inputs and outputs length mismatch\".to_string(),\n                error_type: \"length mismatch\".to_string(),\n            }),\n        ));\n    }\n\n    let output_chunks = str_inputs\n        .iter()\n        .zip(&payload.outputs)\n        .map(|(str_input, output)| {\n            let generate_request = GenerateRequest {\n                inputs: str_input.to_string(),\n                parameters: payload.parameters.clone(),\n                add_special_tokens: true,\n            };\n            let infer = infer.clone();\n            let compute_type = compute_type.clone();\n            let span = tracing::Span::current();\n            async move {\n                generate_internal(infer, compute_type, Json(generate_request), span)\n                    .await\n                    .map(|(_, _, Json(generation))| {\n                        let generation_as_bytes = generation.generated_text.as_bytes().to_vec();\n                        OutputChunk {\n                            name: output.name.clone(),\n                            shape: vec![1, generation_as_bytes.len()],\n                            datatype: \"BYTES\".to_string(),\n                            data: generation_as_bytes,\n                        }\n                    })\n                    .map_err(|_| {\n                        (\n                            StatusCode::INTERNAL_SERVER_ERROR,\n                            Json(ErrorResponse {\n                                error: \"Incomplete generation\".into(),\n                                error_type: \"Incomplete generation\".into(),\n                            }),\n                        )\n                    })\n            }\n        })\n        .collect::<FuturesUnordered<_>>()\n        .try_collect::<Vec<_>>()\n        .await?;\n\n    let inference_output = InferenceOutput {\n        id: id.clone(),\n        outputs: output_chunks,\n    };\n\n    Ok((HeaderMap::new(), Json(inference_output)))\n}\n"
  },
  {
    "path": "router/src/lib.rs",
    "content": "/// Text Generation Inference Webserver\npub mod config;\npub mod infer;\npub mod server;\npub mod validation;\n\n#[cfg(feature = \"kserve\")]\nmod kserve;\npub mod logging;\n\nmod chat;\nmod sagemaker;\npub mod usage_stats;\nmod vertex;\n\nuse crate::infer::tool_grammar::ToolGrammar;\nuse crate::infer::{Infer, InferError};\nuse pyo3::prelude::*;\nuse pyo3::types::IntoPyDict;\nuse serde::{Deserialize, Serialize};\nuse tokenizers::Encoding;\nuse tracing::warn;\nuse utoipa::ToSchema;\nuse uuid::Uuid;\nuse validation::Validation;\n\n#[allow(clippy::large_enum_variant)]\n#[derive(Clone)]\npub enum Tokenizer {\n    Python {\n        tokenizer_name: String,\n        revision: Option<String>,\n        trust_remote_code: bool,\n    },\n    Rust(tokenizers::Tokenizer),\n}\n\npub struct PyTokenizer<'a>(pyo3::Bound<'a, pyo3::PyAny>);\n\nimpl<'a> PyTokenizer<'a> {\n    fn from_py(\n        py: Python<'a>,\n        tokenizer_name: String,\n        revision: Option<String>,\n        trust_remote_code: bool,\n    ) -> PyResult<PyTokenizer<'a>> {\n        let transformers = py.import_bound(\"transformers\")?;\n        let auto = transformers.getattr(\"AutoTokenizer\")?;\n        let from_pretrained = auto.getattr(\"from_pretrained\")?;\n        let args = (tokenizer_name,);\n        let kwargs = if let Some(rev) = &revision {\n            [\n                (\"revision\", rev.to_string().into_py(py)),\n                (\"trust_remote_code\", trust_remote_code.into_py(py)),\n            ]\n            .into_py_dict_bound(py)\n        } else {\n            [(\"trust_remote_code\", trust_remote_code.into_py(py))].into_py_dict_bound(py)\n        };\n        let tokenizer = from_pretrained.call(args, Some(&kwargs))?;\n        tracing::info!(\"Loaded a python tokenizer\");\n        Ok(PyTokenizer(tokenizer))\n    }\n}\n\ntrait TokenizerTrait {\n    fn encode_trait(\n        &self,\n        query: String,\n        add_special_tokens: bool,\n    ) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>>;\n}\n\nimpl TokenizerTrait for tokenizers::Tokenizer {\n    fn encode_trait(\n        &self,\n        query: String,\n        add_special_tokens: bool,\n    ) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {\n        self.encode(query, add_special_tokens)\n    }\n}\n\nimpl TokenizerTrait for PyTokenizer<'_> {\n    fn encode_trait(\n        &self,\n        query: String,\n        add_special_tokens: bool,\n    ) -> Result<tokenizers::Encoding, Box<dyn std::error::Error + Send + Sync>> {\n        let py = self.0.py();\n        let kwargs = [\n            (\"text\", query.into_py(py)),\n            (\"add_special_tokens\", add_special_tokens.into_py(py)),\n        ]\n        .into_py_dict_bound(py);\n        let encode = self.0.getattr(\"encode\")?;\n        let input_ids: Vec<u32> = encode.call((), Some(&kwargs))?.extract()?;\n        Ok(Encoding::new(\n            input_ids,\n            vec![],                           // type ids\n            vec![],                           // tokens (strings)\n            vec![],                           // words\n            vec![],                           // offsets\n            vec![],                           // special_tokens_mask\n            vec![],                           // attention_mask\n            vec![],                           // overflowing\n            std::collections::HashMap::new(), //sequence_ranges\n        ))\n    }\n}\n\n/// Hub type\n#[derive(Clone, Debug, Deserialize)]\npub struct HubModelInfo {\n    #[serde(rename(deserialize = \"id\"))]\n    pub model_id: String,\n    pub sha: Option<String>,\n    pub pipeline_tag: Option<String>,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\npub struct ChatTemplate {\n    name: String,\n    template: String,\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\n#[serde(untagged)]\npub enum ChatTemplateVersions {\n    Single(String),\n    Multiple(Vec<ChatTemplate>),\n}\n\nuse std::path::Path;\n\n#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct HubTokenizerConfig {\n    pub chat_template: Option<ChatTemplateVersions>,\n    pub completion_template: Option<String>,\n    pub bos_token: Option<TokenizerConfigToken>,\n    pub eos_token: Option<TokenizerConfigToken>,\n    pub tokenizer_class: Option<String>,\n    pub add_bos_token: Option<bool>,\n    pub add_eos_token: Option<bool>,\n}\n\nimpl HubTokenizerConfig {\n    pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {\n        std::fs::read_to_string(filename)\n            .ok()\n            .and_then(|content| serde_json::from_str(&content).ok())\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]\npub struct ChatTemplateStandalone {\n    pub chat_template: ChatTemplateVersions,\n}\n\n#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]\n#[serde(untagged)]\npub enum TokenizerConfigToken {\n    String(String),\n    Object { content: String },\n}\n\nimpl TokenizerConfigToken {\n    pub fn as_str(&self) -> &str {\n        match self {\n            TokenizerConfigToken::String(s) => s,\n            TokenizerConfigToken::Object { content } => content,\n        }\n    }\n}\n\n#[derive(Debug, Clone, Serialize, Deserialize)]\n#[serde(tag = \"processor_class\")]\npub enum HubPreprocessorConfig {\n    Idefics2Processor(Idefics2Preprocessor),\n    Idefics3Processor(Idefics2Preprocessor),\n    Gemma3Processor(Gemma3Processor),\n    Llama4Processor(Llama4Processor),\n}\n\nimpl HubPreprocessorConfig {\n    pub fn from_file<P: AsRef<std::path::Path>>(filename: P) -> Option<Self> {\n        let content = std::fs::read_to_string(filename).ok()?;\n        serde_json::from_str(&content).ok()\n    }\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct Idefics2Preprocessor {\n    #[serde(default)]\n    do_image_splitting: bool,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct Gemma3Processor {\n    #[serde(default)]\n    do_image_splitting: bool,\n}\n\n#[derive(Clone, Debug, Serialize, Deserialize)]\npub struct Llama4Processor {\n    #[serde(default)]\n    max_patches: usize,\n}\n\n#[derive(Debug, Clone, Deserialize, Default)]\npub struct HubProcessorConfig {\n    pub chat_template: Option<ChatTemplateVersions>,\n    pub image_seq_len: usize,\n    pub processor_class: Option<String>,\n}\n\nimpl HubProcessorConfig {\n    pub fn from_file<P: AsRef<Path>>(filename: P) -> Option<Self> {\n        std::fs::read_to_string(filename)\n            .ok()\n            .and_then(|content| serde_json::from_str(&content).ok())\n    }\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]\n#[cfg_attr(test, derive(PartialEq))]\nstruct JsonSchemaConfig {\n    /// Optional name identifier for the schema\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    name: Option<String>,\n\n    /// The actual JSON schema definition\n    schema: serde_json::Value,\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema, Serialize)]\n#[cfg_attr(test, derive(PartialEq))]\n#[serde(tag = \"type\", content = \"value\")]\npub(crate) enum GrammarType {\n    /// A string that represents a [JSON Schema](https://json-schema.org/).\n    ///\n    /// JSON Schema is a declarative language that allows to annotate JSON documents\n    /// with types and descriptions.\n    #[serde(rename = \"json\")]\n    #[serde(alias = \"json_object\")]\n    #[schema(example = json ! ({\"properties\": {\"location\":{\"type\": \"string\"}}}))]\n    Json(serde_json::Value),\n\n    #[serde(rename = \"regex\")]\n    Regex(String),\n\n    /// A JSON Schema specification with additional metadata.\n    ///\n    /// Includes an optional name for the schema, an optional strict flag, and the required schema definition.\n    #[serde(rename = \"json_schema\")]\n    #[schema(example = json ! ({\"schema\": {\"properties\": {\"name\": {\"type\": \"string\"}, \"age\": {\"type\": \"integer\"}}}, \"name\": \"person_info\", \"strict\": true}))]\n    JsonSchema(JsonSchemaConfig),\n}\n\n#[derive(Clone, Debug, Serialize, ToSchema)]\npub struct Info {\n    /// Model info\n    #[schema(example = \"bigscience/blomm-560m\")]\n    pub model_id: String,\n    #[schema(nullable = true, example = \"e985a63cdc139290c5f700ff1929f0b5942cced2\")]\n    pub model_sha: Option<String>,\n    // #[schema(example = \"torch.float16\")]\n    // pub model_dtype: String,\n    // #[schema(example = \"cuda\")]\n    // pub model_device_type: String,\n    #[schema(nullable = true, example = \"text-generation\")]\n    pub model_pipeline_tag: Option<String>,\n\n    /// Router Parameters\n    #[schema(example = \"128\")]\n    pub max_concurrent_requests: usize,\n    #[schema(example = \"2\")]\n    pub max_best_of: usize,\n    #[schema(example = \"4\")]\n    pub max_stop_sequences: usize,\n    #[schema(example = \"1024\")]\n    pub max_input_tokens: usize,\n    #[schema(example = \"2048\")]\n    pub max_total_tokens: usize,\n    #[schema(example = \"2\")]\n    pub validation_workers: usize,\n    #[schema(example = \"32\")]\n    pub max_client_batch_size: usize,\n\n    /// Router Info\n    #[schema(example = \"text-generation-router\")]\n    pub router: &'static str,\n    #[schema(example = \"0.5.0\")]\n    pub version: &'static str,\n    #[schema(nullable = true, example = \"null\")]\n    pub sha: Option<&'static str>,\n    #[schema(nullable = true, example = \"null\")]\n    pub docker_label: Option<&'static str>,\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema, Default)]\n#[cfg_attr(test, derive(PartialEq))]\npub(crate) struct GenerateParameters {\n    /// Generate best_of sequences and return the one if the highest token logprobs.\n    #[serde(default)]\n    #[schema(exclusive_minimum = 0, nullable = true, default = \"null\", example = 1)]\n    pub best_of: Option<usize>,\n\n    /// The value used to module the logits distribution.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = 0.0,\n        nullable = true,\n        default = \"null\",\n        example = 0.5\n    )]\n    pub temperature: Option<f32>,\n\n    /// The parameter for repetition penalty. 1.0 means no penalty.\n    /// See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = 0.0,\n        nullable = true,\n        default = \"null\",\n        example = 1.03\n    )]\n    pub repetition_penalty: Option<f32>,\n\n    /// The parameter for frequency penalty. 1.0 means no penalty\n    /// Penalize new tokens based on their existing frequency in the text so far,\n    /// decreasing the model's likelihood to repeat the same line verbatim.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = -2.0,\n        nullable = true,\n        default = \"null\",\n        example = 0.1\n    )]\n    pub frequency_penalty: Option<f32>,\n\n    /// The number of highest probability vocabulary tokens to keep for top-k-filtering.\n    #[serde(default)]\n    #[schema(exclusive_minimum = 0, nullable = true, default = \"null\", example = 10)]\n    pub top_k: Option<i32>,\n\n    /// Top-p value for nucleus sampling.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = 0.0,\n        maximum = 1.0,\n        nullable = true,\n        default = \"null\",\n        example = 0.95\n    )]\n    pub top_p: Option<f32>,\n\n    /// Typical Decoding mass\n    /// See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = 0.0,\n        maximum = 1.0,\n        nullable = true,\n        default = \"null\",\n        example = 0.95\n    )]\n    pub typical_p: Option<f32>,\n\n    /// Activate logits sampling.\n    #[serde(default)]\n    #[schema(default = \"false\", example = true)]\n    pub do_sample: bool,\n\n    /// Maximum number of tokens to generate.\n    #[serde(default)]\n    #[schema(nullable = true, default = \"1024\", example = \"20\")]\n    pub max_new_tokens: Option<u32>,\n\n    /// Whether to prepend the prompt to the generated text\n    #[serde(default)]\n    #[schema(nullable = true, default = \"null\", example = false)]\n    pub return_full_text: Option<bool>,\n\n    /// Stop generating tokens if a member of `stop` is generated.\n    #[serde(default)]\n    #[schema(inline, max_items = 4, example = json ! ([\"photographer\"]))]\n    pub stop: Vec<String>,\n\n    /// Truncate inputs tokens to the given size.\n    #[serde(default)]\n    #[schema(nullable = true, default = \"null\", example = \"null\")]\n    pub truncate: Option<usize>,\n\n    /// Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).\n    #[serde(default)]\n    #[schema(default = \"false\", example = true)]\n    pub watermark: bool,\n\n    /// Whether to return generation details.\n    #[serde(default)]\n    #[schema(default = \"true\")]\n    pub details: bool,\n\n    /// Whether to return decoder input token logprobs and ids.\n    #[serde(default)]\n    #[schema(default = \"false\")]\n    pub decoder_input_details: bool,\n\n    /// Random sampling seed.\n    #[serde(default)]\n    #[schema(\n        exclusive_minimum = 0,\n        nullable = true,\n        default = \"null\",\n        example = \"null\"\n    )]\n    pub seed: Option<u64>,\n\n    /// The number of highest probability vocabulary tokens to keep for top-n-filtering.\n    #[serde(default)]\n    #[schema(exclusive_minimum = 0, nullable = true, default = \"null\", example = 5)]\n    pub top_n_tokens: Option<u32>,\n\n    /// Grammar constraints for the generation.\n    #[serde(default)]\n    #[schema(nullable = true, default = \"null\", example = \"null\")]\n    pub grammar: Option<GrammarType>,\n\n    /// Lora adapter id\n    #[serde(default)]\n    #[schema(nullable = true, default = \"null\", example = \"null\")]\n    pub adapter_id: Option<String>,\n}\n\nfn default_parameters() -> GenerateParameters {\n    GenerateParameters {\n        best_of: None,\n        temperature: None,\n        repetition_penalty: None,\n        frequency_penalty: None,\n        top_k: None,\n        top_p: None,\n        typical_p: None,\n        do_sample: true,\n        max_new_tokens: None,\n        return_full_text: None,\n        stop: Vec::new(),\n        truncate: None,\n        watermark: false,\n        details: false,\n        decoder_input_details: false,\n        seed: None,\n        top_n_tokens: None,\n        grammar: None,\n        adapter_id: None,\n    }\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]\n#[serde(try_from = \"PromptDeserializer\")]\npub struct Prompt(pub Vec<String>);\n\n#[derive(Deserialize)]\n#[serde(untagged)]\nenum PromptDeserializer {\n    Single(String),\n    Multiple(Vec<String>),\n}\n\nimpl TryFrom<PromptDeserializer> for Prompt {\n    type Error = String;\n\n    fn try_from(value: PromptDeserializer) -> Result<Self, Self::Error> {\n        match value {\n            PromptDeserializer::Single(s) => Ok(Prompt(vec![s])),\n            PromptDeserializer::Multiple(v) => {\n                if v.is_empty() {\n                    Err(\n                        \"Empty array detected. Do not use an empty array for the prompt.\"\n                            .to_string(),\n                    )\n                } else {\n                    Ok(Prompt(v))\n                }\n            }\n        }\n    }\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]\npub struct CompletionRequest {\n    /// UNUSED\n    #[schema(example = \"mistralai/Mistral-7B-Instruct-v0.2\")]\n    /// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.\n    pub model: Option<String>,\n\n    /// The prompt to generate completions for.\n    #[schema(example = \"What is Deep Learning?\")]\n    pub prompt: Prompt,\n\n    /// The maximum number of tokens that can be generated in the chat completion.\n    #[serde(default)]\n    #[schema(default = \"1024\", example = \"32\")]\n    pub max_tokens: Option<u32>,\n\n    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\n    /// lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both.\n    #[serde(default)]\n    #[schema(nullable = true, example = 1.0)]\n    pub temperature: Option<f32>,\n\n    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\n    /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n    #[serde(default)]\n    #[schema(nullable = true, example = 0.95)]\n    pub top_p: Option<f32>,\n\n    #[serde(default = \"bool::default\")]\n    pub stream: bool,\n\n    #[schema(nullable = true, example = 42)]\n    pub seed: Option<u64>,\n\n    /// The text to append to the prompt. This is useful for completing sentences or generating a paragraph of text.\n    /// please see the completion_template field in the model's tokenizer_config.json file for completion template.\n    #[serde(default)]\n    pub suffix: Option<String>,\n\n    #[serde(default)]\n    pub repetition_penalty: Option<f32>,\n\n    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\n    /// decreasing the model's likelihood to repeat the same line verbatim.\n    #[serde(default)]\n    #[schema(example = \"1.0\")]\n    pub frequency_penalty: Option<f32>,\n\n    /// Up to 4 sequences where the API will stop generating further tokens.\n    #[serde(default)]\n    #[schema(nullable = true, example = \"null\")]\n    pub stop: Option<Vec<String>>,\n}\n\n#[derive(Clone, Serialize, ToSchema)]\n#[serde(tag = \"object\")]\nenum Completion {\n    #[serde(rename = \"text_completion\")]\n    Chunk(Chunk),\n    #[serde(rename = \"text_completion\")]\n    Final(CompletionFinal),\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]\npub(crate) struct CompletionFinal {\n    pub id: String,\n    #[schema(example = \"1706270835\")]\n    pub created: u64,\n    #[schema(example = \"mistralai/Mistral-7B-Instruct-v0.2\")]\n    pub model: String,\n    pub system_fingerprint: String,\n    pub choices: Vec<CompletionComplete>,\n    pub usage: Usage,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\npub(crate) struct CompletionComplete {\n    pub index: u32,\n    pub text: String,\n    pub logprobs: Option<Vec<f32>>,\n    pub finish_reason: String,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\npub(crate) struct Chunk {\n    pub id: String,\n    pub created: u64,\n    pub choices: Vec<CompletionComplete>,\n    pub model: String,\n    pub system_fingerprint: String,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug))]\npub(crate) struct ChatCompletion {\n    pub id: String,\n    #[schema(example = \"1706270835\")]\n    pub created: u64,\n    #[schema(example = \"mistralai/Mistral-7B-Instruct-v0.2\")]\n    pub model: String,\n    pub system_fingerprint: String,\n    pub choices: Vec<ChatCompletionComplete>,\n    pub usage: Usage,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug))]\npub(crate) struct ChatCompletionComplete {\n    pub index: u32,\n    pub message: OutputMessage,\n    pub logprobs: Option<ChatCompletionLogprobs>,\n    pub finish_reason: String,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct ChatCompletionLogprobs {\n    content: Vec<ChatCompletionLogprob>,\n}\n\nimpl From<(Token, Vec<Token>)> for ChatCompletionLogprobs {\n    fn from(value: (Token, Vec<Token>)) -> Self {\n        let (token, top_tokens) = value;\n\n        Self {\n            content: vec![ChatCompletionLogprob {\n                token: token.text,\n                logprob: token.logprob,\n                top_logprobs: top_tokens\n                    .into_iter()\n                    .map(|t| ChatCompletionTopLogprob {\n                        token: t.text,\n                        logprob: t.logprob,\n                    })\n                    .collect(),\n            }],\n        }\n    }\n}\n\nimpl From<(Vec<Token>, Vec<Vec<Token>>)> for ChatCompletionLogprobs {\n    fn from(value: (Vec<Token>, Vec<Vec<Token>>)) -> Self {\n        let (tokens, top_tokens) = value;\n\n        // Create an iterator that produces None for top_tokens once it's exhausted\n        let top_tokens_iter = top_tokens\n            .into_iter()\n            .map(Some)\n            .chain(std::iter::repeat(None));\n\n        let content = tokens\n            .into_iter()\n            .zip(top_tokens_iter)\n            .map(|(t, top_t_option)| ChatCompletionLogprob {\n                token: t.text,\n                logprob: t.logprob,\n                top_logprobs: match top_t_option {\n                    Some(top_t) => top_t\n                        .into_iter()\n                        .map(|t| ChatCompletionTopLogprob {\n                            token: t.text,\n                            logprob: t.logprob,\n                        })\n                        .collect(),\n                    None => vec![], // Handle the case where there are no top tokens\n                },\n            })\n            .collect();\n\n        Self { content }\n    }\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct ChatCompletionLogprob {\n    token: String,\n    logprob: f32,\n    top_logprobs: Vec<ChatCompletionTopLogprob>,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct ChatCompletionTopLogprob {\n    token: String,\n    logprob: f32,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct Usage {\n    pub prompt_tokens: u32,\n    pub completion_tokens: u32,\n    pub total_tokens: u32,\n}\n\n#[derive(Clone, Serialize, ToSchema)]\n#[serde(tag = \"object\")]\n#[cfg_attr(test, derive(Debug))]\nenum CompletionType {\n    #[serde(rename = \"chat.completion.chunk\")]\n    ChatCompletionChunk(ChatCompletionChunk),\n    #[serde(rename = \"chat.completion\")]\n    ChatCompletion(ChatCompletion),\n}\n\nimpl ChatCompletion {\n    #[allow(clippy::too_many_arguments)]\n    pub(crate) fn new(\n        model: String,\n        system_fingerprint: String,\n        output: Option<String>,\n        created: u64,\n        details: Details,\n        return_logprobs: bool,\n        tool_calls: Option<Vec<ToolCall>>,\n        prompt_tokens: u32,\n    ) -> Self {\n        let message = match (output, tool_calls) {\n            (Some(content), None) => OutputMessage::ChatMessage(TextMessage {\n                role: \"assistant\".into(),\n                content,\n                ..Default::default()\n            }),\n            (None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {\n                role: \"assistant\".to_string(),\n                tool_calls,\n            }),\n            (Some(output), Some(_)) => {\n                warn!(\"Received both chat and tool call\");\n                OutputMessage::ChatMessage(TextMessage {\n                    role: \"assistant\".into(),\n                    content: output,\n                    ..Default::default()\n                })\n            }\n            (None, None) => {\n                warn!(\"Didn't receive an answer\");\n                OutputMessage::ChatMessage(TextMessage {\n                    role: \"assistant\".into(),\n                    content: \"\".to_string(),\n                    ..Default::default()\n                })\n            }\n        };\n        Self {\n            id: String::new(),\n            created,\n            model,\n            system_fingerprint,\n            choices: vec![ChatCompletionComplete {\n                index: 0,\n                message,\n                logprobs: return_logprobs\n                    .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),\n                finish_reason: details.finish_reason.format(true),\n            }],\n            usage: Usage {\n                prompt_tokens,\n                completion_tokens: details.generated_tokens,\n                total_tokens: prompt_tokens + details.generated_tokens,\n            },\n        }\n    }\n}\n#[derive(Clone, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug))]\npub(crate) struct ChatCompletionChunk {\n    pub id: String,\n    #[schema(example = \"1706270978\")]\n    pub created: u64,\n    #[schema(example = \"mistralai/Mistral-7B-Instruct-v0.2\")]\n    pub model: String,\n    pub system_fingerprint: String,\n    pub choices: Vec<ChatCompletionChoice>,\n    pub usage: Option<Usage>,\n}\n\n#[derive(Clone, Serialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct ChatCompletionChoice {\n    pub index: u32,\n    pub delta: ChatCompletionDelta,\n    pub logprobs: Option<ChatCompletionLogprobs>,\n    pub finish_reason: Option<String>,\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]\npub struct ToolCallDelta {\n    #[schema(example = \"assistant\")]\n    role: String,\n    tool_calls: Vec<DeltaToolCall>,\n}\n\n#[derive(Clone, Debug, Serialize, ToSchema)]\n#[serde(untagged)]\n#[cfg_attr(test, derive(PartialEq))]\nenum ChatCompletionDelta {\n    Chat(TextMessage),\n    Tool(ToolCallDelta),\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]\npub(crate) struct DeltaToolCall {\n    pub index: u32,\n    pub id: String,\n    pub r#type: String,\n    pub function: Function,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]\npub(crate) struct Function {\n    pub name: Option<String>,\n    pub arguments: String,\n}\n\n#[allow(clippy::too_many_arguments)]\nimpl ChatCompletionChunk {\n    pub(crate) fn new(\n        model: String,\n        system_fingerprint: String,\n        created: u64,\n        choices: Vec<ChatCompletionChoice>,\n        usage: Option<Usage>,\n    ) -> Self {\n        Self {\n            id: String::new(),\n            created,\n            model,\n            system_fingerprint,\n            choices,\n            usage,\n        }\n    }\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize)]\n#[cfg_attr(test, derive(Debug, PartialEq, Default))]\npub(crate) struct ChatRequest {\n    #[schema(example = \"mistralai/Mistral-7B-Instruct-v0.2\")]\n    /// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.\n    pub model: Option<String>,\n\n    /// A list of messages comprising the conversation so far.\n    #[schema(example = \"[{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"What is Deep Learning?\\\"}]\")]\n    pub messages: Vec<Message>,\n\n    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,\n    /// decreasing the model's likelihood to repeat the same line verbatim.\n    #[serde(default)]\n    #[schema(example = \"1.0\")]\n    pub frequency_penalty: Option<f32>,\n\n    /// UNUSED\n    /// Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n    /// (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\n    /// the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\n    /// but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\n    /// result in a ban or exclusive selection of the relevant token.\n    #[serde(default)]\n    pub logit_bias: Option<Vec<f32>>,\n\n    /// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each\n    /// output token returned in the content of message.\n    #[serde(default)]\n    #[schema(example = \"false\")]\n    pub logprobs: Option<bool>,\n\n    /// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\n    /// an associated log probability. logprobs must be set to true if this parameter is used.\n    #[serde(default)]\n    #[schema(example = \"5\")]\n    pub top_logprobs: Option<u32>,\n\n    /// The maximum number of tokens that can be generated in the chat completion.\n    #[serde(default, alias = \"max_completion_tokens\")]\n    #[schema(default = \"1024\", example = \"32\")]\n    pub max_tokens: Option<u32>,\n\n    /// UNUSED\n    /// How many chat completion choices to generate for each input message. Note that you will be charged based on the\n    /// number of generated tokens across all of the choices. Keep n as 1 to minimize costs.\n    #[serde(default)]\n    #[schema(nullable = true, example = \"2\")]\n    pub n: Option<u32>,\n\n    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\n    /// increasing the model's likelihood to talk about new topics\n    #[serde(default)]\n    #[schema(nullable = true, example = 0.1)]\n    pub presence_penalty: Option<f32>,\n\n    /// Up to 4 sequences where the API will stop generating further tokens.\n    #[serde(default)]\n    #[schema(nullable = true, example = \"null\")]\n    pub stop: Option<Vec<String>>,\n\n    #[serde(default = \"bool::default\")]\n    pub stream: bool,\n\n    #[schema(nullable = true, example = 42)]\n    pub seed: Option<u64>,\n\n    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while\n    /// lower values like 0.2 will make it more focused and deterministic.\n    ///\n    /// We generally recommend altering this or `top_p` but not both.\n    #[serde(default)]\n    #[schema(nullable = true, example = 1.0)]\n    pub temperature: Option<f32>,\n\n    /// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\n    /// tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n    #[serde(default)]\n    #[schema(nullable = true, example = 0.95)]\n    pub top_p: Option<f32>,\n\n    /// A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of\n    /// functions the model may generate JSON inputs for.\n    #[serde(default)]\n    #[schema(nullable = true, example = \"null\")]\n    pub tools: Option<Vec<Tool>>,\n\n    /// A prompt to be appended before the tools\n    #[serde(default)]\n    #[schema(\n        nullable = true,\n        example = \"Given the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\"\n    )]\n    pub tool_prompt: Option<String>,\n\n    /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.\n    #[serde(default)]\n    #[schema(nullable = true, default = \"auto\", example = \"auto\")]\n    pub tool_choice: ToolChoice,\n\n    /// Response format constraints for the generation.\n    ///\n    /// NOTE: A request can use `response_format` OR `tools` but not both.\n    #[serde(default)]\n    #[schema(nullable = true, default = \"null\", example = \"null\")]\n    pub response_format: Option<GrammarType>,\n\n    /// Options for streaming response. Only set this when you set stream: true.\n    #[serde(default)]\n    #[schema(nullable = true, example = \"null\")]\n    pub stream_options: StreamOptions,\n}\n\nimpl ChatRequest {\n    fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {\n        let ChatRequest {\n            model,\n            max_tokens,\n            messages,\n            seed,\n            stop,\n            tools,\n            tool_choice,\n            tool_prompt,\n            temperature,\n            response_format,\n            presence_penalty,\n            frequency_penalty,\n            top_p,\n            top_logprobs,\n            ..\n        } = self;\n\n        let repetition_penalty = presence_penalty.map(|x| x + 2.0);\n        let max_new_tokens = max_tokens;\n        let tool_prompt = tool_prompt\n            .filter(|s| !s.is_empty())\n            .unwrap_or_else(default_tool_prompt);\n        let stop = stop.unwrap_or_default();\n        // enable greedy only when temperature is 0\n        let (do_sample, temperature) = match temperature {\n            Some(0.0) => (false, None),\n            other => (true, other),\n        };\n\n        if response_format.is_some() && tools.is_some() {\n            return Err(InferError::ToolError(\n                \"Grammar and tools are mutually exclusive\".into(),\n            ));\n        }\n\n        let (inputs, grammar, using_tools) = match response_format {\n            Some(format) => {\n                let inputs = infer.apply_chat_template(messages, None)?;\n                (inputs, Some(format), false)\n            }\n            None => {\n                if let Some(tools) = tools {\n                    match ToolGrammar::apply(tools, tool_choice)? {\n                        Some((updated_tools, tool_schema)) => {\n                            let grammar = GrammarType::Json(serde_json::json!(tool_schema));\n                            let inputs: String = infer.apply_chat_template(\n                                messages,\n                                Some((updated_tools, tool_prompt)),\n                            )?;\n                            (inputs, Some(grammar), true)\n                        }\n                        None => {\n                            // same as if no response_format or tools are set\n                            let inputs = infer.apply_chat_template(messages, None)?;\n                            (inputs, None, false)\n                        }\n                    }\n                } else {\n                    // if no response_format or tools are set simply apply the chat template to generate inputs\n                    let inputs = infer.apply_chat_template(messages, None)?;\n                    (inputs, None, false)\n                }\n            }\n        };\n\n        Ok((\n            GenerateRequest {\n                inputs: inputs.to_string(),\n                add_special_tokens: false,\n                parameters: GenerateParameters {\n                    best_of: None,\n                    temperature,\n                    repetition_penalty,\n                    frequency_penalty,\n                    top_k: None,\n                    top_p,\n                    typical_p: None,\n                    do_sample,\n                    max_new_tokens,\n                    return_full_text: None,\n                    stop,\n                    truncate: None,\n                    watermark: false,\n                    details: true,\n                    decoder_input_details: false,\n                    seed,\n                    top_n_tokens: top_logprobs,\n                    grammar,\n                    adapter_id: model.filter(|m| *m != \"tgi\"),\n                },\n            },\n            using_tools,\n        ))\n    }\n\n    fn next_int_id(&self) -> Result<String, Box<dyn std::error::Error>> {\n        let mut id: usize = 0;\n        for message in &self.messages {\n            if let MessageBody::Tool { tool_calls } = &message.body {\n                for tool_call in tool_calls {\n                    let new_id: usize = tool_call.id.parse()?;\n                    id = std::cmp::max(id, new_id + 1);\n                }\n            }\n        }\n        Ok(id.to_string())\n    }\n\n    /// Try to have linearly increasing id\n    /// or resort to using Uuid if the initial\n    /// scheme is not understood\n    fn next_tool_call_id(&self) -> String {\n        self.next_int_id().unwrap_or_else(|_| {\n            let uid = Uuid::new_v4().to_string();\n            uid.to_string()\n        })\n    }\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Default)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\nstruct StreamOptions {\n    /// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.\n    #[schema(example = \"true\")]\n    #[serde(default)]\n    include_usage: bool,\n}\n\npub fn default_tool_prompt() -> String {\n    \"\\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\\n\".to_string()\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema, PartialEq, Serialize)]\n#[serde(tag = \"type\")]\npub enum TypedChoice {\n    #[serde(rename = \"function\")]\n    Function { function: FunctionName },\n}\n\n#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]\npub struct FunctionName {\n    pub name: String,\n}\n\n#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]\n#[serde(from = \"ToolTypeDeserializer\")]\n#[serde(rename_all = \"snake_case\")]\n/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>\npub enum ToolChoice {\n    /// Means the model can pick between generating a message or calling one or more tools.\n    #[default]\n    Auto,\n    /// Means the model will not call any tool and instead generates a message.\n    #[serde(rename = \"none\")]\n    NoTool,\n    /// Means the model must call one or more tools.\n    Required,\n    /// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.\n    Function(FunctionName),\n}\n\n#[derive(Deserialize, ToSchema)]\n#[serde(untagged)]\n/// Controls which (if any) tool is called by the model.\n/// - `none` means the model will not call any tool and instead generates a message.\n/// - `auto` means the model can pick between generating a message or calling one or more tools.\n/// - `required` means the model must call one or more tools.\n/// - Specifying a particular tool via `{\\\"type\\\": \\\"function\\\", \\\"function\\\": {\\\"name\\\": \\\"my_function\\\"}}` forces the model to call that tool.\n///\n/// `none` is the default when no tools are present. `auto` is the default if tools are present.\"\nenum ToolTypeDeserializer {\n    /// None means `null` was passed in the JSON, and the default choice is applied based on the presence of tools.\n    Null,\n\n    /// `auto` means the model can pick between generating a message or calling one or more tools.\n    #[schema(example = \"auto\")]\n    String(String),\n\n    /// Specifying a particular tool forces the model to call that tool, with structured function details.\n    #[schema(example = r#\"{\"type\": \"function\", \"function\": {\"name\": \"my_function\"}}\"#)]\n    TypedChoice(TypedChoice),\n}\n\nimpl From<ToolTypeDeserializer> for ToolChoice {\n    fn from(value: ToolTypeDeserializer) -> Self {\n        match value {\n            ToolTypeDeserializer::Null => ToolChoice::Auto,\n            ToolTypeDeserializer::String(s) => match s.as_str() {\n                \"none\" => ToolChoice::NoTool,\n                \"auto\" => ToolChoice::Auto,\n                \"required\" => ToolChoice::Required,\n                _ => ToolChoice::Function(FunctionName { name: s }),\n            },\n            ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {\n                ToolChoice::Function(function)\n            }\n        }\n    }\n}\n\n#[derive(Debug, Deserialize, Serialize, ToSchema, PartialEq)]\npub struct JsonSchemaTool {\n    #[serde(flatten)]\n    functions_map: FunctionsMap,\n    properties: Properties,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]\nstruct FunctionsMap {\n    #[serde(rename = \"$functions\")]\n    functions: std::collections::HashMap<String, serde_json::Value>,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]\nstruct FunctionRef {\n    #[serde(rename = \"$ref\")]\n    ref_path: String,\n}\n\n#[derive(Debug, Serialize, Deserialize, ToSchema, PartialEq)]\nstruct Properties {\n    #[serde(serialize_with = \"serialize_function\")]\n    function: Vec<FunctionRef>,\n}\n\nfn serialize_function<S>(functions: &Vec<FunctionRef>, serializer: S) -> Result<S::Ok, S::Error>\nwhere\n    S: serde::Serializer,\n{\n    use serde::ser::SerializeStruct;\n    let mut state = serializer.serialize_struct(\"Function\", 1)?;\n    state.serialize_field(\"anyOf\", functions)?;\n    state.end()\n}\n\n#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]\npub struct FunctionDefinition {\n    #[serde(default)]\n    pub description: Option<String>,\n    pub name: String,\n    #[serde(alias = \"parameters\", serialize_with = \"serialize_as_string\")]\n    pub arguments: serde_json::Value,\n}\n\nfn serialize_as_string<S>(value: &serde_json::Value, serializer: S) -> Result<S::Ok, S::Error>\nwhere\n    S: serde::Serializer,\n{\n    serializer.serialize_str(&value.to_string())\n}\n\n#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]\n#[cfg_attr(test, derive(PartialEq))]\npub(crate) struct Tool {\n    // The type of the tool. Currently, only 'function' is supported.\n    #[schema(example = \"function\")]\n    pub r#type: String,\n    // Grab the tool as generic JSON for debugging purposes.\n    pub function: FunctionDefinition,\n}\n\n#[derive(Clone, Serialize, Deserialize, Default)]\npub(crate) struct ChatTemplateInputs<'a> {\n    messages: Vec<TextMessage>,\n    bos_token: Option<&'a str>,\n    eos_token: Option<&'a str>,\n    add_generation_prompt: bool,\n    tools: Option<Vec<Tool>>,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]\npub struct ToolCall {\n    pub id: String,\n    pub r#type: String,\n    pub function: FunctionDefinition,\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]\npub struct Url {\n    url: String,\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]\n#[serde(tag = \"type\")]\n#[serde(rename_all = \"snake_case\")]\npub enum MessageChunk {\n    Text { text: String },\n    ImageUrl { image_url: Url },\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]\npub struct Message {\n    #[schema(example = \"user\")]\n    pub role: String,\n    #[serde(flatten)]\n    #[schema(example = \"My name is David and I\")]\n    pub body: MessageBody,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    #[schema(example = \"\\\"David\\\"\")]\n    pub name: Option<String>,\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]\n#[serde(untagged)]\npub enum MessageBody {\n    // When a regular text message is provided.\n    Content {\n        #[serde(rename = \"content\")]\n        content: MessageContent,\n    },\n    // When tool calls are provided.\n    Tool {\n        #[serde(rename = \"tool_calls\")]\n        tool_calls: Vec<ToolCall>,\n    },\n}\n\n#[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]\n#[serde(untagged)]\npub enum MessageContent {\n    SingleText(String),\n    MultipleChunks(Vec<MessageChunk>),\n}\n\n// Pushing a chunk to a single text message will convert it to a multiple chunks message\nimpl MessageContent {\n    pub fn push(&mut self, chunk: MessageChunk) {\n        match self {\n            MessageContent::SingleText(text) => {\n                *self = MessageContent::MultipleChunks(vec![\n                    MessageChunk::Text { text: text.clone() },\n                    chunk,\n                ]);\n            }\n            MessageContent::MultipleChunks(chunks) => {\n                chunks.push(chunk);\n            }\n        }\n    }\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq, Default)]\npub struct TextMessage {\n    #[schema(example = \"user\")]\n    pub role: String,\n    #[schema(example = \"My name is David and I\")]\n    pub content: String,\n    #[serde(default, skip_serializing_if = \"Option::is_none\")]\n    pub tool_call_id: Option<String>,\n}\n\nimpl From<Message> for TextMessage {\n    fn from(value: Message) -> Self {\n        let content = match value.body {\n            MessageBody::Content { content } => content,\n            MessageBody::Tool { tool_calls } => {\n                let content = serde_json::to_string(&tool_calls).unwrap_or_default();\n                MessageContent::SingleText(content)\n            }\n        };\n        TextMessage {\n            role: value.role,\n            content: match content {\n                MessageContent::SingleText(text) => text,\n                MessageContent::MultipleChunks(chunks) => chunks\n                    .into_iter()\n                    .map(|chunk| match chunk {\n                        MessageChunk::Text { text } => text,\n                        MessageChunk::ImageUrl { image_url } => format!(\"![]({})\", image_url.url),\n                    })\n                    .collect::<Vec<_>>()\n                    .join(\"\"),\n            },\n            ..Default::default()\n        }\n    }\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]\npub struct ToolCallMessage {\n    #[schema(example = \"assistant\")]\n    role: String,\n    tool_calls: Vec<ToolCall>,\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]\n#[serde(untagged)]\npub(crate) enum OutputMessage {\n    ChatMessage(TextMessage),\n    ToolCall(ToolCallMessage),\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema)]\n#[cfg_attr(test, derive(PartialEq))]\npub(crate) struct GenerateRequest {\n    #[schema(example = \"My name is Olivier and I\")]\n    pub inputs: String,\n    #[serde(default = \"default_parameters\")]\n    pub parameters: GenerateParameters,\n\n    /// This is used internally because some requests\n    /// already contain the templated input therefore\n    /// we shouldn't add the special tokens.\n    #[serde(default = \"default_true\", skip)]\n    pub add_special_tokens: bool,\n}\n\nfn default_true() -> bool {\n    true\n}\n\n#[derive(Clone, Debug, Deserialize, ToSchema)]\npub(crate) struct CompatGenerateRequest {\n    #[schema(example = \"My name is Olivier and I\")]\n    pub inputs: String,\n    #[serde(default = \"default_parameters\")]\n    pub parameters: GenerateParameters,\n    #[serde(default)]\n    #[schema(default = \"false\")]\n    pub stream: bool,\n}\n\nimpl From<CompatGenerateRequest> for GenerateRequest {\n    fn from(req: CompatGenerateRequest) -> Self {\n        Self {\n            inputs: req.inputs,\n            add_special_tokens: true,\n            parameters: req.parameters,\n        }\n    }\n}\n\n#[derive(Debug, Serialize, ToSchema)]\npub struct PrefillToken {\n    #[schema(example = 0)]\n    pub id: u32,\n    #[schema(example = \"test\")]\n    pub text: String,\n    #[schema(nullable = true, example = - 0.34)]\n    pub logprob: f32,\n}\n\n#[derive(Debug, Serialize, ToSchema, Clone)]\npub struct Token {\n    #[schema(example = 0)]\n    pub id: u32,\n    #[schema(example = \"test\")]\n    pub text: String,\n    #[schema(nullable = true, example = - 0.34)]\n    pub logprob: f32,\n    #[schema(example = \"false\")]\n    pub special: bool,\n}\n\n#[derive(Debug, Serialize, ToSchema)]\npub struct SimpleToken {\n    #[schema(example = 0)]\n    id: u32,\n    #[schema(example = \"test\")]\n    text: String,\n    #[schema(example = 0)]\n    start: usize,\n    #[schema(example = 2)]\n    stop: usize,\n}\n\n#[derive(Debug, Serialize, ToSchema, Clone)]\n#[serde(rename_all(serialize = \"snake_case\"))]\n#[schema(example = \"Length\")]\npub enum FinishReason {\n    #[schema(rename = \"length\")]\n    Length,\n    #[serde(rename = \"eos_token\")]\n    #[schema(rename = \"eos_token\")]\n    EndOfSequenceToken,\n    #[schema(rename = \"stop_sequence\")]\n    StopSequence,\n}\n\nimpl std::fmt::Display for FinishReason {\n    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n        match self {\n            FinishReason::Length => write!(f, \"length\"),\n            FinishReason::EndOfSequenceToken => write!(f, \"eos_token\"),\n            FinishReason::StopSequence => write!(f, \"stop_sequence\"),\n        }\n    }\n}\n\nimpl FinishReason {\n    pub fn format(&self, use_stop: bool) -> String {\n        match self {\n            FinishReason::EndOfSequenceToken if use_stop => \"stop\".to_string(),\n            _ => self.to_string(),\n        }\n    }\n}\n\n#[derive(Serialize, ToSchema)]\npub(crate) struct BestOfSequence {\n    #[schema(example = \"test\")]\n    pub generated_text: String,\n    #[schema(example = \"length\")]\n    pub finish_reason: FinishReason,\n    #[schema(example = 1)]\n    pub generated_tokens: u32,\n    #[schema(nullable = true, example = 42)]\n    pub seed: Option<u64>,\n    pub prefill: Vec<PrefillToken>,\n    pub tokens: Vec<Token>,\n    #[serde(skip_serializing_if = \"Vec::is_empty\")]\n    pub top_tokens: Vec<Vec<Token>>,\n}\n\n#[derive(Serialize, ToSchema)]\npub(crate) struct Details {\n    #[schema(example = \"length\")]\n    pub finish_reason: FinishReason,\n    #[schema(example = 1)]\n    pub generated_tokens: u32,\n    #[schema(nullable = true, example = 42)]\n    pub seed: Option<u64>,\n    pub prefill: Vec<PrefillToken>,\n    pub tokens: Vec<Token>,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    pub best_of_sequences: Option<Vec<BestOfSequence>>,\n    #[serde(skip_serializing_if = \"Vec::is_empty\")]\n    pub top_tokens: Vec<Vec<Token>>,\n}\n\n#[derive(Serialize, ToSchema)]\npub(crate) struct GenerateResponse {\n    #[schema(example = \"test\")]\n    pub generated_text: String,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    pub details: Option<Details>,\n}\n\n#[derive(Serialize, ToSchema)]\npub(crate) struct ChatTokenizeResponse {\n    pub(crate) tokenize_response: TokenizeResponse,\n    pub(crate) templated_text: String,\n}\n\n#[derive(Serialize, ToSchema)]\n#[serde(transparent)]\npub(crate) struct TokenizeResponse(Vec<SimpleToken>);\n\n#[derive(Serialize, ToSchema, Clone)]\npub(crate) struct StreamDetails {\n    #[schema(example = \"length\")]\n    pub finish_reason: FinishReason,\n    #[schema(example = 1)]\n    pub generated_tokens: u32,\n    #[schema(nullable = true, example = 42)]\n    pub seed: Option<u64>,\n    #[schema(example = 1)]\n    pub input_length: u32,\n}\n\n#[derive(Serialize, ToSchema, Clone)]\npub(crate) struct StreamResponse {\n    pub index: u32,\n    pub token: Token,\n    #[serde(skip_serializing_if = \"Vec::is_empty\")]\n    pub top_tokens: Vec<Token>,\n    #[schema(nullable = true, default = \"null\", example = \"test\")]\n    pub generated_text: Option<String>,\n    #[schema(nullable = true, default = \"null\")]\n    pub details: Option<StreamDetails>,\n}\n\n#[derive(Serialize, ToSchema)]\npub(crate) struct ErrorResponse {\n    pub error: String,\n    pub error_type: String,\n}\n\n#[derive(Serialize, Deserialize, ToSchema)]\npub(crate) struct ModelInfo {\n    #[schema(example = \"gpt2\")]\n    pub id: String,\n    #[schema(example = \"model\")]\n    pub object: String,\n    #[schema(example = 1686935002)]\n    pub created: u64,\n    #[schema(example = \"openai\")]\n    pub owned_by: String,\n}\n\n#[derive(Serialize, Deserialize, ToSchema)]\npub(crate) struct ModelsInfo {\n    #[schema(example = \"list\")]\n    pub object: String,\n    pub data: Vec<ModelInfo>,\n}\n\nimpl Default for ModelsInfo {\n    fn default() -> Self {\n        ModelsInfo {\n            object: \"list\".to_string(),\n            data: Vec::new(),\n        }\n    }\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use serde_json::json;\n\n    pub(crate) fn get_tokenizer() -> Tokenizer {\n        let api = hf_hub::api::sync::Api::new().unwrap();\n        let repo = api.model(\"gpt2\".to_string());\n        let filename = repo.get(\"tokenizer.json\").unwrap();\n        Tokenizer::Rust(tokenizers::Tokenizer::from_file(filename).unwrap())\n    }\n\n    #[test]\n    fn test_hub_nested_tokens_tokenizer_config() {\n        // this is a subset of the tokenizer.json file\n        // in this case we expect the tokens to be encoded as simple strings\n        let json_content = r#\"{\n            \"chat_template\": \"test\",\n            \"bos_token\": \"<｜begin▁of▁sentence｜>\",\n            \"eos_token\": \"<｜end▁of▁sentence｜>\"\n        }\"#;\n\n        let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();\n\n        // check that we successfully parsed the tokens\n        assert_eq!(\n            config.chat_template,\n            Some(ChatTemplateVersions::Single(\"test\".to_string()))\n        );\n        assert_eq!(\n            config.bos_token,\n            Some(TokenizerConfigToken::String(\n                \"<｜begin▁of▁sentence｜>\".to_string()\n            ))\n        );\n        assert_eq!(\n            config.eos_token,\n            Some(TokenizerConfigToken::String(\n                \"<｜end▁of▁sentence｜>\".to_string()\n            ))\n        );\n\n        // in this case we expect the tokens to be encoded as structured tokens\n        // we want the content of the structured token\n        let json_content = r#\"{\n            \"chat_template\": \"test\",\n            \"bos_token\": {\n              \"__type\": \"AddedToken\",\n              \"content\": \"<｜begin▁of▁sentence｜>\",\n              \"lstrip\": false,\n              \"normalized\": true,\n              \"rstrip\": false,\n              \"single_word\": false\n            },\n            \"eos_token\": {\n              \"__type\": \"AddedToken\",\n              \"content\": \"<｜end▁of▁sentence｜>\",\n              \"lstrip\": false,\n              \"normalized\": true,\n              \"rstrip\": false,\n              \"single_word\": false\n            }\n        }\"#;\n\n        let config: HubTokenizerConfig = serde_json::from_str(json_content).unwrap();\n\n        // check that we successfully parsed the tokens\n        assert_eq!(\n            config.chat_template,\n            Some(ChatTemplateVersions::Single(\"test\".to_string()))\n        );\n        assert_eq!(\n            config.bos_token,\n            Some(TokenizerConfigToken::Object {\n                content: \"<｜begin▁of▁sentence｜>\".to_string()\n            })\n        );\n        assert_eq!(\n            config.eos_token,\n            Some(TokenizerConfigToken::Object {\n                content: \"<｜end▁of▁sentence｜>\".to_string()\n            })\n        );\n    }\n\n    #[test]\n    fn test_chat_simple_string() {\n        let json = json!({\n            \"model\": \"\",\n            \"messages\": [{\n                \"role\": \"user\",\n                \"content\": \"What is Deep Learning?\"\n            }]\n        });\n        let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();\n\n        assert_eq!(\n            request.messages[0],\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::SingleText(\"What is Deep Learning?\".to_string())\n                },\n            }\n        );\n    }\n\n    #[test]\n    fn test_message_content_append() {\n        let mut content = MessageContent::SingleText(\"Initial text\".to_string());\n        let chunk = MessageChunk::Text {\n            text: \"Additional text\".to_string(),\n        };\n\n        content.push(chunk);\n\n        match content {\n            MessageContent::MultipleChunks(chunks) => {\n                assert_eq!(chunks.len(), 2);\n                assert_eq!(\n                    chunks[0],\n                    MessageChunk::Text {\n                        text: \"Initial text\".to_string()\n                    }\n                );\n                assert_eq!(\n                    chunks[1],\n                    MessageChunk::Text {\n                        text: \"Additional text\".to_string()\n                    }\n                );\n            }\n            _ => panic!(\"Expected MultipleChunks, but got a different variant\"),\n        }\n    }\n\n    #[test]\n    fn test_chat_request() {\n        let json = json!({\n            \"model\": \"\",\n            \"messages\": [{\n                \"role\": \"user\",\n                \"content\": [\n                    {\"type\": \"text\", \"text\": \"Whats in this image?\"},\n                    {\"type\": \"image_url\", \"image_url\": {\"url\": \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\"}},\n                ]\n            }]\n        });\n        let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();\n\n        assert_eq!(\n            request.messages[0],\n            Message {\n                name: None,\n                role: \"user\".to_string(),\n\n                body: MessageBody::Content {\n                    content: MessageContent::MultipleChunks(vec![\n                        MessageChunk::Text { text: \"Whats in this image?\".to_string() },\n                        MessageChunk::ImageUrl { image_url: Url { url: \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\".to_string() }},\n                    ]),\n                },\n            }\n        );\n    }\n\n    #[test]\n    fn text_message_convert() {\n        let message = Message{\n            name: None,\n                role: \"user\".to_string(),\n                body: MessageBody::Content {\n                    content: MessageContent::MultipleChunks(vec![\n                        MessageChunk::Text { text: \"Whats in this image?\".to_string() },\n                        MessageChunk::ImageUrl { image_url: Url { url: \"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png\".to_string() } }\n                    ]),\n                }\n            };\n        let textmsg: TextMessage = message.into();\n        assert_eq!(textmsg.content, \"Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)\");\n    }\n\n    #[test]\n    fn test_chat_stream_options() {\n        let json = json!({\n            \"model\": \"\",\n            \"stream_options\": {\"include_usage\": true},\n            \"messages\": [{\n                \"role\": \"user\",\n                \"content\": \"Hello\"\n            }]\n        });\n        let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();\n\n        assert!(matches!(\n            request.stream_options,\n            StreamOptions {\n                include_usage: true\n            }\n        ));\n\n        let json = json!({\n            \"model\": \"\",\n            \"messages\": [{\n                \"role\": \"user\",\n                \"content\": \"Hello\"\n            }]\n        });\n        let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();\n\n        assert!(matches!(\n            request.stream_options,\n            StreamOptions {\n                include_usage: false\n            }\n        ));\n    }\n\n    #[test]\n    fn openai_output() {\n        let message = OutputMessage::ChatMessage(TextMessage {\n            role: \"assistant\".to_string(),\n            content: \"This is the answer\".to_string(),\n            ..Default::default()\n        });\n        let serialized = serde_json::to_string(&message).unwrap();\n        assert_eq!(\n            serialized,\n            r#\"{\"role\":\"assistant\",\"content\":\"This is the answer\"}\"#\n        );\n\n        let message = OutputMessage::ToolCall(ToolCallMessage {\n            role: \"assistant\".to_string(),\n            tool_calls: vec![ToolCall {\n                id: \"0\".to_string(),\n                r#type: \"function\".to_string(),\n                function: FunctionDefinition {\n                    description: None,\n                    name: \"myfn\".to_string(),\n                    arguments: json!({\n                        \"format\": \"csv\"\n                    }),\n                },\n            }],\n        });\n        let serialized = serde_json::to_string(&message).unwrap();\n        assert_eq!(\n            serialized,\n            r#\"{\"role\":\"assistant\",\"tool_calls\":[{\"id\":\"0\",\"type\":\"function\",\"function\":{\"description\":null,\"name\":\"myfn\",\"arguments\":\"{\\\"format\\\":\\\"csv\\\"}\"}}]}\"#\n        );\n    }\n\n    #[test]\n    fn tool_choice_formats() {\n        #[derive(Deserialize)]\n        struct TestRequest {\n            tool_choice: ToolChoice,\n        }\n\n        let de_none: TestRequest = serde_json::from_str(r#\"{\"tool_choice\":\"none\"}\"#).unwrap();\n        assert_eq!(de_none.tool_choice, ToolChoice::NoTool);\n\n        let de_auto: TestRequest = serde_json::from_str(r#\"{\"tool_choice\":\"auto\"}\"#).unwrap();\n        assert_eq!(de_auto.tool_choice, ToolChoice::Auto);\n\n        let de_required: TestRequest =\n            serde_json::from_str(r#\"{\"tool_choice\":\"required\"}\"#).unwrap();\n        assert_eq!(de_required.tool_choice, ToolChoice::Required);\n\n        let de_named: TestRequest = serde_json::from_str(r#\"{\"tool_choice\":\"myfn\"}\"#).unwrap();\n        assert_eq!(\n            de_named.tool_choice,\n            ToolChoice::Function(FunctionName {\n                name: \"myfn\".to_string(),\n            })\n        );\n\n        let de_openai_named: TestRequest = serde_json::from_str(\n            r#\"{\"tool_choice\":{\"type\":\"function\",\"function\":{\"name\":\"myfn\"}}}\"#,\n        )\n        .unwrap();\n        assert_eq!(\n            de_openai_named.tool_choice,\n            ToolChoice::Function(FunctionName {\n                name: \"myfn\".to_string(),\n            })\n        );\n    }\n}\n"
  },
  {
    "path": "router/src/logging.rs",
    "content": "use axum::{extract::Request, middleware::Next, response::Response};\nuse opentelemetry::sdk::propagation::TraceContextPropagator;\nuse opentelemetry::sdk::trace;\nuse opentelemetry::sdk::trace::Sampler;\nuse opentelemetry::sdk::Resource;\nuse opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};\nuse opentelemetry::Context;\nuse opentelemetry::{global, KeyValue};\nuse opentelemetry_otlp::WithExportConfig;\nuse tracing_subscriber::layer::SubscriberExt;\nuse tracing_subscriber::util::SubscriberInitExt;\nuse tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};\n\nstruct TraceParent {\n    #[allow(dead_code)]\n    version: u8,\n    trace_id: TraceId,\n    parent_id: SpanId,\n    trace_flags: TraceFlags,\n}\n\nfn parse_traceparent(header_value: &str) -> Option<TraceParent> {\n    let parts: Vec<&str> = header_value.split('-').collect();\n    if parts.len() != 4 {\n        return None;\n    }\n\n    let version = u8::from_str_radix(parts[0], 16).ok()?;\n    if version == 0xff {\n        return None;\n    }\n\n    let trace_id = TraceId::from_hex(parts[1]).ok()?;\n    let parent_id = SpanId::from_hex(parts[2]).ok()?;\n    let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;\n\n    Some(TraceParent {\n        version,\n        trace_id,\n        parent_id,\n        trace_flags: TraceFlags::new(trace_flags),\n    })\n}\n\npub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {\n    let context = request\n        .headers()\n        .get(\"traceparent\")\n        .and_then(|v| v.to_str().ok())\n        .and_then(parse_traceparent)\n        .map(|traceparent| {\n            Context::new().with_remote_span_context(SpanContext::new(\n                traceparent.trace_id,\n                traceparent.parent_id,\n                traceparent.trace_flags,\n                true,\n                Default::default(),\n            ))\n        });\n\n    request.extensions_mut().insert(context);\n\n    next.run(request).await\n}\n\n/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:\n///     - otlp_endpoint is an optional URL to an Open Telemetry collector\n///     - otlp_service_name service name to appear in APM\n///     - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)\n///     - LOG_FORMAT may be TEXT or JSON (default to TEXT)\n///     - LOG_COLORIZE may be \"false\" or \"true\" (default to \"true\" or ansi supported platforms)\npub fn init_logging(otlp_endpoint: Option<String>, otlp_service_name: String, json_output: bool) {\n    let mut layers = Vec::new();\n\n    // STDOUT/STDERR layer\n    let ansi = std::env::var(\"LOG_COLORIZE\") != Ok(\"1\".to_string());\n    let fmt_layer = tracing_subscriber::fmt::layer()\n        .with_file(true)\n        .with_ansi(ansi)\n        .with_line_number(true);\n\n    let fmt_layer = match json_output {\n        true => fmt_layer.json().flatten_event(true).boxed(),\n        false => fmt_layer.boxed(),\n    };\n    layers.push(fmt_layer);\n\n    // OpenTelemetry tracing layer\n    if let Some(otlp_endpoint) = otlp_endpoint {\n        global::set_text_map_propagator(TraceContextPropagator::new());\n\n        let tracer = opentelemetry_otlp::new_pipeline()\n            .tracing()\n            .with_exporter(\n                opentelemetry_otlp::new_exporter()\n                    .tonic()\n                    .with_endpoint(otlp_endpoint),\n            )\n            .with_trace_config(\n                trace::config()\n                    .with_resource(Resource::new(vec![KeyValue::new(\n                        \"service.name\",\n                        otlp_service_name,\n                    )]))\n                    .with_sampler(Sampler::AlwaysOn),\n            )\n            .install_batch(opentelemetry::runtime::Tokio);\n\n        if let Ok(tracer) = tracer {\n            layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed());\n            init_tracing_opentelemetry::init_propagator().unwrap();\n        };\n    }\n\n    // Filter events with LOG_LEVEL\n    let varname = \"LOG_LEVEL\";\n    let env_filter = if let Ok(log_level) = std::env::var(varname) {\n        // Override to avoid simple logs to be spammed with tokio level informations\n        let log_level = match &log_level[..] {\n            \"warn\" => \"text_generation_launcher=warn,text_generation_router=warn\",\n            \"info\" => \"text_generation_launcher=info,text_generation_router=info\",\n            \"debug\" => \"text_generation_launcher=debug,text_generation_router=debug\",\n            log_level => log_level,\n        };\n        EnvFilter::builder()\n            .with_default_directive(LevelFilter::INFO.into())\n            .parse_lossy(log_level)\n    } else {\n        EnvFilter::new(\"info\")\n    };\n\n    tracing_subscriber::registry()\n        .with(env_filter)\n        .with(layers)\n        .init();\n}\n"
  },
  {
    "path": "router/src/sagemaker.rs",
    "content": "use crate::infer::Infer;\nuse crate::server::{chat_completions, compat_generate, completions, ComputeType};\nuse crate::{\n    ChatCompletion, ChatCompletionChunk, ChatRequest, Chunk, CompatGenerateRequest,\n    CompletionFinal, CompletionRequest, ErrorResponse, GenerateResponse, Info, StreamResponse,\n};\nuse axum::extract::Extension;\nuse axum::http::StatusCode;\nuse axum::response::Response;\nuse axum::Json;\nuse serde::{Deserialize, Serialize};\nuse tracing::instrument;\nuse utoipa::ToSchema;\n\n#[derive(Clone, Deserialize, ToSchema)]\n#[serde(untagged)]\npub(crate) enum SagemakerRequest {\n    Generate(CompatGenerateRequest),\n    Chat(ChatRequest),\n    Completion(CompletionRequest),\n}\n\n// Used for OpenAPI specs\n#[allow(dead_code)]\n#[derive(Serialize, ToSchema)]\n#[serde(untagged)]\npub(crate) enum SagemakerResponse {\n    Generate(GenerateResponse),\n    Chat(ChatCompletion),\n    Completion(CompletionFinal),\n}\n\n// Used for OpenAPI specs\n#[allow(dead_code)]\n#[derive(Serialize, ToSchema)]\n#[serde(untagged)]\npub(crate) enum SagemakerStreamResponse {\n    Generate(StreamResponse),\n    Chat(ChatCompletionChunk),\n    Completion(Chunk),\n}\n\n/// Generate tokens from Sagemaker request\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/invocations\",\nrequest_body = SagemakerRequest,\nresponses(\n(status = 200, description = \"Generated Chat Completion\",\ncontent(\n(\"application/json\" = SagemakerResponse),\n(\"text/event-stream\" = SagemakerStreamResponse),\n)),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"})),\n)\n)]\n#[instrument(skip_all)]\npub(crate) async fn sagemaker_compatibility(\n    default_return_full_text: Extension<bool>,\n    infer: Extension<Infer>,\n    compute_type: Extension<ComputeType>,\n    context: Extension<Option<opentelemetry::Context>>,\n    info: Extension<Info>,\n    Json(req): Json<SagemakerRequest>,\n) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {\n    match req {\n        SagemakerRequest::Generate(req) => {\n            compat_generate(\n                default_return_full_text,\n                infer,\n                compute_type,\n                context,\n                Json(req),\n            )\n            .await\n        }\n        SagemakerRequest::Chat(req) => {\n            chat_completions(infer, compute_type, info, context, Json(req)).await\n        }\n        SagemakerRequest::Completion(req) => {\n            completions(infer, compute_type, info, context, Json(req)).await\n        }\n    }\n}\n"
  },
  {
    "path": "router/src/server.rs",
    "content": "use crate::chat::{ChatChoice, ChatEvent, ChatState};\n/// HTTP Server logic\nuse crate::config::Config;\nuse crate::infer::{Backend, Infer, InferError, InferResponse, InferStreamResponse};\n#[cfg(feature = \"kserve\")]\nuse crate::kserve::{\n    kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,\n    kserve_model_metadata, kserve_model_metadata_ready,\n};\nuse crate::logging::trace_context_middleware;\nuse crate::sagemaker::{\n    sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,\n    __path_sagemaker_compatibility,\n};\nuse crate::validation::ValidationError;\nuse crate::vertex::vertex_compatibility;\nuse crate::{\n    usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,\n    GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,\n    HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,\n    OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,\n    TextMessage, Token, TokenizeResponse, Tokenizer, ToolCallDelta, ToolCallMessage, Url, Usage,\n    Validation,\n};\nuse crate::{\n    ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,\n    ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,\n    ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,\n    CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,\n};\nuse crate::{ChatTokenizeResponse, JsonSchemaConfig};\nuse crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};\nuse crate::{MessageBody, ModelInfo, ModelsInfo};\nuse async_stream::__private::AsyncStream;\nuse axum::extract::{DefaultBodyLimit, Extension};\nuse axum::http::{HeaderMap, HeaderValue, Method, StatusCode};\nuse axum::response::sse::{Event, KeepAlive, Sse};\nuse axum::response::{IntoResponse, Response};\nuse axum::routing::{get, post};\nuse axum::{http, Json, Router};\nuse axum_tracing_opentelemetry::middleware::OtelAxumLayer;\nuse futures::stream::StreamExt;\nuse futures::stream::{FuturesOrdered, FuturesUnordered};\nuse futures::Stream;\nuse futures::TryStreamExt;\nuse hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};\nuse hf_hub::{Cache, Repo, RepoType};\nuse http::header::AUTHORIZATION;\nuse metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};\nuse pyo3::prelude::*;\nuse pyo3::types::IntoPyDict;\nuse std::convert::Infallible;\nuse std::fs::File;\nuse std::io::BufReader;\nuse std::net::{IpAddr, Ipv4Addr, SocketAddr};\nuse std::path::{Path, PathBuf};\nuse std::sync::atomic::{AtomicBool, Ordering};\nuse std::sync::Arc;\nuse std::time::Duration;\nuse thiserror::Error;\nuse tokio::select;\nuse tokio::signal;\nuse tokio::sync::oneshot;\nuse tokio::time::Instant;\nuse tower_http::cors::{AllowOrigin, CorsLayer};\nuse tracing::{info_span, instrument, Instrument};\nuse tracing_opentelemetry::OpenTelemetrySpanExt;\nuse utoipa::OpenApi;\nuse utoipa_swagger_ui::SwaggerUi;\n\nfn encoding_to_tokens(encoding: &tokenizers::Encoding, input: &str) -> Vec<SimpleToken> {\n    let offsets = encoding.get_offsets();\n    let input_ids = encoding.get_ids();\n    if offsets.len() == input_ids.len() {\n        input_ids\n            .iter()\n            .zip(offsets)\n            .map(|(&id, &(start, stop))| {\n                let text: Vec<u8> = input.bytes().skip(start).take(stop - start).collect();\n                let text: String = String::from_utf8_lossy(&text).to_string();\n                SimpleToken {\n                    id,\n                    text,\n                    start,\n                    stop,\n                }\n            })\n            .collect()\n    } else {\n        encoding\n            .get_ids()\n            .iter()\n            .map(|&id| SimpleToken {\n                id,\n                text: \"\".to_string(),\n                start: 0,\n                stop: 0,\n            })\n            .collect()\n    }\n}\n\n/// Generate tokens if `stream == false` or a stream of token if `stream == true`\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/\",\nrequest_body = CompatGenerateRequest,\nresponses(\n(status = 200, description = \"Generated Text\",\ncontent(\n(\"application/json\" = Vec<GenerateResponse>),\n(\"text/event-stream\" = StreamResponse),\n)),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"})),\n)\n)]\n#[instrument(skip(infer, req))]\npub(crate) async fn compat_generate(\n    Extension(default_return_full_text): Extension<bool>,\n    infer: Extension<Infer>,\n    compute_type: Extension<ComputeType>,\n    context: Extension<Option<opentelemetry::Context>>,\n    Json(mut req): Json<CompatGenerateRequest>,\n) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {\n    // default return_full_text given the pipeline_tag\n    if req.parameters.return_full_text.is_none() {\n        req.parameters.return_full_text = Some(default_return_full_text)\n    }\n\n    // switch on stream\n    if req.stream {\n        Ok(\n            generate_stream(infer, compute_type, context, Json(req.into()))\n                .await\n                .into_response(),\n        )\n    } else {\n        let (headers, Json(generation)) =\n            generate(infer, compute_type, context, Json(req.into())).await?;\n        // wrap generation inside a Vec to match api-inference\n        Ok((headers, Json(vec![generation])).into_response())\n    }\n}\n\n/// Text Generation Inference endpoint info\n#[utoipa::path(\nget,\ntag = \"Text Generation Inference\",\npath = \"/info\",\nresponses((status = 200, description = \"Served model info\", body = Info))\n)]\n#[instrument]\nasync fn get_model_info(info: Extension<Info>) -> Json<Info> {\n    Json(info.0)\n}\n\n#[utoipa::path(\nget,\ntag = \"Text Generation Inference\",\npath = \"/v1/models\",\nresponses(\n(status = 200, description = \"Served model info\", body = ModelInfo),\n(status = 404, description = \"Model not found\", body = ErrorResponse),\n)\n)]\n#[instrument(skip(info))]\n/// Get model info\nasync fn openai_get_model_info(info: Extension<Info>) -> Json<ModelsInfo> {\n    Json(ModelsInfo {\n        data: vec![ModelInfo {\n            id: info.0.model_id.clone(),\n            object: \"model\".to_string(),\n            created: 0, // TODO: determine how to get this\n            owned_by: info.0.model_id.clone(),\n        }],\n        ..Default::default()\n    })\n}\n\n/// Template and tokenize ChatRequest\n#[utoipa::path(\n    post,\n    tag = \"Text Generation Inference\",\n    path = \"/chat_tokenize\",\n    request_body = ChatRequest,\n    responses(\n    (status = 200, description = \"Templated and tokenized ChatRequest\", body = ChatTokenizeResponse),\n    (status = 404, description = \"Failed to tokenize ChatRequest\", body = ErrorResponse),\n    )\n)]\nasync fn get_chat_tokenize(\n    Extension(infer): Extension<Infer>,\n    Json(chat): Json<ChatRequest>,\n) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {\n    metrics::counter!(\"tgi_request_count\").increment(1);\n\n    let generate_request: GenerateRequest = chat.try_into_generate(&infer)?.0;\n    let input = generate_request.inputs.clone();\n    let encoding = infer.tokenize(generate_request).await?;\n\n    let tokens = encoding_to_tokens(&encoding, &input);\n\n    let resp = ChatTokenizeResponse {\n        tokenize_response: TokenizeResponse(tokens),\n        templated_text: input,\n    };\n    Ok((HeaderMap::new(), Json(resp)))\n}\n\n#[utoipa::path(\nget,\ntag = \"Text Generation Inference\",\npath = \"/health\",\nresponses(\n(status = 200, description = \"Everything is working fine\"),\n(status = 503, description = \"Text generation inference is down\", body = ErrorResponse,\nexample = json ! ({\"error\": \"unhealthy\", \"error_type\": \"healthcheck\"})),\n)\n)]\n#[instrument(skip(infer))]\n/// Health check method\nasync fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {\n    match infer.health().await {\n        true => Ok(()),\n        false => Err((\n            StatusCode::SERVICE_UNAVAILABLE,\n            Json(ErrorResponse {\n                error: \"unhealthy\".to_string(),\n                error_type: \"healthcheck\".to_string(),\n            }),\n        )),\n    }\n}\n\n/// Generate tokens\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/generate\",\nrequest_body = GenerateRequest,\nresponses(\n(status = 200, description = \"Generated Text\", body = GenerateResponse),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"})),\n)\n)]\n#[instrument(\nskip_all,\nfields(\nparameters = ? req.parameters,\ntotal_time,\nvalidation_time,\nqueue_time,\ninference_time,\ntime_per_token,\nseed,\n)\n)]\nasync fn generate(\n    infer: Extension<Infer>,\n    Extension(ComputeType(compute_type)): Extension<ComputeType>,\n    Extension(context): Extension<Option<opentelemetry::Context>>,\n    Json(req): Json<GenerateRequest>,\n) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {\n    let span = tracing::Span::current();\n    if let Some(context) = context {\n        span.set_parent(context);\n    }\n\n    let (headers, _, response) =\n        generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;\n    Ok((headers, response))\n}\n\npub(crate) async fn generate_internal(\n    infer: Extension<Infer>,\n    ComputeType(compute_type): ComputeType,\n    Json(req): Json<GenerateRequest>,\n    span: tracing::Span,\n) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {\n    let start_time = Instant::now();\n    metrics::counter!(\"tgi_request_count\").increment(1);\n\n    // Do not long ultra long inputs, like image payloads.\n    tracing::debug!(\n        \"Input: {}\",\n        &req.inputs.chars().take(1000).collect::<String>()\n    );\n\n    let compute_characters = req.inputs.chars().count();\n    let mut add_prompt = None;\n    if req.parameters.return_full_text.unwrap_or(false) {\n        add_prompt = Some(req.inputs.clone());\n    }\n\n    let details: bool = req.parameters.details || req.parameters.decoder_input_details;\n\n    // Inference\n    let (response, best_of_responses) = match req.parameters.best_of {\n        Some(best_of) if best_of > 1 => {\n            let (response, best_of_responses) = infer.generate_best_of(req, best_of).await?;\n            (response, Some(best_of_responses))\n        }\n        _ => (infer.generate(req).await?, None),\n    };\n\n    // Token details\n    let input_length = response._input_length;\n    let details = match details {\n        true => {\n            // convert best_of_responses\n            let best_of_sequences = best_of_responses.map(|responses: Vec<InferResponse>| {\n                responses\n                    .into_iter()\n                    .map(|response: InferResponse| {\n                        // Add prompt if return_full_text\n                        let mut output_text = response.generated_text.text;\n                        if let Some(prompt) = &add_prompt {\n                            output_text = prompt.clone() + &output_text;\n                        }\n\n                        BestOfSequence {\n                            generated_text: output_text,\n                            finish_reason: response.generated_text.finish_reason,\n                            generated_tokens: response.generated_text.generated_tokens,\n                            prefill: response.prefill,\n                            tokens: response.tokens,\n                            top_tokens: response.top_tokens,\n                            seed: response.generated_text.seed,\n                        }\n                    })\n                    .collect()\n            });\n\n            Some(Details {\n                finish_reason: response.generated_text.finish_reason,\n                generated_tokens: response.generated_text.generated_tokens,\n                prefill: response.prefill,\n                tokens: response.tokens,\n                seed: response.generated_text.seed,\n                best_of_sequences,\n                top_tokens: response.top_tokens,\n            })\n        }\n        false => None,\n    };\n\n    // Timings\n    let total_time = start_time.elapsed();\n    let validation_time = response.queued - start_time;\n    let queue_time = response.start - response.queued;\n    let inference_time = Instant::now() - response.start;\n    let time_per_token = inference_time / response.generated_text.generated_tokens;\n\n    // Tracing metadata\n    span.record(\"total_time\", format!(\"{total_time:?}\"));\n    span.record(\"validation_time\", format!(\"{validation_time:?}\"));\n    span.record(\"queue_time\", format!(\"{queue_time:?}\"));\n    span.record(\"inference_time\", format!(\"{inference_time:?}\"));\n    span.record(\"time_per_token\", format!(\"{time_per_token:?}\"));\n    span.record(\"seed\", format!(\"{:?}\", response.generated_text.seed));\n\n    // Headers\n    let mut headers = HeaderMap::new();\n    headers.insert(\"x-compute-type\", compute_type.parse().unwrap());\n    headers.insert(\n        \"x-compute-time\",\n        total_time.as_secs_f64().to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-compute-characters\",\n        compute_characters.to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-total-time\",\n        total_time.as_millis().to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-validation-time\",\n        validation_time.as_millis().to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-queue-time\",\n        queue_time.as_millis().to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-inference-time\",\n        inference_time.as_millis().to_string().parse().unwrap(),\n    );\n    headers.insert(\n        \"x-time-per-token\",\n        time_per_token.as_millis().to_string().parse().unwrap(),\n    );\n    headers.insert(\"x-prompt-tokens\", input_length.into());\n    headers.insert(\n        \"x-generated-tokens\",\n        response.generated_text.generated_tokens.into(),\n    );\n\n    // Metrics\n    metrics::counter!(\"tgi_request_success\").increment(1);\n    metrics::histogram!(\"tgi_request_duration\").record(total_time.as_secs_f64());\n    metrics::histogram!(\"tgi_request_validation_duration\").record(validation_time.as_secs_f64());\n    metrics::histogram!(\"tgi_request_queue_duration\").record(queue_time.as_secs_f64());\n    metrics::histogram!(\"tgi_request_inference_duration\").record(inference_time.as_secs_f64());\n    metrics::histogram!(\"tgi_request_mean_time_per_token_duration\")\n        .record(time_per_token.as_secs_f64());\n    metrics::histogram!(\"tgi_request_generated_tokens\")\n        .record(response.generated_text.generated_tokens as f64);\n\n    // Send response\n    let mut output_text = response.generated_text.text;\n    if let Some(prompt) = add_prompt {\n        output_text = prompt + &output_text;\n    }\n\n    tracing::debug!(\"Output: {}\", output_text);\n    tracing::info!(\"Success\");\n\n    let response = GenerateResponse {\n        generated_text: output_text,\n        details,\n    };\n    Ok((headers, input_length, Json(response)))\n}\n\n/// Generate a stream of token using Server-Sent Events\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/generate_stream\",\nrequest_body = GenerateRequest,\nresponses(\n(status = 200, description = \"Generated Text\", body = StreamResponse,\ncontent_type = \"text/event-stream\"),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"}),\ncontent_type = \"text/event-stream\"),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"}),\ncontent_type = \"text/event-stream\"),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"}),\ncontent_type = \"text/event-stream\"),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"}),\ncontent_type = \"text/event-stream\"),\n)\n)]\n#[instrument(\nskip_all,\nfields(\nparameters = ? req.parameters,\ntotal_time,\nvalidation_time,\nqueue_time,\ninference_time,\ntime_per_token,\nseed,\n)\n)]\nasync fn generate_stream(\n    Extension(infer): Extension<Infer>,\n    Extension(compute_type): Extension<ComputeType>,\n    Extension(context): Extension<Option<opentelemetry::Context>>,\n    Json(req): Json<GenerateRequest>,\n) -> (\n    HeaderMap,\n    Sse<impl Stream<Item = Result<Event, Infallible>>>,\n) {\n    let span = tracing::Span::current();\n    if let Some(context) = context {\n        span.set_parent(context);\n    }\n\n    let (headers, response_stream) =\n        generate_stream_internal(infer, compute_type, Json(req), span).await;\n\n    let response_stream = async_stream::stream! {\n        let mut response_stream = Box::pin(response_stream);\n        while let Some(raw_event) = response_stream.next().await {\n            yield Ok(raw_event.map_or_else(Event::from, |token| {\n                Event::default()\n                    .json_data(token)\n                    .unwrap_or_else(|e| InferError::StreamSerializationError(e.to_string()).into())\n            }));\n        }\n    };\n\n    let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());\n    (headers, sse)\n}\n\nasync fn generate_stream_internal(\n    infer: Infer,\n    ComputeType(compute_type): ComputeType,\n    Json(req): Json<GenerateRequest>,\n    span: tracing::Span,\n) -> (\n    HeaderMap,\n    impl Stream<Item = Result<StreamResponse, InferError>>,\n) {\n    let start_time = Instant::now();\n    metrics::counter!(\"tgi_request_count\").increment(1);\n\n    tracing::debug!(\"Input: {}\", req.inputs);\n\n    let compute_characters = req.inputs.chars().count();\n\n    let mut headers = HeaderMap::new();\n    headers.insert(\"x-compute-type\", compute_type.parse().unwrap());\n    headers.insert(\n        \"x-compute-characters\",\n        compute_characters.to_string().parse().unwrap(),\n    );\n    headers.insert(\"X-Accel-Buffering\", \"no\".parse().unwrap());\n\n    let stream = async_stream::stream! {\n        // Inference\n        let mut end_reached = false;\n        let mut error = false;\n\n        let mut add_prompt = None;\n        if req.parameters.return_full_text.unwrap_or(false) {\n            add_prompt = Some(req.inputs.clone());\n        }\n        let details = req.parameters.details;\n\n        let best_of = req.parameters.best_of.unwrap_or(1);\n        if best_of != 1 {\n            let err = InferError::from(ValidationError::BestOfStream);\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"validation\").increment(1);\n            tracing::error!(\"{err}\");\n            yield Err(err);\n        } else if req.parameters.decoder_input_details {\n            let err = InferError::from(ValidationError::PrefillDetailsStream);\n            metrics::counter!(\"tgi_request_failure\", \"err\" => \"validation\").increment(1);\n            tracing::error!(\"{err}\");\n            yield Err(err);\n        } else {\n            match infer.generate_stream(req).instrument(info_span!(parent: &span, \"async_stream\")).await {\n                // Keep permit as long as generate_stream lives\n                Ok((_permit, input_length, response_stream)) => {\n                    let mut index = 0;\n                    let mut response_stream = Box::pin(response_stream);\n                    // Server-Sent Event stream\n                    while let Some(response) = response_stream.next().await {\n                        index += 1;\n                        match response {\n                            Ok(response) => {\n                                match response {\n                                    // Prefill is ignored\n                                    InferStreamResponse::Prefill(_) => {}\n                                    // Yield event for every new token\n                                    InferStreamResponse::Intermediate{\n                                        token,\n                                        top_tokens,\n                                    } => {\n                                        tracing::debug!(parent: &span, \"Token: {:?}\", token);\n\n                                        // StreamResponse\n                                        let stream_token = StreamResponse {\n                                            index,\n                                            token,\n                                            top_tokens,\n                                            generated_text: None,\n                                            details: None,\n                                        };\n                                        yield Ok(stream_token);\n                                    }\n                                    // Yield event for last token and compute timings\n                                    InferStreamResponse::End {\n                                        token,\n                                        generated_text,\n                                        start,\n                                        queued,\n                                        top_tokens,\n                                    } => {\n                                        // Token details\n                                        let details = match details {\n                                            true => Some(StreamDetails {\n                                                finish_reason: generated_text.finish_reason,\n                                                generated_tokens: generated_text.generated_tokens,\n                                                seed: generated_text.seed,\n                                                input_length,\n                                            }),\n                                            false => None,\n                                        };\n\n                                        // Timings\n                                        let total_time = start_time.elapsed();\n                                        let validation_time = queued - start_time;\n                                        let queue_time = start - queued;\n                                        let inference_time = Instant::now() - start;\n                                        let time_per_token = inference_time / generated_text.generated_tokens;\n\n                                        // Tracing metadata\n                                        span.record(\"total_time\", format!(\"{total_time:?}\"));\n                                        span.record(\"validation_time\", format!(\"{validation_time:?}\"));\n                                        span.record(\"queue_time\", format!(\"{queue_time:?}\"));\n                                        span.record(\"inference_time\", format!(\"{inference_time:?}\"));\n                                        span.record(\"time_per_token\", format!(\"{time_per_token:?}\"));\n                                        span.record(\"seed\", format!(\"{:?}\", generated_text.seed));\n\n                                        // Metrics\n                                        metrics::counter!(\"tgi_request_success\").increment(1);\n                                        metrics::histogram!(\"tgi_request_duration\").record(total_time.as_secs_f64());\n                                        metrics::histogram!(\"tgi_request_validation_duration\").record(validation_time.as_secs_f64());\n                                        metrics::histogram!(\"tgi_request_queue_duration\").record(queue_time.as_secs_f64());\n                                        metrics::histogram!(\"tgi_request_inference_duration\").record(inference_time.as_secs_f64());\n                                        metrics::histogram!(\"tgi_request_mean_time_per_token_duration\").record(time_per_token.as_secs_f64());\n                                        metrics::histogram!(\"tgi_request_generated_tokens\").record(generated_text.generated_tokens as f64);\n\n                                        // StreamResponse\n                                        end_reached = true;\n\n                                        let mut output_text = generated_text.text;\n                                        if let Some(prompt) = add_prompt {\n                                            output_text = prompt + &output_text;\n                                        }\n\n                                        tracing::debug!(parent: &span, \"Output: {}\", output_text);\n                                        tracing::info!(parent: &span, \"Success\");\n\n                                        let stream_token = StreamResponse {\n                                            index,\n                                            token,\n                                            top_tokens,\n                                            generated_text: Some(output_text),\n                                            details\n                                        };\n\n                                        yield Ok(stream_token);\n                                        break;\n                                    }\n                                }\n                            }\n                            // yield error\n                            Err(err) => {\n                                error = true;\n                                yield Err(err);\n                                break;\n                            }\n                        }\n                    }\n                },\n                // yield error\n                Err(err) => {\n                    error = true;\n                    yield Err(err);\n                }\n            }\n            // Check if generation reached the end\n            // Skip if we already sent an error\n            if !end_reached && !error {\n                let err = InferError::IncompleteGenerationStream;\n                metrics::counter!(\"tgi_request_failure\", \"err\" => \"incomplete\").increment(1);\n                tracing::error!(\"{err}\");\n                yield Err(err);\n            }\n        }\n    };\n\n    (headers, stream)\n}\n\n/// Generate tokens\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/v1/completions\",\nrequest_body = CompletionRequest,\nresponses(\n(status = 200, description = \"Generated Chat Completion\",\ncontent(\n(\"application/json\" = CompletionFinal),\n(\"text/event-stream\" = Chunk),\n)),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"})),\n)\n)]\n#[instrument(\nskip_all,\nfields(\n// parameters = ? req.parameters,\ntotal_time,\nvalidation_time,\nqueue_time,\ninference_time,\ntime_per_token,\nseed,\n)\n)]\npub(crate) async fn completions(\n    Extension(infer): Extension<Infer>,\n    Extension(compute_type): Extension<ComputeType>,\n    Extension(info): Extension<Info>,\n    Extension(context): Extension<Option<opentelemetry::Context>>,\n    Json(req): Json<CompletionRequest>,\n) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {\n    let span = tracing::Span::current();\n    if let Some(context) = context {\n        span.set_parent(context);\n    }\n\n    metrics::counter!(\"tgi_request_count\").increment(1);\n\n    let CompletionRequest {\n        model,\n        max_tokens,\n        seed,\n        stop,\n        stream,\n        temperature,\n        ..\n    } = req;\n\n    let max_new_tokens = max_tokens;\n    let stop = stop.unwrap_or_default();\n    // enable greedy only when temperature is 0\n    let (do_sample, temperature) = match temperature {\n        Some(0.0) => (false, None),\n        other => (true, other),\n    };\n\n    // if suffix is present throw an error\n    if req.suffix.is_some() {\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"validation\").increment(1);\n        return Err((\n            StatusCode::UNPROCESSABLE_ENTITY,\n            Json(ErrorResponse {\n                error: \"Suffix is not supported and can be achieved by preprocessing the prompt.\"\n                    .to_string(),\n                error_type: \"suffix not supported\".to_string(),\n            }),\n        ));\n    }\n\n    if req.prompt.0.len() > info.max_client_batch_size {\n        metrics::counter!(\"tgi_request_failure\", \"err\" => \"validation\").increment(1);\n        return Err((\n            StatusCode::UNPROCESSABLE_ENTITY,\n            Json(ErrorResponse {\n                error: format!(\n                    \"Number of prompts exceeds the maximum allowed batch size of {}\",\n                    info.max_client_batch_size\n                ),\n                error_type: \"batch size exceeded\".to_string(),\n            }),\n        ));\n    }\n\n    let generate_requests: Vec<GenerateRequest> = req\n        .prompt\n        .0\n        .iter()\n        .map(|prompt| GenerateRequest {\n            inputs: prompt.to_string(),\n            add_special_tokens: true,\n            parameters: GenerateParameters {\n                best_of: None,\n                temperature,\n                repetition_penalty: req.repetition_penalty,\n                frequency_penalty: req.frequency_penalty,\n                top_k: None,\n                top_p: req.top_p,\n                typical_p: None,\n                do_sample,\n                max_new_tokens,\n                return_full_text: None,\n                stop: stop.clone(),\n                truncate: None,\n                watermark: false,\n                details: true,\n                decoder_input_details: !stream,\n                seed,\n                top_n_tokens: None,\n                grammar: None,\n                adapter_id: model.as_ref().filter(|m| *m != \"tgi\").map(String::from),\n            },\n        })\n        .collect();\n\n    let mut x_compute_type = None;\n    let mut x_compute_characters = 0u32;\n    let mut x_accel_buffering = None;\n\n    if stream {\n        let mut response_streams = FuturesOrdered::new();\n        for (index, generate_request) in generate_requests.into_iter().enumerate() {\n            let model_id = info.model_id.clone();\n            let system_fingerprint =\n                format!(\"{}-{}\", info.version, info.docker_label.unwrap_or(\"native\"));\n            let infer_clone = infer.clone();\n            let compute_type_clone = compute_type.clone();\n            let span_clone = span.clone();\n\n            // Create a future for each generate_stream_internal call.\n            let generate_future = async move {\n                let (header_tx, header_rx) = oneshot::channel();\n                let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();\n\n                tokio::spawn(async move {\n                    let (headers, response_stream) = generate_stream_internal(\n                        infer_clone.clone(),\n                        compute_type_clone.clone(),\n                        Json(generate_request),\n                        span_clone.clone(),\n                    )\n                    .await;\n\n                    let response_stream = async_stream::stream! {\n                        let mut response_stream = Box::pin(response_stream);\n\n                        while let Some(stream_token) = response_stream.next().await {\n                            match stream_token {\n                                Ok(stream_token) => {\n                                    let event = Event::default();\n\n                                    let current_time = std::time::SystemTime::now()\n                                        .duration_since(std::time::UNIX_EPOCH)\n                                        .unwrap_or_else(|_| std::time::Duration::from_secs(0))\n                                        .as_secs();\n\n                                    let message = match stream_token.details {\n                                        Some(details) => {\n                                            let completion_tokens = details.generated_tokens;\n                                            let prompt_tokens = details.input_length;\n                                            let total_tokens = prompt_tokens + completion_tokens;\n\n                                            Completion::Final(CompletionFinal {\n                                                id: String::new(),\n                                                created: current_time,\n                                                model: model_id.clone(),\n                                                system_fingerprint: system_fingerprint.clone(),\n                                                choices: vec![CompletionComplete {\n                                                    finish_reason: details.finish_reason.to_string(),\n                                                    index: index as u32,\n                                                    logprobs: None,\n                                                    text: stream_token.token.text,\n                                                }],\n                                                usage: Usage {\n                                                    prompt_tokens,\n                                                    completion_tokens,\n                                                    total_tokens,\n                                                },\n                                            })\n                                        }\n                                        None => Completion::Chunk(Chunk {\n                                            id: String::new(),\n                                            created: current_time,\n                                            choices: vec![CompletionComplete {\n                                                finish_reason: String::new(),\n                                                index: index as u32,\n                                                logprobs: None,\n                                                text: stream_token.token.text,\n                                            }],\n                                            model: model_id.clone(),\n                                            system_fingerprint: system_fingerprint.clone(),\n                                        }),\n                                    };\n\n                                    let event = event\n                                        .json_data(message)\n                                        .unwrap_or_else(|_e| Event::default());\n\n                                    yield Ok(event);\n                                }\n                                Err(err) => yield Ok(err.into_openai_event()),\n                            }\n                        }\n                    };\n\n                    // send and dont wait for response\n                    let _ = header_tx.send(headers);\n\n                    // pin an emit messages to the sse_tx\n                    let mut sse = Box::pin(response_stream);\n                    while let Some(event) = sse.next().await {\n                        if sse_tx.send(event).is_err() {\n                            tracing::error!(\"Failed to send event. Receiver dropped.\");\n                            break;\n                        }\n                    }\n                });\n\n                (header_rx, sse_rx)\n            };\n            response_streams.push_back(generate_future);\n        }\n\n        let mut all_rxs = vec![];\n\n        while let Some((header_rx, sse_rx)) = response_streams.next().await {\n            all_rxs.push(sse_rx);\n\n            // get the headers from the first response of each stream\n            let headers = header_rx.await.map_err(|e| {\n                tracing::error!(\"Failed to get headers: {:?}\", e);\n                (\n                    StatusCode::INTERNAL_SERVER_ERROR,\n                    Json(ErrorResponse {\n                        error: \"Failed to get headers\".to_string(),\n                        error_type: \"headers\".to_string(),\n                    }),\n                )\n            })?;\n            if x_compute_type.is_none() {\n                x_compute_type = headers\n                    .get(\"x-compute-type\")\n                    .and_then(|v| v.to_str().ok())\n                    .map(|v| v.to_string());\n\n                x_accel_buffering = headers\n                    .get(\"x-accel-buffering\")\n                    .and_then(|v| v.to_str().ok())\n                    .map(|v| v.to_string());\n            }\n            x_compute_characters += headers\n                .get(\"x-compute-characters\")\n                .and_then(|v| v.to_str().ok())\n                .and_then(|v| v.parse().ok())\n                .unwrap_or(0);\n        }\n\n        let mut headers = HeaderMap::new();\n        if let Some(x_compute_type) = x_compute_type {\n            headers.insert(\"x-compute-type\", x_compute_type.parse().unwrap());\n        }\n        headers.insert(\"x-compute-characters\", x_compute_characters.into());\n        if let Some(x_accel_buffering) = x_accel_buffering {\n            headers.insert(\"x-accel-buffering\", x_accel_buffering.parse().unwrap());\n        }\n\n        // now sink the sse streams into a single stream and remove the ones that are done\n        let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {\n            loop {\n                let mut i = 0;\n                while i < all_rxs.len() {\n                    let rx = &mut all_rxs[i];\n                    select! {\n                        Some(event) = rx.recv() => {\n                            yield event;\n                        }\n                        else => {\n                            all_rxs.remove(i);\n                            continue; // skip the increment to handle the next element at the same index\n                        }\n                    }\n                    i += 1; // only increment when no element was removed\n                }\n\n                if all_rxs.is_empty() {\n                    break;\n                }\n            }\n        };\n\n        let stream = stream.chain(futures::stream::once(async {\n            Ok(Event::default().data(\"[DONE]\"))\n        }));\n\n        let sse = Sse::new(stream).keep_alive(KeepAlive::default());\n        Ok((headers, sse).into_response())\n    } else {\n        let current_time = std::time::SystemTime::now()\n            .duration_since(std::time::UNIX_EPOCH)\n            .unwrap_or_else(|_| std::time::Duration::from_secs(0))\n            .as_secs();\n\n        let responses = FuturesUnordered::new();\n        for (index, generate_request) in generate_requests.into_iter().enumerate() {\n            let infer_clone = infer.clone();\n            let compute_type_clone = compute_type.clone();\n            let span_clone = span.clone();\n            let response_future = async move {\n                let result = generate_internal(\n                    Extension(infer_clone),\n                    compute_type_clone,\n                    Json(generate_request),\n                    span_clone,\n                )\n                .await;\n                result.map(|(headers, input_length, generation)| {\n                    (index, headers, input_length, generation)\n                })\n            };\n            responses.push(response_future);\n        }\n        let generate_responses = responses.try_collect::<Vec<_>>().await?;\n\n        let mut prompt_tokens = 0u32;\n        let mut completion_tokens = 0u32;\n        let mut total_tokens = 0u32;\n\n        let mut x_compute_time = 0u32;\n        let mut x_total_time = 0u32;\n        let mut x_validation_time = 0u32;\n        let mut x_queue_time = 0u32;\n        let mut x_inference_time = 0u32;\n        let mut x_time_per_token = 0u32;\n        let mut x_prompt_tokens = 0u32;\n        let mut x_generated_tokens = 0u32;\n\n        let choices = generate_responses\n            .into_iter()\n            .map(|(index, headers, input_length, Json(generation))| {\n                let details = generation.details.ok_or((\n                    // this should never happen but handle if details are missing unexpectedly\n                    StatusCode::INTERNAL_SERVER_ERROR,\n                    Json(ErrorResponse {\n                        error: \"No details in generation\".to_string(),\n                        error_type: \"no details\".to_string(),\n                    }),\n                ))?;\n\n                if x_compute_type.is_none() {\n                    x_compute_type = headers\n                        .get(\"x-compute-type\")\n                        .and_then(|v| v.to_str().ok())\n                        .map(|v| v.to_string());\n                }\n\n                // accumulate headers and usage from each response\n                x_compute_time += headers\n                    .get(\"x-compute-time\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_compute_characters += headers\n                    .get(\"x-compute-characters\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_total_time += headers\n                    .get(\"x-total-time\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_validation_time += headers\n                    .get(\"x-validation-time\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_queue_time += headers\n                    .get(\"x-queue-time\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_inference_time += headers\n                    .get(\"x-inference-time\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_time_per_token += headers\n                    .get(\"x-time-per-token\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_prompt_tokens += headers\n                    .get(\"x-prompt-tokens\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n                x_generated_tokens += headers\n                    .get(\"x-generated-tokens\")\n                    .and_then(|v| v.to_str().ok()?.parse().ok())\n                    .unwrap_or(0);\n\n                prompt_tokens += input_length;\n                completion_tokens += details.generated_tokens;\n                total_tokens += input_length + details.generated_tokens;\n\n                Ok(CompletionComplete {\n                    finish_reason: details.finish_reason.format(true),\n                    index: index as u32,\n                    logprobs: None,\n                    text: generation.generated_text,\n                })\n            })\n            .collect::<Result<Vec<_>, _>>()\n            .map_err(|(status, Json(err))| (status, Json(err)))?;\n\n        let response = Completion::Final(CompletionFinal {\n            id: \"\".to_string(),\n            created: current_time,\n            model: info.model_id.clone(),\n            system_fingerprint: format!(\n                \"{}-{}\",\n                info.version,\n                info.docker_label.unwrap_or(\"native\")\n            ),\n            choices,\n            usage: Usage {\n                prompt_tokens,\n                completion_tokens,\n                total_tokens,\n            },\n        });\n\n        // headers similar to `generate` but aggregated\n        let mut headers = HeaderMap::new();\n        if let Some(x_compute_type) = x_compute_type {\n            headers.insert(\"x-compute-type\", x_compute_type.parse().unwrap());\n        }\n        headers.insert(\"x-compute-characters\", x_compute_characters.into());\n        headers.insert(\"x-total-time\", x_total_time.into());\n        headers.insert(\"x-validation-time\", x_validation_time.into());\n        headers.insert(\"x-queue-time\", x_queue_time.into());\n        headers.insert(\"x-inference-time\", x_inference_time.into());\n        headers.insert(\"x-time-per-token\", x_time_per_token.into());\n        headers.insert(\"x-prompt-tokens\", x_prompt_tokens.into());\n        headers.insert(\"x-generated-tokens\", x_generated_tokens.into());\n        if let Some(x_accel_buffering) = x_accel_buffering {\n            headers.insert(\"x-accel-buffering\", x_accel_buffering.parse().unwrap());\n        }\n        Ok((headers, Json(response)).into_response())\n    }\n}\n\n/// Generate tokens\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/v1/chat/completions\",\nrequest_body = ChatRequest,\nresponses(\n(status = 200, description = \"Generated Chat Completion\",\ncontent(\n(\"application/json\" = ChatCompletion),\n(\"text/event-stream\" = ChatCompletionChunk),\n)),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\", \"error_type\": \"generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\", \"error_type\": \"overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\", \"error_type\": \"validation\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\", \"error_type\": \"incomplete_generation\"})),\n)\n)]\n#[instrument(\n    skip_all,\n    fields(\n        parameters,\n        total_time,\n        validation_time,\n        queue_time,\n        inference_time,\n        time_per_token,\n        seed,\n    )\n)]\npub(crate) async fn chat_completions(\n    Extension(infer): Extension<Infer>,\n    Extension(compute_type): Extension<ComputeType>,\n    Extension(info): Extension<Info>,\n    Extension(context): Extension<Option<opentelemetry::Context>>,\n    Json(mut chat): Json<ChatRequest>,\n) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {\n    let span = tracing::Span::current();\n    if let Some(context) = context {\n        span.set_parent(context);\n    }\n\n    metrics::counter!(\"tgi_request_count\").increment(1);\n    let ChatRequest {\n        model,\n        stream,\n        stream_options,\n        logprobs,\n        ..\n    } = chat.clone();\n\n    tracing::debug!(\"Got chat_template {:?}\", infer.chat_template);\n    let id = chat.next_tool_call_id();\n    let (generate_request, using_tools): (GenerateRequest, bool) =\n        chat.clone().try_into_generate(&infer)?;\n    span.record(\"parameters\", format!(\"{:?}\", generate_request.parameters));\n    let logprobs = logprobs.unwrap_or_default();\n\n    // extract model id from request if specified\n    let model_id = match model.as_deref() {\n        Some(\"tgi\") | None => info.model_id.clone(),\n        Some(m_id) => m_id.to_string(),\n    };\n    let system_fingerprint = format!(\"{}-{}\", info.version, info.docker_label.unwrap_or(\"native\"));\n    // switch on stream\n    if stream {\n        let (headers, response_stream) = generate_stream_internal(\n            infer.clone(),\n            compute_type.clone(),\n            Json(generate_request),\n            span.clone(),\n        )\n        .await;\n\n        let response_stream = async_stream::stream! {\n            let mut response_stream = Box::pin(response_stream);\n            let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());\n            while let Some(result) = response_stream.next().await {\n                match result{\n                Ok(stream_token) => {\n                    let events = state.push(stream_token);\n                    match events{\n                        ChatEvent::NoTool => {\n                            chat.tools = None;\n                            chat.response_format = None;\n                            let (generate_request, using_tools): (GenerateRequest, bool) =\n                                chat.clone().try_into_generate(&infer).unwrap();\n                            assert!(!using_tools);\n                            let (_headers, response_stream2) =\n                                generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;\n                            state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());\n                            response_stream = Box::pin(response_stream2);\n                        }\n                        ChatEvent::Events(events) => {\n                            for chat_complete in events{\n                                yield Ok(Event::default().json_data(chat_complete).unwrap_or_else(|e| {\n                                    tracing::error!(\"Failed to serialize ChatCompletionChunk: {:?}\", e);\n                                    Event::default()\n                                }));\n                            }\n                        }\n                    }\n                }\n                Err(err) => yield Ok(err.into_openai_event())\n                }\n            }\n            yield Ok::<Event, Infallible>(Event::default().data(\"[DONE]\"));\n        };\n\n        let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());\n        Ok((headers, sse).into_response())\n    } else {\n        let (mut headers, mut input_length, Json(generation)) = generate_internal(\n            Extension(infer.clone()),\n            compute_type.clone(),\n            Json(generate_request),\n            span.clone(),\n        )\n        .await?;\n\n        let current_time = std::time::SystemTime::now()\n            .duration_since(std::time::UNIX_EPOCH)\n            .unwrap_or_else(|_| std::time::Duration::from_secs(0))\n            .as_secs();\n\n        let (tool_calls, output) = if using_tools {\n            match crate::chat::parse_output(&generation.generated_text)? {\n                ChatChoice::NoTool => {\n                    chat.tools = None;\n                    chat.response_format = None;\n                    let (generate_request, using_tools): (GenerateRequest, bool) =\n                        chat.clone().try_into_generate(&infer)?;\n                    assert!(!using_tools);\n                    let (headers_final, input_length_final, Json(generation)) = generate_internal(\n                        Extension(infer),\n                        compute_type,\n                        Json(generate_request),\n                        span,\n                    )\n                    .await?;\n                    headers = headers_final;\n                    input_length = input_length_final;\n                    (None, Some(generation.generated_text))\n                }\n                ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),\n            }\n        } else {\n            (None, Some(generation.generated_text))\n        };\n        // build the complete response object with the full text\n        let response = CompletionType::ChatCompletion(ChatCompletion::new(\n            model_id,\n            system_fingerprint,\n            output,\n            current_time,\n            generation.details.unwrap(),\n            logprobs,\n            tool_calls,\n            input_length,\n        ));\n\n        // wrap generation inside a Vec to match api-inference\n        Ok((headers, Json(response)).into_response())\n    }\n}\n\n/// Tokenize inputs\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/tokenize\",\nrequest_body = GenerateRequest,\nresponses(\n(status = 200, description = \"Tokenized ids\", body = TokenizeResponse),\n(status = 404, description = \"No tokenizer found\", body = ErrorResponse,\nexample = json ! ({\"error\": \"No fast tokenizer available\"})),\n)\n)]\n#[instrument(skip_all)]\nasync fn tokenize(\n    Extension(infer): Extension<Infer>,\n    Json(req): Json<GenerateRequest>,\n) -> Result<Json<TokenizeResponse>, (StatusCode, Json<ErrorResponse>)> {\n    let input = req.inputs.clone();\n    let encoding = infer.tokenize(req).await?;\n    let tokens = encoding_to_tokens(&encoding, &input);\n    Ok(Json(TokenizeResponse(tokens)))\n}\n\n/// Prometheus metrics scrape endpoint\n#[utoipa::path(\n    get,\n    tag = \"Text Generation Inference\",\n    path = \"/metrics\",\n    responses((status = 200, description = \"Prometheus Metrics\", body = String))\n)]\nasync fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {\n    prom_handle.render()\n}\n\n#[derive(Clone, Debug)]\npub(crate) struct ComputeType(String);\n\n// OpenAPI documentation\n#[derive(OpenApi)]\n#[openapi(\npaths(\nhealth,\nget_model_info,\ncompat_generate,\ngenerate,\ngenerate_stream,\nchat_completions,\ncompletions,\ntokenize,\nmetrics,\nopenai_get_model_info,\nsagemaker_compatibility,\nget_chat_tokenize,\n),\ncomponents(\nschemas(\nInfo,\nCompatGenerateRequest,\nSagemakerRequest,\nGenerateRequest,\nGrammarType,\nJsonSchemaConfig,\nChatRequest,\nMessage,\nMessageContent,\nMessageChunk,\nUrl,\nFunctionName,\nOutputMessage,\nTextMessage,\nToolCallMessage,\nToolCallDelta,\nChatCompletionComplete,\nChatCompletionChoice,\nChatCompletionDelta,\nChatCompletionChunk,\nChatCompletionLogprob,\nChatCompletionLogprobs,\nChatCompletionTopLogprob,\nChatCompletion,\nCompletionRequest,\nCompletionComplete,\nSagemakerResponse,\nSagemakerStreamResponse,\nChunk,\nCompletion,\nCompletionFinal,\nPrompt,\nGenerateParameters,\nPrefillToken,\nToken,\nGenerateResponse,\nTokenizeResponse,\nSimpleToken,\nBestOfSequence,\nDetails,\nFinishReason,\nStreamResponse,\nStreamDetails,\nErrorResponse,\nGrammarType,\nUsage,\nStreamOptions,\nDeltaToolCall,\nTool,\nToolCall,\nFunction,\nFunctionDefinition,\nToolChoice,\nModelInfo,\nChatTokenizeResponse,\nMessageBody,\n)\n),\ntags(\n(name = \"Text Generation Inference\", description = \"Hugging Face Text Generation Inference API\")\n),\ninfo(\ntitle = \"Text Generation Inference\",\nlicense(\nname = \"Apache 2.0\",\nurl = \"https://www.apache.org/licenses/LICENSE-2.0\"\n)\n)\n)]\npub struct ApiDoc;\n\npub fn schema() -> ApiDoc {\n    ApiDoc\n}\n\npub fn py_resolve_tokenizer(\n    py: pyo3::Python,\n    tokenizer_name: &str,\n    revision: Option<&str>,\n    trust_remote_code: bool,\n) -> pyo3::PyResult<()> {\n    let transformers = py.import_bound(\"transformers\")?;\n    let auto = transformers.getattr(\"AutoTokenizer\")?;\n    let from_pretrained = auto.getattr(\"from_pretrained\")?;\n    let args = (tokenizer_name,);\n    let kwargs = if let Some(rev) = &revision {\n        [\n            (\"revision\", rev.to_string().into_py(py)),\n            (\"trust_remote_code\", trust_remote_code.into_py(py)),\n        ]\n        .into_py_dict_bound(py)\n    } else {\n        [(\"trust_remote_code\", trust_remote_code.into_py(py))].into_py_dict_bound(py)\n    };\n    let tokenizer = from_pretrained.call(args, Some(&kwargs))?;\n    let save = tokenizer.getattr(\"save_pretrained\")?;\n    let args = (\"out\".to_string(),);\n    save.call1(args)?;\n    Ok(())\n}\n\npub fn legacy_tokenizer_handle(config_filename: Option<&PathBuf>) -> Option<()> {\n    // XXX Legacy case for FasterDecoding/medusa-vicuna-7b-v1.3\n    // and state-spaces/mamba-130m\n    tracing::warn!(\"Odd tokenizer detected, falling back on legacy tokenization\");\n\n    #[derive(serde::Deserialize)]\n    struct FallbackConfig {\n        base_model_name_or_path: Option<String>,\n        model_type: Option<String>,\n        ssm_config: Option<serde_json::Value>,\n    }\n    config_filename.and_then(|filename| {\n        std::fs::read_to_string(filename)\n            .ok()\n            .as_ref()\n            .and_then(|c| {\n                let config: Result<FallbackConfig, _> = serde_json::from_str(c);\n                if let Ok(config) = config {\n                    if config.model_type.is_none() {\n                        if let Some(base) = config.base_model_name_or_path {\n                            pyo3::Python::with_gil(|py| -> PyResult<()> {\n                                py_resolve_tokenizer(py, &base, Some(\"main\"), false)\n                            })\n                            .ok()?;\n                        }\n                    }\n                    if config.ssm_config.is_some() {\n                        // XXX Legacy mamba\n                        pyo3::Python::with_gil(|py| -> PyResult<()> {\n                            py_resolve_tokenizer(py, \"EleutherAI/gpt-neox-20b\", Some(\"main\"), false)\n                        })\n                        .ok()?;\n                    }\n                }\n                Some(())\n            })\n    })\n}\n\n/// Serving method\n#[allow(clippy::too_many_arguments)]\npub async fn run(\n    backend: impl Backend + Send + Sync + 'static,\n    max_concurrent_requests: usize,\n    max_best_of: usize,\n    max_stop_sequences: usize,\n    max_top_n_tokens: u32,\n    max_input_tokens: usize,\n    max_total_tokens: usize,\n    validation_workers: usize,\n    api_key: Option<String>,\n    tokenizer_name: String,\n    tokenizer_config_path: Option<String>,\n    revision: Option<String>,\n    trust_remote_code: bool,\n    hostname: String,\n    port: u16,\n    cors_allow_origin: Option<Vec<String>>,\n    ngrok: bool,\n    _ngrok_authtoken: Option<String>,\n    _ngrok_edge: Option<String>,\n    disable_grammar_support: bool,\n    max_client_batch_size: usize,\n    usage_stats_level: usage_stats::UsageStatsLevel,\n    payload_limit: usize,\n    max_image_fetch_size: usize,\n    prometheus_port: u16,\n) -> Result<(), WebServerError> {\n    // CORS allowed origins\n    // map to go inside the option and then map to parse from String to HeaderValue\n    // Finally, convert to AllowOrigin\n    let allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {\n        AllowOrigin::list(\n            cors_allow_origin\n                .iter()\n                .map(|origin| origin.parse::<HeaderValue>().unwrap()),\n        )\n    });\n\n    // Parse Huggingface hub token\n    let authorization_token = std::env::var(\"HF_TOKEN\")\n        .or_else(|_| std::env::var(\"HUGGING_FACE_HUB_TOKEN\"))\n        .ok();\n\n    // Tokenizer instance\n    // This will only be used to validate payloads\n    let local_path = Path::new(&tokenizer_name);\n\n    // Shared API builder initialization\n    let api_builder = || {\n        let mut builder = ApiBuilder::from_env().with_progress(false);\n        if let Some(token) = authorization_token {\n            builder = builder.with_token(Some(token));\n        }\n\n        if let Ok(cache_dir) = std::env::var(\"HUGGINGFACE_HUB_CACHE\") {\n            builder = builder.with_cache_dir(cache_dir.into());\n        }\n\n        if let Ok(origin) = std::env::var(\"HF_HUB_USER_AGENT_ORIGIN\") {\n            builder = builder.with_user_agent(\"origin\", origin.as_str());\n        }\n\n        builder\n    };\n\n    // Decide if we need to use the API based on the revision and local path\n    let use_api = revision.is_some() || !local_path.exists() || !local_path.is_dir();\n\n    // Initialize API if needed\n    #[derive(Clone)]\n    enum Type {\n        Api(Api),\n        Cache(Cache),\n        None,\n    }\n    let api = if use_api {\n        if std::env::var(\"HF_HUB_OFFLINE\") == Ok(\"1\".to_string()) {\n            let cache = std::env::var(\"HUGGINGFACE_HUB_CACHE\")\n                .map_err(|_| ())\n                .map(|cache_dir| Cache::new(cache_dir.into()))\n                .unwrap_or_else(|_| Cache::from_env());\n            tracing::warn!(\"Offline mode active using cache defaults\");\n            Type::Cache(cache)\n        } else {\n            tracing::info!(\"Using the Hugging Face API\");\n            match api_builder().build() {\n                Ok(api) => Type::Api(api),\n                Err(_) => {\n                    tracing::warn!(\"Unable to build the Hugging Face API\");\n                    Type::None\n                }\n            }\n        }\n    } else {\n        Type::None\n    };\n\n    // Load tokenizer and model info\n    let (\n        config_filename,\n        tokenizer_config_filename,\n        preprocessor_config_filename,\n        processor_config_filename,\n        chat_template_filename,\n        model_info,\n    ) = match api {\n        Type::None => (\n            Some(local_path.join(\"config.json\")),\n            Some(local_path.join(\"tokenizer_config.json\")),\n            Some(local_path.join(\"preprocessor_config.json\")),\n            Some(local_path.join(\"processor_config.json\")),\n            Some(local_path.join(\"chat_template.json\")),\n            None,\n        ),\n        Type::Api(api) => {\n            let api_repo = api.repo(Repo::with_revision(\n                tokenizer_name.to_string(),\n                RepoType::Model,\n                revision.clone().unwrap_or_else(|| \"main\".to_string()),\n            ));\n\n            let config_filename = api_repo.get(\"config.json\").await.ok();\n            let tokenizer_config_filename = api_repo.get(\"tokenizer_config.json\").await.ok();\n            let preprocessor_config_filename = api_repo.get(\"preprocessor_config.json\").await.ok();\n            let processor_config_filename = api_repo.get(\"processor_config.json\").await.ok();\n            let chat_template_filename = api_repo.get(\"chat_template.json\").await.ok();\n\n            let model_info = if let Some(model_info) = get_hub_model_info(&api_repo).await {\n                Some(model_info)\n            } else {\n                tracing::warn!(\"Could not retrieve model info from the Hugging Face hub.\");\n                None\n            };\n            (\n                config_filename,\n                tokenizer_config_filename,\n                preprocessor_config_filename,\n                processor_config_filename,\n                chat_template_filename,\n                model_info,\n            )\n        }\n        Type::Cache(cache) => {\n            tracing::info!(\"Cache {cache:?}\");\n            let repo = cache.repo(Repo::with_revision(\n                tokenizer_name.to_string(),\n                RepoType::Model,\n                revision.clone().unwrap_or_else(|| \"main\".to_string()),\n            ));\n            (\n                repo.get(\"config.json\"),\n                repo.get(\"tokenizer_config.json\"),\n                repo.get(\"preprocessor_config.json\"),\n                repo.get(\"processor_config.json\"),\n                repo.get(\"chat_template.json\"),\n                None,\n            )\n        }\n    };\n\n    // if chat_template_filename is present, load the chat template\n    let chat_template: Option<crate::ChatTemplateVersions> = chat_template_filename\n        .and_then(|f| std::fs::read_to_string(f).ok())\n        .and_then(|c| {\n            let res = serde_json::from_str::<crate::ChatTemplateStandalone>(&c);\n            if let Err(e) = &res {\n                tracing::warn!(\"Could not parse chat template {e:?}\");\n            }\n            res.ok().map(|t| t.chat_template)\n        });\n\n    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.\n    tracing::warn!(\"Tokenizer_config {tokenizer_config_path:?} - {tokenizer_config_filename:?}\");\n    let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path\n    {\n        HubTokenizerConfig::from_file(filename)\n    } else {\n        tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)\n    };\n    let mut tokenizer_config = tokenizer_config.unwrap_or_else(|| {\n        tracing::warn!(\"Could not find tokenizer config locally and no API specified\");\n        HubTokenizerConfig::default()\n    });\n\n    if chat_template.is_some() {\n        tracing::info!(\"Using chat template from chat_template.json\");\n        tokenizer_config.chat_template = chat_template;\n    }\n\n    let tokenizer: Result<Tokenizer, WebServerError> = {\n        use pyo3::prelude::*;\n        Python::with_gil(|py| -> PyResult<()> {\n            py_resolve_tokenizer(py, &tokenizer_name, revision.as_deref(), trust_remote_code)?;\n            Ok(())\n        })\n        .inspect_err(|err| {\n            tracing::error!(\"Failed to import python tokenizer {err}\");\n        })\n        .or_else(|err| {\n            let out = legacy_tokenizer_handle(config_filename.as_ref());\n            out.ok_or(err)\n        })\n        .map_err(|_| WebServerError::Tokenizer(\"Unable to load tokenizer.\".to_string()))?;\n        let filename = \"out/tokenizer.json\";\n        if let Ok(tok) = tokenizers::Tokenizer::from_file(filename) {\n            Ok(Tokenizer::Rust(tok))\n        } else {\n            Ok(Tokenizer::Python {\n                tokenizer_name: tokenizer_name.clone(),\n                revision: revision.clone(),\n                trust_remote_code,\n            })\n        }\n    };\n\n    let config: Option<Config> = config_filename.and_then(|filename| {\n        std::fs::read_to_string(filename)\n            .ok()\n            .as_ref()\n            .and_then(|c| {\n                let config: Result<Config, _> = serde_json::from_str(c);\n                if let Err(err) = &config {\n                    tracing::warn!(\"Could not parse config {err:?}\");\n                }\n                config.ok()\n            })\n    });\n    let model_info = model_info.unwrap_or_else(|| HubModelInfo {\n        model_id: tokenizer_name.to_string(),\n        sha: None,\n        pipeline_tag: None,\n    });\n\n    let processor_config = processor_config_filename\n        .and_then(HubProcessorConfig::from_file)\n        .unwrap_or_default();\n\n    let preprocessor_config: Option<HubPreprocessorConfig> =\n        preprocessor_config_filename.and_then(HubPreprocessorConfig::from_file);\n\n    tracing::info!(\"Using config {config:?}\");\n\n    // Only send usage stats when TGI is run in container and the function returns Some\n    let is_container = matches!(usage_stats::is_container(), Ok(true));\n    // retrieve the huggingface_hub user agent origin if set, and add the origin to telemetry\n    let origin = std::env::var(\"HF_HUB_USER_AGENT_ORIGIN\").ok();\n    let user_agent = match (usage_stats_level, is_container) {\n        (usage_stats::UsageStatsLevel::On | usage_stats::UsageStatsLevel::NoStack, true) => {\n            let reduced_args = usage_stats::Args::new(\n                config.clone(),\n                tokenizer_config.tokenizer_class.clone(),\n                max_concurrent_requests,\n                max_best_of,\n                max_stop_sequences,\n                max_top_n_tokens,\n                max_input_tokens,\n                max_total_tokens,\n                // waiting_served_ratio,\n                // max_batch_prefill_tokens,\n                // max_batch_total_tokens,\n                // max_waiting_tokens,\n                // max_batch_size,\n                revision.clone(),\n                validation_workers,\n                disable_grammar_support,\n                max_client_batch_size,\n                usage_stats_level,\n                backend.name(),\n                origin,\n            );\n            Some(usage_stats::UserAgent::new(reduced_args))\n        }\n        _ => None,\n    };\n\n    let stop_usage_thread = Arc::new(AtomicBool::new(false));\n    let stop_usage_thread_clone = stop_usage_thread.clone();\n    if let Some(ua) = user_agent.clone() {\n        let start_event =\n            usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None);\n        tokio::spawn(async move {\n            // send start event\n            start_event.send().await;\n            let mut last_report = Instant::now();\n            while !stop_usage_thread_clone.load(Ordering::Relaxed) {\n                if last_report.elapsed() > Duration::from_secs(900) {\n                    let report_event = usage_stats::UsageStatsEvent::new(\n                        ua.clone(),\n                        usage_stats::EventType::Ping,\n                        None,\n                    );\n                    report_event.send().await;\n                    last_report = Instant::now();\n                }\n                tokio::time::sleep(Duration::from_secs(1)).await;\n            }\n        });\n    };\n    let compat_return_full_text = match &model_info.pipeline_tag {\n        None => {\n            tracing::warn!(\"no pipeline tag found for model {tokenizer_name}\");\n            true\n        }\n        Some(pipeline_tag) => pipeline_tag.as_str() == \"text-generation\",\n    };\n    let result = start(\n        backend,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        validation_workers,\n        api_key,\n        config,\n        (tokenizer?, tokenizer_config),\n        (preprocessor_config, processor_config),\n        hostname,\n        port,\n        ngrok,\n        _ngrok_authtoken,\n        _ngrok_edge,\n        disable_grammar_support,\n        max_client_batch_size,\n        model_info,\n        compat_return_full_text,\n        allow_origin,\n        payload_limit,\n        max_image_fetch_size,\n        prometheus_port,\n    )\n    .await;\n\n    if let Some(ua) = user_agent {\n        stop_usage_thread.store(true, Ordering::Relaxed);\n        match result {\n            Ok(_) => {\n                let stop_event = usage_stats::UsageStatsEvent::new(\n                    ua.clone(),\n                    usage_stats::EventType::Stop,\n                    None,\n                );\n                stop_event.send().await;\n                Ok(())\n            }\n            Err(e) => {\n                let description = match usage_stats_level {\n                    usage_stats::UsageStatsLevel::On => Some(e.to_string()),\n                    usage_stats::UsageStatsLevel::NoStack => Some(\"unknow_error\".to_string()),\n                    _ => None,\n                };\n                let event = usage_stats::UsageStatsEvent::new(\n                    ua.clone(),\n                    usage_stats::EventType::Error,\n                    description,\n                );\n                event.send().await;\n\n                Err(e)\n            }\n        }\n    } else {\n        result\n    }\n}\n\n#[allow(clippy::too_many_arguments)]\nasync fn start(\n    backend: impl Backend + Send + Sync + 'static,\n    max_concurrent_requests: usize,\n    max_best_of: usize,\n    max_stop_sequences: usize,\n    max_top_n_tokens: u32,\n    max_input_tokens: usize,\n    max_total_tokens: usize,\n    validation_workers: usize,\n    api_key: Option<String>,\n    config: Option<Config>,\n    (tokenizer, tokenizer_config): (Tokenizer, HubTokenizerConfig),\n    (preprocessor_config, processor_config): (Option<HubPreprocessorConfig>, HubProcessorConfig),\n    hostname: String,\n    port: u16,\n    ngrok: bool,\n    _ngrok_authtoken: Option<String>,\n    _ngrok_edge: Option<String>,\n    disable_grammar_support: bool,\n    max_client_batch_size: usize,\n    model_info: HubModelInfo,\n    compat_return_full_text: bool,\n    allow_origin: Option<AllowOrigin>,\n    payload_limit: usize,\n    max_image_fetch_size: usize,\n    prometheus_port: u16,\n) -> Result<(), WebServerError> {\n    // Determine the server port based on the feature and environment variable.\n    let port = if cfg!(feature = \"google\") {\n        std::env::var(\"AIP_HTTP_PORT\")\n            .map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))\n            .unwrap_or(port)\n    } else {\n        port\n    };\n\n    let addr = match hostname.parse() {\n        Ok(ip) => SocketAddr::new(ip, port),\n        Err(_) => {\n            tracing::warn!(\"Invalid hostname, defaulting to 0.0.0.0\");\n            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)\n        }\n    };\n\n    // Create state\n    let validation = Validation::new(\n        validation_workers,\n        tokenizer,\n        config,\n        preprocessor_config,\n        max_best_of,\n        max_stop_sequences,\n        max_top_n_tokens,\n        max_input_tokens,\n        max_total_tokens,\n        disable_grammar_support,\n        max_image_fetch_size,\n    );\n\n    let infer = Infer::new(\n        backend,\n        validation,\n        max_concurrent_requests,\n        tokenizer_config,\n        processor_config,\n    );\n\n    // Duration buckets\n    let duration_matcher = Matcher::Suffix(String::from(\"duration\"));\n    let n_duration_buckets = 35;\n    let mut duration_buckets = Vec::with_capacity(n_duration_buckets);\n    // Minimum duration in seconds\n    let mut value = 0.0001;\n    for _ in 0..n_duration_buckets {\n        // geometric sequence\n        value *= 1.5;\n        duration_buckets.push(value);\n    }\n    // Input Length buckets\n    let input_length_matcher = Matcher::Full(String::from(\"tgi_request_input_length\"));\n    let input_length_buckets: Vec<f64> = (0..100)\n        .map(|x| (max_input_tokens as f64 / 100.0) * (x + 1) as f64)\n        .collect();\n    // Generated tokens buckets\n    let generated_tokens_matcher = Matcher::Full(String::from(\"tgi_request_generated_tokens\"));\n    let generated_tokens_buckets: Vec<f64> = (0..100)\n        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)\n        .collect();\n    // Input Length buckets\n    let max_new_tokens_matcher = Matcher::Full(String::from(\"tgi_request_max_new_tokens\"));\n    let max_new_tokens_buckets: Vec<f64> = (0..100)\n        .map(|x| (max_total_tokens as f64 / 100.0) * (x + 1) as f64)\n        .collect();\n    // Batch size buckets\n    let batch_size_matcher = Matcher::Full(String::from(\"tgi_batch_next_size\"));\n    let batch_size_buckets: Vec<f64> = (0..1024).map(|x| (x + 1) as f64).collect();\n    // Speculated tokens buckets\n    // let skipped_matcher = Matcher::Full(String::from(\"tgi_request_skipped_tokens\"));\n    // let skipped_buckets: Vec<f64> = (0..shard_info.speculate + 1).map(|x| x as f64).collect();\n\n    let mut p_addr = addr;\n    p_addr.set_port(prometheus_port);\n\n    // Prometheus handler\n    let builder = PrometheusBuilder::new()\n        .with_http_listener(p_addr)\n        .set_buckets_for_metric(duration_matcher, &duration_buckets)\n        .unwrap()\n        .set_buckets_for_metric(input_length_matcher, &input_length_buckets)\n        .unwrap()\n        .set_buckets_for_metric(generated_tokens_matcher, &generated_tokens_buckets)\n        .unwrap()\n        .set_buckets_for_metric(max_new_tokens_matcher, &max_new_tokens_buckets)\n        .unwrap()\n        .set_buckets_for_metric(batch_size_matcher, &batch_size_buckets)\n        .unwrap();\n    // .set_buckets_for_metric(skipped_matcher, &skipped_buckets)\n    // .unwrap();\n    // See: https://github.com/metrics-rs/metrics/issues/467#issuecomment-2022755151\n    let (recorder, _) = builder\n        .build()\n        .expect(\"failed to build prometheus recorder\");\n    let prom_handle = recorder.handle();\n    metrics::set_global_recorder(recorder).expect(\"Failed to set global recorder\");\n\n    // Metrics descriptions\n    metrics::describe_counter!(\"tgi_request_success\", \"Number of successful requests\");\n    metrics::describe_histogram!(\n        \"tgi_request_duration\",\n        metrics::Unit::Seconds,\n        \"Request duration\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_validation_duration\",\n        metrics::Unit::Seconds,\n        \"Request validation duration\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_queue_duration\",\n        metrics::Unit::Seconds,\n        \"Request queue duration\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_inference_duration\",\n        metrics::Unit::Seconds,\n        \"Request inference duration\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_mean_time_per_token_duration\",\n        metrics::Unit::Seconds,\n        \"Mean time per token per request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_generated_tokens\",\n        metrics::Unit::Count,\n        \"Generated tokens per request\"\n    );\n    metrics::describe_counter!(\n        \"tgi_batch_inference_count\",\n        metrics::Unit::Count,\n        \"Inference calls per method (prefill or decode)\"\n    );\n    metrics::describe_counter!(\n        \"tgi_request_count\",\n        metrics::Unit::Count,\n        \"Total number of requests\"\n    );\n    metrics::describe_counter!(\n        \"tgi_batch_inference_success\",\n        metrics::Unit::Count,\n        \"Number of successful inference calls per method (prefill or decode)\"\n    );\n    metrics::describe_gauge!(\n        \"tgi_batch_current_size\",\n        metrics::Unit::Count,\n        \"Current batch size\"\n    );\n    metrics::describe_gauge!(\"tgi_queue_size\", metrics::Unit::Count, \"Current queue size\");\n    metrics::describe_gauge!(\n        \"tgi_batch_current_max_tokens\",\n        metrics::Unit::Count,\n        \"Maximum tokens for the current batch\"\n    );\n    metrics::describe_gauge!(\n        \"tgi_batch_total_tokens\",\n        metrics::Unit::Count,\n        \"Maximum amount of tokens in total.\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_max_new_tokens\",\n        metrics::Unit::Count,\n        \"Maximum new tokens per request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_batch_inference_duration\",\n        metrics::Unit::Seconds,\n        \"Batch inference duration\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_batch_forward_duration\",\n        metrics::Unit::Seconds,\n        \"Batch forward duration per method (prefill or decode)\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_skipped_tokens\",\n        metrics::Unit::Count,\n        \"Speculated tokens per request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_batch_filter_duration\",\n        metrics::Unit::Seconds,\n        \"Time spent filtering batches and sending generated tokens per method (prefill or decode)\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_queue_duration\",\n        metrics::Unit::Seconds,\n        \"Time spent in the queue per request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_validation_duration\",\n        metrics::Unit::Seconds,\n        \"Time spent validating the request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_duration\",\n        metrics::Unit::Seconds,\n        \"Total time spent processing the request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_batch_decode_duration\",\n        metrics::Unit::Seconds,\n        \"Time spent decoding a batch per method (prefill or decode)\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_request_input_length\",\n        metrics::Unit::Count,\n        \"Input token length per request\"\n    );\n    metrics::describe_histogram!(\n        \"tgi_batch_next_size\",\n        metrics::Unit::Count,\n        \"Batch size of the next batch\"\n    );\n\n    // CORS layer\n    let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());\n    let cors_layer = CorsLayer::new()\n        .allow_methods([Method::GET, Method::POST])\n        .allow_headers([http::header::CONTENT_TYPE])\n        .allow_origin(allow_origin);\n\n    // Endpoint info\n    let info = Info {\n        model_id: model_info.model_id,\n        model_sha: model_info.sha,\n        // model_dtype: shard_info.dtype,\n        // model_device_type: shard_info.device_type,\n        model_pipeline_tag: model_info.pipeline_tag,\n        max_concurrent_requests,\n        max_best_of,\n        max_stop_sequences,\n        max_input_tokens,\n        max_total_tokens,\n        // waiting_served_ratio,\n        // max_batch_total_tokens,\n        // max_waiting_tokens,\n        // max_batch_size,\n        validation_workers,\n        max_client_batch_size,\n        router: env!(\"CARGO_PKG_NAME\"),\n        version: env!(\"CARGO_PKG_VERSION\"),\n        sha: option_env!(\"VERGEN_GIT_SHA\"),\n        docker_label: option_env!(\"DOCKER_LABEL\"),\n    };\n\n    #[allow(unused_mut)] // mut is needed for conditional compilation\n    let mut doc = ApiDoc::openapi();\n\n    #[cfg(feature = \"google\")]\n    {\n        use crate::vertex::__path_vertex_compatibility;\n        use crate::vertex::{VertexInstance, VertexRequest, VertexResponse};\n\n        #[derive(OpenApi)]\n        #[openapi(\n            paths(vertex_compatibility),\n            components(schemas(VertexInstance, VertexRequest, VertexResponse))\n        )]\n        struct VertexApiDoc;\n\n        doc.merge(VertexApiDoc::openapi());\n    }\n\n    #[cfg(feature = \"kserve\")]\n    {\n        use crate::kserve::{\n            InferenceOutput, InferenceRequest, LiveResponse, MetadataServerResponse, OutputChunk,\n            ReadyResponse,\n        };\n        use crate::kserve::{\n            __path_kerve_server_metadata, __path_kserve_health_live, __path_kserve_health_ready,\n            __path_kserve_model_infer, __path_kserve_model_metadata,\n            __path_kserve_model_metadata_ready,\n        };\n\n        #[derive(OpenApi)]\n        #[openapi(\n            paths(\n                kserve_health_live,\n                kserve_health_ready,\n                kerve_server_metadata,\n                kserve_model_metadata,\n                kserve_model_metadata_ready,\n                kserve_model_infer,\n            ),\n            components(schemas(\n                InferenceOutput,\n                InferenceRequest,\n                LiveResponse,\n                MetadataServerResponse,\n                OutputChunk,\n                ReadyResponse,\n            ))\n        )]\n        struct KServeApiDoc;\n\n        doc.merge(KServeApiDoc::openapi());\n    }\n\n    // Configure Swagger UI\n    let swagger_ui = SwaggerUi::new(\"/docs\").url(\"/api-doc/openapi.json\", doc);\n\n    // Define base and health routes\n    let mut base_routes = Router::new()\n        .route(\"/\", post(compat_generate))\n        .route(\"/generate\", post(generate))\n        .route(\"/generate_stream\", post(generate_stream))\n        .route(\"/v1/chat/completions\", post(chat_completions))\n        .route(\"/v1/completions\", post(completions))\n        .route(\"/vertex\", post(vertex_compatibility))\n        .route(\"/invocations\", post(sagemaker_compatibility))\n        .route(\"/tokenize\", post(tokenize));\n\n    if let Some(api_key) = api_key {\n        let mut prefix = \"Bearer \".to_string();\n        prefix.push_str(&api_key);\n\n        // Leak to allow FnMut\n        let api_key: &'static str = prefix.leak();\n\n        let auth = move |headers: HeaderMap,\n                         request: axum::extract::Request,\n                         next: axum::middleware::Next| async move {\n            match headers.get(AUTHORIZATION) {\n                Some(token) => match token.to_str() {\n                    Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => {\n                        let response = next.run(request).await;\n                        Ok(response)\n                    }\n                    _ => Err(StatusCode::UNAUTHORIZED),\n                },\n                None => Err(StatusCode::UNAUTHORIZED),\n            }\n        };\n\n        base_routes = base_routes.layer(axum::middleware::from_fn(auth))\n    }\n    let info_routes = Router::new()\n        .route(\"/\", get(health))\n        .route(\"/chat_tokenize\", post(get_chat_tokenize))\n        .route(\"/info\", get(get_model_info))\n        .route(\"/health\", get(health))\n        .route(\"/ping\", get(health))\n        .route(\"/metrics\", get(metrics))\n        .route(\"/v1/models\", get(openai_get_model_info));\n\n    let compute_type =\n        ComputeType(std::env::var(\"COMPUTE_TYPE\").unwrap_or(\"gpu+optimized\".to_string()));\n\n    // Combine routes and layers\n    let mut app = Router::new()\n        .merge(swagger_ui)\n        .merge(base_routes)\n        .merge(info_routes);\n\n    #[cfg(feature = \"google\")]\n    {\n        tracing::info!(\"Built with `google` feature\");\n        tracing::info!(\n            \"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected.\"\n        );\n        if let Ok(env_predict_route) = std::env::var(\"AIP_PREDICT_ROUTE\") {\n            app = app.route(&env_predict_route, post(vertex_compatibility));\n        }\n        if let Ok(env_health_route) = std::env::var(\"AIP_HEALTH_ROUTE\") {\n            app = app.route(&env_health_route, get(health));\n        }\n    }\n\n    #[cfg(feature = \"kserve\")]\n    {\n        tracing::info!(\"Built with `kserve` feature\");\n        app = app\n            .route(\n                \"/v2/models/:model_name/versions/:model_version/infer\",\n                post(kserve_model_infer),\n            )\n            .route(\n                \"/v2/models/:model_name/versions/:model_version\",\n                get(kserve_model_metadata),\n            )\n            .route(\"/v2/health/ready\", get(kserve_health_ready))\n            .route(\"/v2/health/live\", get(kserve_health_live))\n            .route(\"/v2\", get(kerve_server_metadata))\n            .route(\n                \"/v2/models/:model_name/versions/:model_version/ready\",\n                get(kserve_model_metadata_ready),\n            );\n    }\n\n    // add layers after routes\n    app = app\n        .layer(Extension(info))\n        .layer(Extension(compat_return_full_text))\n        .layer(Extension(infer))\n        .layer(Extension(compute_type))\n        .layer(Extension(prom_handle.clone()))\n        .layer(OtelAxumLayer::default())\n        .layer(DefaultBodyLimit::max(payload_limit))\n        .layer(axum::middleware::from_fn(trace_context_middleware))\n        .layer(cors_layer);\n\n    tracing::info!(\"Connected\");\n\n    if ngrok {\n        #[cfg(feature = \"ngrok\")]\n        {\n            panic!(\"ngrok feature is not functional with axum=0.7 and hyper=1, waiting on https://github.com/ngrok/ngrok-rust/pull/137/files to re-enable.\");\n\n            // Run server\n        }\n        #[cfg(not(feature = \"ngrok\"))]\n        {\n            let _ngrok_authtoken = ngrok_authtoken;\n            let _ngrok_domain = ngrok_domain;\n            let _ngrok_username = ngrok_username;\n            let _ngrok_password = ngrok_password;\n\n            panic!(\"`text-generation-router` was compiled without the `ngrok` feature\");\n        }\n    } else {\n        // Run server\n        let listener = match tokio::net::TcpListener::bind(&addr).await {\n            Ok(listener) => listener,\n            Err(e) => {\n                tracing::error!(\"Failed to bind to {addr}: {e}\");\n                return Err(WebServerError::Axum(Box::new(e)));\n            }\n        };\n        axum::serve(listener, app)\n            .with_graceful_shutdown(shutdown_signal())\n            .await\n            .map_err(|err| WebServerError::Axum(Box::new(err)))?;\n    }\n    Ok(())\n}\n\n/// get model info from the Huggingface Hub\npub async fn get_hub_model_info(api: &ApiRepo) -> Option<HubModelInfo> {\n    let response = api.info_request().send().await.ok()?;\n\n    if response.status().is_success() {\n        let hub_model_info: HubModelInfo =\n            serde_json::from_str(&response.text().await.ok()?).ok()?;\n        if let Some(sha) = &hub_model_info.sha {\n            tracing::info!(\n                \"Serving revision {sha} of model {}\",\n                hub_model_info.model_id\n            );\n        }\n        Some(hub_model_info)\n    } else {\n        None\n    }\n}\n\n/// get tokenizer_config from the Huggingface Hub\npub async fn get_tokenizer_config(api_repo: &ApiRepo) -> Option<HubTokenizerConfig> {\n    let tokenizer_config_filename = api_repo.get(\"tokenizer_config.json\").await.ok()?;\n\n    // Open the file in read-only mode with buffer.\n    let file = File::open(tokenizer_config_filename).ok()?;\n    let reader = BufReader::new(file);\n\n    // Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.\n    let tokenizer_config: HubTokenizerConfig = serde_json::from_reader(reader)\n        .map_err(|e| {\n            tracing::warn!(\"Unable to parse tokenizer config: {}\", e);\n            e\n        })\n        .ok()?;\n\n    Some(tokenizer_config)\n}\n\n/// Shutdown signal handler\nasync fn shutdown_signal() {\n    let ctrl_c = async {\n        signal::ctrl_c()\n            .await\n            .expect(\"failed to install Ctrl+C handler\");\n    };\n\n    #[cfg(unix)]\n    let terminate = async {\n        signal::unix::signal(signal::unix::SignalKind::terminate())\n            .expect(\"failed to install signal handler\")\n            .recv()\n            .await;\n    };\n\n    #[cfg(not(unix))]\n    let terminate = std::future::pending::<()>();\n\n    tokio::select! {\n        _ = ctrl_c => {},\n        _ = terminate => {},\n    }\n\n    tracing::info!(\"signal received, starting graceful shutdown\");\n    opentelemetry::global::shutdown_tracer_provider();\n}\n\n/// Convert to Axum supported formats\nimpl From<InferError> for (StatusCode, Json<ErrorResponse>) {\n    fn from(err: InferError) -> Self {\n        let status_code = match err {\n            InferError::GenerationError(_) => StatusCode::FAILED_DEPENDENCY,\n            InferError::Overloaded(_) => StatusCode::TOO_MANY_REQUESTS,\n            InferError::ValidationError(_) => StatusCode::UNPROCESSABLE_ENTITY,\n            InferError::IncompleteGeneration => StatusCode::INTERNAL_SERVER_ERROR,\n            InferError::IncompleteGenerationStream => StatusCode::INTERNAL_SERVER_ERROR,\n            InferError::TemplateError(_) => StatusCode::UNPROCESSABLE_ENTITY,\n            InferError::MissingTemplateVariable(_) => StatusCode::UNPROCESSABLE_ENTITY,\n            InferError::ToolError(_) => StatusCode::UNPROCESSABLE_ENTITY,\n            InferError::StreamSerializationError(_) => StatusCode::INTERNAL_SERVER_ERROR,\n        };\n\n        (\n            status_code,\n            Json(ErrorResponse {\n                error: err.to_string(),\n                error_type: err.error_type().to_string(),\n            }),\n        )\n    }\n}\n\nimpl From<InferError> for Event {\n    fn from(err: InferError) -> Self {\n        Event::default()\n            .json_data(ErrorResponse {\n                error: err.to_string(),\n                error_type: err.error_type().to_string(),\n            })\n            .unwrap()\n    }\n}\n\n#[derive(Debug, Error)]\npub enum WebServerError {\n    #[error(\"Axum error: {0}\")]\n    Axum(#[from] axum::BoxError),\n    #[error(\"Tokenizer error: {0}\")]\n    Tokenizer(String),\n}\n"
  },
  {
    "path": "router/src/usage_stats.rs",
    "content": "use crate::config::Config;\nuse clap::ValueEnum;\nuse csv::ReaderBuilder;\nuse reqwest::header::HeaderMap;\nuse serde::Serialize;\nuse std::{\n    fs::File,\n    io::{self, BufRead},\n    path::Path,\n    process::Command,\n    time::Duration,\n};\nuse uuid::Uuid;\n\nconst TELEMETRY_URL: &str = \"https://huggingface.co/api/telemetry/tgi\";\n\n#[derive(Copy, Clone, Debug, Serialize, ValueEnum)]\npub enum UsageStatsLevel {\n    On,\n    NoStack,\n    Off,\n}\n\n#[derive(Debug, Clone, Serialize)]\npub struct UserAgent {\n    pub uid: String,\n    pub args: Args,\n    pub env: Env,\n}\n\nimpl UserAgent {\n    pub fn new(reduced_args: Args) -> Self {\n        Self {\n            uid: Uuid::new_v4().to_string(),\n            args: reduced_args,\n            env: Env::new(),\n        }\n    }\n}\n\n#[derive(Serialize, Debug)]\npub enum EventType {\n    Start,\n    Stop,\n    Error,\n    Ping,\n}\n\n#[derive(Debug, Serialize)]\npub struct UsageStatsEvent {\n    user_agent: UserAgent,\n    event_type: EventType,\n    #[serde(skip_serializing_if = \"Option::is_none\")]\n    error_reason: Option<String>,\n}\n\nimpl UsageStatsEvent {\n    pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option<String>) -> Self {\n        Self {\n            user_agent,\n            event_type,\n            error_reason,\n        }\n    }\n    pub async fn send(&self) {\n        let mut headers = HeaderMap::new();\n        headers.insert(\"Content-Type\", \"application/json\".parse().unwrap());\n        let body = serde_json::to_string(&self).unwrap();\n        let client = reqwest::Client::new();\n        let _ = client\n            .post(TELEMETRY_URL)\n            .headers(headers)\n            .body(body)\n            .timeout(Duration::from_secs(10))\n            .send()\n            .await;\n    }\n}\n\n#[derive(Debug, Clone, Serialize)]\npub struct Args {\n    model_config: Option<Config>,\n    tokenizer_class: Option<String>,\n    max_concurrent_requests: usize,\n    max_best_of: usize,\n    max_stop_sequences: usize,\n    max_top_n_tokens: u32,\n    max_input_tokens: usize,\n    max_total_tokens: usize,\n    // waiting_served_ratio: f32,\n    // max_batch_prefill_tokens: u32,\n    // max_batch_total_tokens: Option<u32>,\n    // max_waiting_tokens: usize,\n    // max_batch_size: Option<usize>,\n    revision: Option<String>,\n    validation_workers: usize,\n    disable_grammar_support: bool,\n    max_client_batch_size: usize,\n    usage_stats_level: UsageStatsLevel,\n    backend_name: &'static str,\n    origin: Option<String>,\n}\n\nimpl Args {\n    #[allow(clippy::too_many_arguments)]\n    pub fn new(\n        model_config: Option<Config>,\n        tokenizer_class: Option<String>,\n        max_concurrent_requests: usize,\n        max_best_of: usize,\n        max_stop_sequences: usize,\n        max_top_n_tokens: u32,\n        max_input_tokens: usize,\n        max_total_tokens: usize,\n        // waiting_served_ratio: f32,\n        // max_batch_prefill_tokens: u32,\n        // max_batch_total_tokens: Option<u32>,\n        // max_waiting_tokens: usize,\n        // max_batch_size: Option<usize>,\n        revision: Option<String>,\n        validation_workers: usize,\n        disable_grammar_support: bool,\n        max_client_batch_size: usize,\n        usage_stats_level: UsageStatsLevel,\n        backend_name: &'static str,\n        origin: Option<String>,\n    ) -> Self {\n        Self {\n            model_config,\n            tokenizer_class,\n            max_concurrent_requests,\n            max_best_of,\n            max_stop_sequences,\n            max_top_n_tokens,\n            max_input_tokens,\n            max_total_tokens,\n            // waiting_served_ratio,\n            // max_batch_prefill_tokens,\n            // max_batch_total_tokens,\n            // max_waiting_tokens,\n            // max_batch_size,\n            revision,\n            validation_workers,\n            disable_grammar_support,\n            max_client_batch_size,\n            usage_stats_level,\n            backend_name,\n            origin,\n        }\n    }\n}\n\n/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency\n#[derive(Serialize, Debug, Clone)]\npub struct Env {\n    git_sha: &'static str,\n    docker_label: &'static str,\n    nvidia_info: Option<Vec<NvidiaSmiInfo>>,\n    xpu_info: Option<Vec<XpuSmiInfo>>,\n    hpu_info: Option<Vec<HpuSmiInfo>>,\n    system_env: SystemInfo,\n}\n\n#[derive(Debug, Serialize, Clone)]\nstruct NvidiaSmiInfo {\n    name: String,\n    pci_bus_id: String,\n    driver_version: String,\n    pstate: String,\n    pcie_link_gen_max: String,\n    pcie_link_gen_current: String,\n    temperature_gpu: String,\n    utilization_gpu: String,\n    utilization_memory: String,\n    memory_total: String,\n    memory_free: String,\n    memory_used: String,\n    reset_status_reset_required: String,\n    reset_status_drain_and_reset_recommended: String,\n    compute_cap: String,\n    ecc_errors_corrected_volatile_total: String,\n    mig_mode_current: String,\n    power_draw_instant: String,\n    power_limit: String,\n}\n\nimpl NvidiaSmiInfo {\n    fn new() -> Option<Vec<NvidiaSmiInfo>> {\n        let output = Command::new(\"nvidia-smi\")\n            .args([\n                \"--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit\",\n                \"--format=csv\"\n            ])\n            .output()\n            .ok()?;\n\n        if !output.status.success() {\n            return None;\n        }\n\n        let stdout = String::from_utf8(output.stdout).ok()?;\n\n        let mut rdr = ReaderBuilder::new()\n            .has_headers(true)\n            .from_reader(stdout.as_bytes());\n\n        let mut infos = Vec::new();\n\n        for result in rdr.records() {\n            let record = result.ok()?;\n            infos.push(NvidiaSmiInfo {\n                name: record[0].to_string(),\n                pci_bus_id: record[1].to_string(),\n                driver_version: record[2].to_string(),\n                pstate: record[3].to_string(),\n                pcie_link_gen_max: record[4].to_string(),\n                pcie_link_gen_current: record[5].to_string(),\n                temperature_gpu: record[6].to_string(),\n                utilization_gpu: record[7].to_string(),\n                utilization_memory: record[8].to_string(),\n                memory_total: record[9].to_string(),\n                memory_free: record[10].to_string(),\n                memory_used: record[11].to_string(),\n                reset_status_reset_required: record[12].to_string(),\n                reset_status_drain_and_reset_recommended: record[13].to_string(),\n                compute_cap: record[14].to_string(),\n                ecc_errors_corrected_volatile_total: record[15].to_string(),\n                mig_mode_current: record[16].to_string(),\n                power_draw_instant: record[17].to_string(),\n                power_limit: record[18].to_string(),\n            });\n        }\n\n        Some(infos)\n    }\n}\n\n#[derive(Debug, Serialize, Clone)]\nstruct XpuSmiInfo {\n    device_id: usize,\n    gpu_utilization: f32,\n    gpu_power: f32,\n    gpu_core_temperature: f32,\n    gpu_memory_bandwidth_utilization: f32,\n}\n\nimpl XpuSmiInfo {\n    /// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format\n    fn new() -> Option<Vec<XpuSmiInfo>> {\n        let output = Command::new(\"xpu-smi\")\n            .args([\n                \"dump\", \"-d\", \"-1\", \"-m\",\n                \"0,1,3,17\", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization\n                \"-n\", \"1\", \"-j\",\n            ])\n            .output()\n            .ok()?;\n\n        if !output.status.success() {\n            return None;\n        }\n\n        let stdout = String::from_utf8(output.stdout).ok()?;\n        let mut infos = Vec::new();\n\n        let json_data: serde_json::Value = match serde_json::from_str(&stdout) {\n            Ok(data) => data,\n            Err(_) => return None,\n        };\n\n        if let Some(metrics_data) = json_data.as_array() {\n            for entry in metrics_data {\n                let device_id = entry[\"deviceId\"].as_u64()? as usize;\n                let gpu_utilization = entry[\"metrics\"][0].as_f64()? as f32;\n                let gpu_power = entry[\"metrics\"][1].as_f64()? as f32;\n                let gpu_core_temperature = entry[\"metrics\"][2].as_f64()? as f32;\n                let gpu_memory_bandwidth_utilization = entry[\"metrics\"][3].as_f64()? as f32;\n\n                infos.push(XpuSmiInfo {\n                    device_id,\n                    gpu_utilization,\n                    gpu_power,\n                    gpu_core_temperature,\n                    gpu_memory_bandwidth_utilization,\n                });\n            }\n        }\n\n        Some(infos)\n    }\n}\n\n#[derive(Debug, Serialize, Clone)]\nstruct HpuSmiInfo {\n    name: String,\n    pci_bus_id: String,\n    driver_version: String,\n    temperature: String,\n    utilization: String,\n    memory_total: String,\n    memory_free: String,\n    memory_used: String,\n    power_draw_instant: String,\n}\n\nimpl HpuSmiInfo {\n    fn new() -> Option<Vec<HpuSmiInfo>> {\n        let output = Command::new(\"hl-smi\")\n            .args([\n                \"--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw\",\n                \"--format=csv\"\n            ])\n            .output()\n            .ok()?;\n\n        if !output.status.success() {\n            return None;\n        }\n\n        let stdout = String::from_utf8(output.stdout).ok()?;\n\n        let mut rdr = ReaderBuilder::new()\n            .has_headers(true)\n            .from_reader(stdout.as_bytes());\n\n        let mut infos = Vec::new();\n\n        for result in rdr.records() {\n            let record = result.ok()?;\n            infos.push(HpuSmiInfo {\n                name: record[0].to_string(),\n                pci_bus_id: record[1].to_string(),\n                driver_version: record[2].to_string(),\n                temperature: record[3].to_string(),\n                utilization: record[4].to_string(),\n                memory_total: record[5].to_string(),\n                memory_free: record[6].to_string(),\n                memory_used: record[7].to_string(),\n                power_draw_instant: record[8].to_string(),\n            });\n        }\n\n        Some(infos)\n    }\n}\n\n#[derive(Serialize, Debug, Clone)]\npub struct SystemInfo {\n    cpu_count: usize,\n    cpu_type: String,\n    total_memory: u64,\n    architecture: String,\n    platform: String,\n}\n\nimpl SystemInfo {\n    fn new() -> Self {\n        let mut system = sysinfo::System::new_all();\n        system.refresh_all();\n\n        let cpu_count = system.cpus().len();\n        let cpu_type = system.cpus()[0].brand().to_string();\n        let total_memory = system.total_memory();\n        let architecture = std::env::consts::ARCH.to_string();\n        let platform = format!(\n            \"{}-{}-{}\",\n            std::env::consts::OS,\n            std::env::consts::FAMILY,\n            std::env::consts::ARCH\n        );\n        Self {\n            cpu_count,\n            cpu_type,\n            total_memory,\n            architecture,\n            platform,\n        }\n    }\n}\n\nimpl Default for Env {\n    fn default() -> Self {\n        Self::new()\n    }\n}\n\nimpl Env {\n    pub fn new() -> Self {\n        Self {\n            system_env: SystemInfo::new(),\n            nvidia_info: NvidiaSmiInfo::new(),\n            xpu_info: XpuSmiInfo::new(),\n            hpu_info: HpuSmiInfo::new(),\n            git_sha: option_env!(\"VERGEN_GIT_SHA\").unwrap_or(\"N/A\"),\n            docker_label: option_env!(\"DOCKER_LABEL\").unwrap_or(\"N/A\"),\n        }\n    }\n    pub fn is_hpu_device(&self) -> bool {\n        self.hpu_info.is_some()\n    }\n}\n\npub fn is_container() -> io::Result<bool> {\n    let path = Path::new(\"/proc/self/cgroup\");\n    let file = File::open(path)?;\n    let reader = io::BufReader::new(file);\n\n    for line in reader.lines() {\n        let line = line?;\n        // Check for common container runtimes\n        if line.contains(\"/docker/\")\n            || line.contains(\"/docker-\")\n            || line.contains(\"/kubepods/\")\n            || line.contains(\"/kubepods-\")\n            || line.contains(\"containerd\")\n            || line.contains(\"crio\")\n            || line.contains(\"podman\")\n        {\n            return Ok(true);\n        }\n    }\n    Ok(false)\n}\n"
  },
  {
    "path": "router/src/validation.rs",
    "content": "use crate::config::Config;\nuse crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};\nuse crate::{\n    GenerateParameters, GenerateRequest, GrammarType, HubPreprocessorConfig, Idefics2Preprocessor,\n    TokenizerTrait,\n};\nuse crate::{PyTokenizer, Tokenizer};\nuse base64::{engine::general_purpose::STANDARD, Engine};\nuse image::{ImageFormat, ImageReader};\nuse outlines_core::json_schema::to_regex as json_schema_to_regex;\nuse rand::{thread_rng, Rng};\nuse serde_json::Value;\n/// Payload validation logic\nuse std::cmp::min;\nuse std::io::{Cursor, Read};\nuse std::iter;\nuse std::sync::Arc;\nuse thiserror::Error;\nuse tokio::sync::mpsc;\nuse tokio::sync::oneshot;\nuse tracing::warn;\nuse tracing::{instrument, Span};\nuse {once_cell::sync::Lazy, regex::Regex};\n\nstatic DEFAULT_GENERATION_LENGTH: u32 = 1024;\n\n/// Validation\n#[derive(Debug, Clone)]\npub struct Validation {\n    /// Validation parameters\n    max_best_of: usize,\n    max_stop_sequences: usize,\n    max_top_n_tokens: u32,\n    max_input_length: usize,\n    max_total_tokens: usize,\n    disable_grammar_support: bool,\n    /// Channel to communicate with the background tokenization task\n    sender: mpsc::UnboundedSender<TokenizerRequest>,\n}\n\nimpl Validation {\n    #[allow(clippy::too_many_arguments)]\n    pub(crate) fn new(\n        workers: usize,\n        tokenizer: Tokenizer,\n        config: Option<Config>,\n        preprocessor_config: Option<HubPreprocessorConfig>,\n        max_best_of: usize,\n        max_stop_sequences: usize,\n        max_top_n_tokens: u32,\n        max_input_length: usize,\n        max_total_tokens: usize,\n        disable_grammar_support: bool,\n        max_image_fetch_size: usize,\n    ) -> Self {\n        let workers = if let Tokenizer::Python { .. } = &tokenizer {\n            1\n        } else {\n            workers\n        };\n        // If we have a fast tokenizer\n        let sender = {\n            // Create round robin channel\n            let (validation_sender, validation_round_robin_receiver) = mpsc::unbounded_channel();\n            let mut senders = Vec::with_capacity(workers);\n\n            // Create workers\n            for _ in 0..workers {\n                let tokenizer_clone = tokenizer.clone();\n                let config_clone = config.clone();\n                let preprocessor_config_clone = preprocessor_config.clone();\n                let (tokenizer_sender, tokenizer_receiver) = mpsc::unbounded_channel();\n                senders.push(tokenizer_sender);\n\n                // Spawn worker\n                tokio::task::spawn_blocking(move || {\n                    tokenizer_worker(\n                        tokenizer_clone,\n                        config_clone,\n                        preprocessor_config_clone,\n                        tokenizer_receiver,\n                        max_image_fetch_size,\n                    )\n                });\n            }\n\n            // Create tokenization round robin task\n            tokio::spawn(round_robin_task(validation_round_robin_receiver, senders));\n\n            validation_sender\n        };\n\n        Self {\n            max_best_of,\n            sender,\n            max_stop_sequences,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n        }\n    }\n\n    #[instrument(skip(self, inputs))]\n    pub async fn tokenize(\n        &self,\n        inputs: String,\n        add_special_tokens: bool,\n        truncate: Option<usize>,\n    ) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {\n        // If we have a fast tokenizer\n        // Create response channel\n        let (response_sender, response_receiver) = oneshot::channel();\n        // Send request to the background validation task\n        // Unwrap is safe here\n        let _ = &self\n            .sender\n            .send((\n                (inputs, add_special_tokens, truncate),\n                response_sender,\n                Span::current(),\n            ))\n            .unwrap();\n\n        // Await on response channel\n        // Unwrap is safe here\n        let encoding = response_receiver.await.unwrap()?;\n        Ok(encoding)\n    }\n\n    #[allow(clippy::type_complexity)]\n    #[instrument(skip(self, inputs))]\n    async fn validate_input(\n        &self,\n        inputs: String,\n        add_special_tokens: bool,\n        truncate: Option<usize>,\n        max_new_tokens: Option<u32>,\n    ) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32, u32), ValidationError> {\n        // If we have a fast tokenizer\n        let (encoding, inputs) = self\n            .tokenize(inputs.clone(), add_special_tokens, truncate)\n            .await?;\n        // Create response channel\n        let input_length = if let Some(truncate) = truncate {\n            std::cmp::min(encoding.len(), truncate)\n        } else {\n            encoding.len()\n        };\n\n        // Get total tokens\n        let (max_new_tokens, max_total_new_tokens) = if let Some(max_new_tokens) = max_new_tokens {\n            // Do not accept humongous max_new_tokens queries.\n            // We preallocate the default but we prevent a single user\n            // from taking up all the slots in a handful of queries that consume little\n            // amount of tokens. (You can have 10 token long query that creates a handful of token\n            // but the requested amount to be 120k.\n            let chunk_size = min(max_new_tokens, DEFAULT_GENERATION_LENGTH);\n            (chunk_size, max_new_tokens)\n        } else {\n            // Use the maximum possible number of tokens as default\n            // However, the system will re-queue the request everytime it completes\n            // `DEFAULT_GENERATION_LENGTH` tokens.\n            let max_new_tokens = self.max_total_tokens.saturating_sub(input_length) as u32;\n            (\n                min(max_new_tokens, DEFAULT_GENERATION_LENGTH),\n                max_new_tokens,\n            )\n        };\n        let total_tokens = input_length + max_new_tokens as usize;\n\n        // Validate MaxTotalTokens\n        if total_tokens > self.max_total_tokens {\n            return Err(ValidationError::MaxTotalTokens(\n                self.max_total_tokens,\n                input_length,\n                max_new_tokens,\n            ));\n        }\n\n        // Validate InputLength\n        if input_length > self.max_input_length {\n            return Err(ValidationError::InputLength(\n                self.max_input_length,\n                input_length,\n            ));\n        }\n\n        let ids = encoding.get_ids();\n        let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned();\n\n        metrics::histogram!(\"tgi_request_input_length\").record(input_length as f64);\n        Ok((\n            inputs,\n            Some(input_ids),\n            input_length,\n            max_new_tokens,\n            max_total_new_tokens,\n        ))\n    }\n\n    /// Validate a payload and get the number of tokens in the input\n    #[instrument(skip_all)]\n    pub(crate) async fn validate(\n        &self,\n        request: GenerateRequest,\n    ) -> Result<ValidGenerateRequest, ValidationError> {\n        let GenerateParameters {\n            best_of,\n            temperature,\n            repetition_penalty,\n            frequency_penalty,\n            top_k,\n            top_p,\n            typical_p,\n            do_sample,\n            max_new_tokens,\n            stop: stop_sequences,\n            truncate,\n            seed,\n            watermark,\n            decoder_input_details,\n            top_n_tokens,\n            grammar,\n            adapter_id,\n            ..\n        } = request.parameters;\n\n        // sampling must be true when best_of > 1\n        let best_of = best_of.unwrap_or(1);\n        let sampling = do_sample\n            || temperature.is_some()\n            || top_k.is_some()\n            || top_p.is_some()\n            || typical_p.is_some();\n\n        if best_of > 1 && !sampling {\n            return Err(BestOfSampling);\n        }\n\n        let temperature = temperature.unwrap_or(1.0);\n        if temperature <= 0.0 {\n            return Err(ValidationError::Temperature);\n        }\n\n        let repetition_penalty = repetition_penalty.unwrap_or(1.0);\n        if repetition_penalty <= 0.0 {\n            return Err(ValidationError::RepetitionPenalty);\n        }\n\n        let frequency_penalty = frequency_penalty.unwrap_or(0.0);\n        if !(-2.0..=2.0).contains(&frequency_penalty) {\n            return Err(ValidationError::FrequencyPenalty);\n        }\n\n        // Different because the proto default value is not a valid value\n        // for the user\n        let top_p = top_p\n            .map(|value| {\n                if value <= 0.0 || value >= 1.0 {\n                    return Err(ValidationError::TopP);\n                }\n                Ok(value)\n            })\n            .unwrap_or(Ok(1.0))?;\n\n        let typical_p = typical_p\n            .map(|value| {\n                if value <= 0.0 || value >= 1.0 {\n                    return Err(ValidationError::TypicalP);\n                }\n                Ok(value)\n            })\n            .unwrap_or(Ok(1.0))?;\n\n        let top_k: u32 = top_k\n            .map(|value| {\n                if value <= 0 {\n                    return Err(ValidationError::TopK);\n                }\n                Ok(value as u32)\n            })\n            .unwrap_or(Ok(0))?;\n\n        if max_new_tokens == Some(0) {\n            return Err(ValidationError::NegativeMaxNewTokens);\n        }\n\n        if stop_sequences.len() > self.max_stop_sequences {\n            return Err(ValidationError::StopSequence(\n                self.max_stop_sequences,\n                stop_sequences.len(),\n            ));\n        }\n\n        // If seed is None, assign a random one\n        let seed = match seed {\n            None => thread_rng().gen(),\n            Some(seed) => {\n                if best_of > 1 {\n                    return Err(BestOfSeed);\n                }\n                seed\n            }\n        };\n\n        let top_n_tokens = top_n_tokens\n            .map(|value| {\n                if value > self.max_top_n_tokens {\n                    return Err(ValidationError::TopNTokens(self.max_top_n_tokens, value));\n                }\n                Ok(value)\n            })\n            .unwrap_or(Ok(0))?;\n\n        // Check if inputs is empty\n        if request.inputs.is_empty() {\n            return Err(EmptyInput);\n        }\n\n        // Check if truncate is strictly positive and less than max_input_length\n        let truncate = truncate\n            .map(|value| {\n                if value == 0 || value > self.max_input_length {\n                    return Err(ValidationError::Truncate(self.max_input_length, value));\n                }\n                Ok(Some(value))\n            })\n            .unwrap_or(Ok(None))?;\n\n        // Validate inputs\n        let (inputs, input_ids, input_length, max_new_tokens, max_total_new_tokens) = self\n            .validate_input(\n                request.inputs,\n                request.add_special_tokens,\n                truncate,\n                max_new_tokens,\n            )\n            .await?;\n\n        // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar\n        // NOTE: this is currently difficult because we need the tokenizer in Python to build\n        // the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which\n        // may be slow and memory intensive. Best case is to have a Rust implementation of the FSM\n        // compiler and use that to build the FSM here.\n\n        // Validate grammar and unpack the grammar and type for the proto message\n        let grammar = match grammar {\n            Some(grammar) => {\n                // Ensure that grammar is not set if it's not supported\n                if self.disable_grammar_support {\n                    return Err(ValidationError::Grammar);\n                }\n                let valid_grammar = match grammar {\n                    GrammarType::Json(json) => {\n                        let json = match json {\n                            // if value is a string, we need to parse it again to make sure its\n                            // a valid json\n                            Value::String(s) => serde_json::from_str(&s)\n                                .map_err(|e| ValidationError::InvalidGrammar(e.to_string())),\n                            Value::Object(_) => Ok(json),\n                            _ => Err(ValidationError::Grammar),\n                        }?;\n\n                        // Check if the json is a valid JSONSchema\n                        jsonschema::draft202012::meta::validate(&json)\n                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;\n\n                        // The schema can be valid but lack properties.\n                        // We need properties for the grammar to be successfully parsed in Python.\n                        // Therefore, we must check and throw an error if properties are missing.\n                        json.get(\"properties\")\n                            .ok_or(ValidationError::InvalidGrammar(\n                                \"Grammar must have a 'properties' field\".to_string(),\n                            ))?;\n\n                        // Do compilation in the router for performance. In the future, we\n                        // should also move regex -> automaton compilation in the router,\n                        // but this is not yet supported in pure Rust by outlines-core.\n                        let grammar_regex = json_schema_to_regex(&json, None, &json)\n                            .map_err(ValidationError::RegexFromSchema)?;\n\n                        ValidGrammar::Regex(grammar_regex.to_string())\n                    }\n                    GrammarType::JsonSchema(schema_config) => {\n                        // Extract the actual schema for validation\n                        let json = &schema_config.schema;\n\n                        // Check if the json is a valid JSONSchema\n                        jsonschema::draft202012::meta::validate(json)\n                            .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;\n\n                        // The schema can be valid but lack properties.\n                        // We need properties for the grammar to be successfully parsed in Python.\n                        // Therefore, we must check and throw an error if properties are missing.\n                        json.get(\"properties\")\n                            .ok_or(ValidationError::InvalidGrammar(\n                                \"Grammar must have a 'properties' field\".to_string(),\n                            ))?;\n\n                        // Do compilation in the router for performance\n                        let grammar_regex = json_schema_to_regex(json, None, json)\n                            .map_err(ValidationError::RegexFromSchema)?;\n\n                        ValidGrammar::Regex(grammar_regex.to_string())\n                    }\n                    GrammarType::Regex(regex) => ValidGrammar::Regex(regex),\n                };\n                Some(valid_grammar)\n            }\n            None => None,\n        };\n\n        let parameters = ValidParameters {\n            temperature,\n            repetition_penalty,\n            frequency_penalty,\n            top_k,\n            top_p,\n            typical_p,\n            do_sample,\n            seed,\n            watermark,\n            grammar,\n        };\n        let stopping_parameters = ValidStoppingParameters {\n            max_new_tokens,\n            max_total_new_tokens,\n            stop_sequences,\n            ignore_eos_token: false,\n        };\n\n        metrics::histogram!(\"tgi_request_max_new_tokens\").record(max_new_tokens as f64);\n\n        Ok(ValidGenerateRequest {\n            inputs,\n            input_ids: input_ids.map(Arc::new),\n            add_special_tokens: request.add_special_tokens,\n            decoder_input_details,\n            input_length: input_length as u32,\n            truncate: truncate.unwrap_or(self.max_input_length) as u32,\n            parameters,\n            stopping_parameters,\n            top_n_tokens,\n            adapter_id,\n        })\n    }\n\n    /// Validate the best_of parameter\n    #[instrument(skip_all)]\n    pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> {\n        if self.max_best_of == 1 && best_of != 1 {\n            return Err(ValidationError::BestOfDisabled);\n        }\n\n        if best_of > self.max_best_of {\n            return Err(ValidationError::BestOf(self.max_best_of, best_of));\n        }\n\n        Ok(best_of)\n    }\n}\n\n/// Round robin tokenization task\nasync fn round_robin_task(\n    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,\n    senders: Vec<mpsc::UnboundedSender<TokenizerRequest>>,\n) {\n    loop {\n        for sender in &senders {\n            match receiver.recv().await {\n                None => return,\n                Some(request) => sender.send(request).unwrap(),\n            };\n        }\n    }\n}\n\n/// Start tokenization workers\nfn tokenizer_worker(\n    tokenizer: Tokenizer,\n    config: Option<Config>,\n    preprocessor_config: Option<HubPreprocessorConfig>,\n    mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,\n    max_image_fetch_size: usize,\n) {\n    match tokenizer {\n        Tokenizer::Python {\n            tokenizer_name,\n            revision,\n            trust_remote_code,\n        } => {\n            pyo3::Python::with_gil(|py| -> pyo3::PyResult<()> {\n                let tokenizer =\n                    PyTokenizer::from_py(py, tokenizer_name, revision, trust_remote_code)?;\n                // Loop over requests\n                while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =\n                    receiver.blocking_recv()\n                {\n                    parent_span.in_scope(|| {\n                        response_tx\n                            .send(prepare_input(\n                                inputs,\n                                truncate,\n                                add_special_tokens,\n                                &tokenizer,\n                                config.as_ref(),\n                                preprocessor_config.as_ref(),\n                                max_image_fetch_size,\n                            ))\n                            .unwrap_or(())\n                    })\n                }\n                Ok(())\n            })\n            .expect(\"Failure in python tokenizer worker\");\n        }\n        Tokenizer::Rust(tokenizer) => {\n            while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =\n                receiver.blocking_recv()\n            {\n                parent_span.in_scope(|| {\n                    response_tx\n                        .send(prepare_input(\n                            inputs,\n                            truncate,\n                            add_special_tokens,\n                            &tokenizer,\n                            config.as_ref(),\n                            preprocessor_config.as_ref(),\n                            max_image_fetch_size,\n                        ))\n                        .unwrap_or(())\n                })\n            }\n        }\n    }\n}\n\nfn format_from_mimetype(mimetype: &str) -> Option<ImageFormat> {\n    match mimetype {\n        \"image/png\" => Some(ImageFormat::Png),\n        \"image/jpeg\" => Some(ImageFormat::Jpeg),\n        \"image/jpg\" => Some(ImageFormat::Jpeg),\n        \"image/gif\" => Some(ImageFormat::Gif),\n        \"image/webp\" => Some(ImageFormat::WebP),\n        \"image/tiff\" => Some(ImageFormat::Tiff),\n        // \"image/pnm\"=>Some(ImageFormat::Pnm),\n        // \"image/tga\"=>Some(ImageFormat::Tga),\n        // \"image/dds\"=>Some(ImageFormat::Dds),\n        // \"image/bmp\"=>Some(ImageFormat::Bmp),\n        // \"image/ico\"=>Some(ImageFormat::Ico),\n        // \"image/x-exr\"=>Some(ImageFormat::OpenExr),\n        _ => None,\n    }\n}\n\nfn format_to_mimetype(format: ImageFormat) -> String {\n    match format {\n        ImageFormat::Png => \"image/png\",\n        ImageFormat::Jpeg => \"image/jpeg\",\n        ImageFormat::Gif => \"image/gif\",\n        ImageFormat::WebP => \"image/webp\",\n        ImageFormat::Tiff => \"image/tiff\",\n        _ => \"application/octet-stream\",\n    }\n    .to_string()\n}\n\nfn fetch_image(\n    input: &str,\n    max_image_fetch_size: usize,\n) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {\n    if input.starts_with(\"![](http://\") || input.starts_with(\"![](https://\") {\n        let url = &input[\"![](\".len()..input.len() - 1];\n        let response = reqwest::blocking::get(url)?;\n\n        // Check Content-Length header if present\n        if let Some(content_length) = response.content_length() {\n            if content_length as usize > max_image_fetch_size {\n                return Err(ValidationError::ImageTooLarge(\n                    content_length as usize,\n                    max_image_fetch_size,\n                ));\n            }\n        }\n\n        // Read the body with size limit to prevent unbounded memory allocation\n        let mut data = Vec::new();\n        let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);\n        limited_reader.read_to_end(&mut data)?;\n\n        if data.len() > max_image_fetch_size {\n            return Err(ValidationError::ImageTooLarge(\n                data.len(),\n                max_image_fetch_size,\n            ));\n        }\n\n        let format = image::guess_format(&data)?;\n        // TODO Remove this clone\n        let img = ImageReader::with_format(Cursor::new(data.clone()), format).decode()?;\n        let height: usize = img.height().try_into()?;\n        let width: usize = img.width().try_into()?;\n        let mimetype = format_to_mimetype(format);\n        Ok((data.to_vec(), mimetype, height, width))\n    } else if input.starts_with(\"![](data:\") {\n        // Remove ![](....)\n        let content = &input[\"![](data:\".len()..input.len() - 1];\n        let tokens: Vec<_> = content.split(';').collect();\n        if tokens.len() != 2 {\n            return Err(ValidationError::InvalidImageContent(content.to_string()));\n        }\n        let mimetype = tokens[0];\n        let content = tokens[1];\n\n        if !content.starts_with(\"base64,\") {\n            return Err(ValidationError::InvalidImageContent(content.to_string()));\n        }\n\n        let data = STANDARD.decode(&content[\"base64,\".len()..])?;\n        let img = if let Some(format) = format_from_mimetype(mimetype) {\n            ImageReader::with_format(Cursor::new(&data), format).decode()?\n        } else {\n            ImageReader::new(Cursor::new(&data))\n                .with_guessed_format()\n                .map_err(|_io_error| ValidationError::InvalidImageContent(content.to_string()))?\n                .decode()?\n        };\n\n        let height: usize = img.height().try_into()?;\n        let width: usize = img.width().try_into()?;\n        Ok((data, mimetype.to_string(), height, width))\n    } else {\n        Err(ValidationError::InvalidImageContent(input.to_string()))\n    }\n}\n\nfn image_tokens(\n    config: &Config,\n    preprocessor_config: Option<&HubPreprocessorConfig>,\n    height: usize,\n    width: usize,\n) -> String {\n    use Config::*;\n    use HubPreprocessorConfig::*;\n    match config {\n        Idefics => \"<image>\".to_string(),\n        Mllama => \"<|image|>\".to_string(),\n        Idefics2(config) => {\n            const FAKE: &str = \"<fake_token_around_image>\";\n            const IMAGE: &str = \"<image>\";\n\n            let slots = config.get_number_of_features(height, width);\n\n            let mut image_string = String::with_capacity(2 * FAKE.len() + slots * IMAGE.len());\n            image_string.push_str(FAKE);\n            image_string.extend(iter::repeat_n(IMAGE, slots));\n            image_string.push_str(FAKE);\n\n            if matches!(\n                preprocessor_config,\n                Some(Idefics2Processor(Idefics2Preprocessor {\n                    do_image_splitting: true,\n                    ..\n                }))\n            ) {\n                image_string = image_string.repeat(5);\n            };\n\n            image_string\n        }\n        Idefics3(config) => {\n            const FAKE: &str = \"<fake_token_around_image>\";\n            const IMAGE: &str = \"<image>\";\n            const GLOBAL_IMG: &str = \"<global-img>\";\n\n            let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();\n            let max_image_size = config.get_max_image_size();\n\n            let (height, width) = {\n                let h = height as f32;\n                let w = width as f32;\n\n                // First resize to max_longest_edge (always scale to this size)\n                let scale1 = max_longest_edge_for_image_resize as f32 / h.max(w);\n                let (h, w) = (h * scale1, w * scale1);\n\n                // Ensure we dont exceed max_size (only scale down)\n                let scale2 = (max_image_size as f32 / h.max(w)).min(1.0);\n\n                ((h * scale2) as usize, (w * scale2) as usize)\n            };\n\n            let image_seq_len = config.get_number_of_features();\n            let max_edge = config.get_max_longest_edge();\n\n            let (image_rows, image_cols) = if height > max_edge || width > max_edge {\n                (\n                    (height as f32 / max_edge as f32).ceil() as usize,\n                    (width as f32 / max_edge as f32).ceil() as usize,\n                )\n            } else {\n                (0, 0)\n            };\n\n            let mut image_string = String::new();\n\n            if image_rows == 0 && image_cols == 0 {\n                // Single image case\n                image_string.push_str(FAKE);\n                image_string.push_str(GLOBAL_IMG);\n                image_string.push_str(&IMAGE.repeat(image_seq_len));\n                image_string.push_str(FAKE);\n            } else {\n                // Split image case\n                for n_h in 0..image_rows {\n                    for n_w in 0..image_cols {\n                        image_string.push_str(FAKE);\n                        image_string.push_str(&format!(\"<row_{}_col_{}>\", n_h + 1, n_w + 1));\n                        image_string.push_str(&IMAGE.repeat(image_seq_len));\n                    }\n                    image_string.push('\\n');\n                }\n\n                image_string.push('\\n');\n                image_string.push_str(FAKE);\n                image_string.push_str(GLOBAL_IMG);\n                image_string.push_str(&IMAGE.repeat(image_seq_len));\n                image_string.push_str(FAKE);\n            }\n\n            image_string\n        }\n        Paligemma(config) => \"<image>\".repeat(config.get_number_of_features(height, width)),\n        LlavaNext(config) => \"<image>\".repeat(config.get_number_of_features(height, width)),\n        Llama4(config) => {\n            const IMAGE_START: &str = \"<|image_start|>\";\n            const IMAGE: &str = \"<|image|>\";\n            const IMAGE_END: &str = \"<|image_end|>\";\n            const PATCH: &str = \"<|patch|>\";\n            const TILE_X_SEP: &str = \"<|tile_x_separator|>\";\n            const TILE_Y_SEP: &str = \"<|tile_y_separator|>\";\n\n            let image_height = config.image_size();\n            let patch_size = config.patch_size();\n            let pixel_shuffle_ratio = config.pixel_shuffle_ratio();\n            let max_patches = match preprocessor_config {\n                Some(HubPreprocessorConfig::Llama4Processor(cfg)) => cfg.max_patches,\n                _ => panic!(\"Expected Llama4Processor in preprocessor_config\"),\n            };\n            let downsample_ratio =\n                (1.0 / (pixel_shuffle_ratio * pixel_shuffle_ratio)).round() as usize;\n\n            let (ratio_h, ratio_w) = config.get_aspect_ratios(height, width, max_patches);\n            let image_width = image_height; // Assuming pixel shape: [H][W][C]\n\n            let num_patches_per_chunk =\n                (image_height / patch_size) * (image_width / patch_size) / downsample_ratio;\n\n            let mut img_string = String::new();\n            img_string.push_str(IMAGE_START);\n\n            if ratio_h * ratio_w > 1 {\n                for _yy in 0..ratio_h {\n                    for xx in 0..ratio_w {\n                        img_string.push_str(&PATCH.repeat(num_patches_per_chunk));\n                        if xx < ratio_w - 1 {\n                            img_string.push_str(TILE_X_SEP);\n                        }\n                    }\n                    img_string.push_str(TILE_Y_SEP);\n                }\n            }\n\n            img_string.push_str(IMAGE);\n            img_string.push_str(&PATCH.repeat(num_patches_per_chunk));\n            img_string.push_str(IMAGE_END);\n\n            img_string\n        }\n        Qwen2Vl(config) => format!(\n            \"<|vision_start|>{:?}<|vision_end|>\",\n            \"<|image_pad|>\".repeat(config.get_number_of_features(height, width))\n        ),\n        Qwen2_5Vl(config) => format!(\n            \"<|vision_start|>{:?}<|vision_end|>\",\n            \"<|image_pad|>\".repeat(config.get_number_of_features(height, width))\n        ),\n        Gemma3(_config) => {\n            // TODO: prefer using the config to determine the number of features\n            let num_mm_soft_tokens_per_image = 256;\n            format!(\n                \"\\n\\n<start_of_image>{}<end_of_image>\\n\\n\",\n                \"<image_soft_token>\".repeat(num_mm_soft_tokens_per_image)\n            )\n        }\n        _ => unimplemented!(\"Images tokens are not supported for this model configuration\"),\n    }\n}\n\nfn image_tokens_fixup(config: &Config, text: String) -> String {\n    match config {\n        Config::Idefics2(_) => {\n            const FAKE: &str = \"<fake_token_around_image>\";\n            text.replace(&format!(\"{FAKE}{FAKE}\"), FAKE)\n        }\n        _ => text,\n    }\n}\n\n/// Get input length and optionally truncate it\nfn prepare_input<T: TokenizerTrait>(\n    inputs: String,\n    _truncate: Option<usize>,\n    add_special_tokens: bool,\n    tokenizer: &T,\n    config: Option<&Config>,\n    preprocessor_config: Option<&HubPreprocessorConfig>,\n    max_image_fetch_size: usize,\n) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {\n    use Config::*;\n    static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r\"!\\[\\]\\([^\\)]*\\)\").unwrap());\n    let (tokenizer_query, input_chunks) = match config {\n        Some(\n            config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Gemma3(_) | Llama4(_)\n            | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_) | Qwen2_5Vl(_)),\n        ) => {\n            let mut input_chunks = Vec::new();\n            let mut tokenizer_query = String::with_capacity(inputs.len());\n            let mut start = 0;\n            for chunk in RE.find_iter(&inputs) {\n                let chunk_start = chunk.start();\n                let chunk_end = chunk.end();\n                if chunk_start != start {\n                    input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));\n                    tokenizer_query.push_str(&inputs[start..chunk_start]);\n                }\n                let (data, mimetype, height, width) =\n                    fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;\n                input_chunks.push(Chunk::Image(Image { data, mimetype }));\n                tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));\n                start = chunk_end;\n            }\n            if start != inputs.len() {\n                input_chunks.push(Chunk::Text(inputs[start..].to_string()));\n                tokenizer_query.push_str(&inputs[start..]);\n            }\n\n            tokenizer_query = image_tokens_fixup(config, tokenizer_query);\n\n            (tokenizer_query, input_chunks)\n        }\n        _ => (inputs.clone(), vec![Chunk::Text(inputs)]),\n    };\n\n    // Get the number of tokens in the input\n    let encoding = tokenizer\n        .encode_trait(tokenizer_query, add_special_tokens)\n        .map_err(|err| ValidationError::Tokenizer(err.to_string()))?;\n\n    Ok((encoding, input_chunks))\n}\n\ntype TokenizerRequest = (\n    (String, bool, Option<usize>),\n    oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,\n    Span,\n);\n\n#[derive(Debug, Clone, Eq, PartialEq)]\npub struct Image {\n    pub data: Vec<u8>,\n    pub mimetype: String,\n}\n\n#[derive(Debug, Clone, Eq, PartialEq)]\npub enum Chunk {\n    Text(String),\n    Image(Image),\n}\n\n/// Convert input chunks to a stringly-typed input for backwards\n/// compat for backends that haven't implemented chunked inputs.\npub trait ChunksToString {\n    /// Convert chunks to string.\n    fn chunks_to_string(&self) -> String;\n}\n\nimpl ChunksToString for Vec<Chunk> {\n    fn chunks_to_string(&self) -> String {\n        let mut output = String::new();\n        self.iter().for_each(|c| match &c {\n            Chunk::Text(text) => output.push_str(text),\n            Chunk::Image(Image { data, mimetype }) => {\n                let encoded = STANDARD.encode(data);\n                output.push_str(&format!(\"![](data:{};base64,{})\", mimetype, encoded))\n            }\n        });\n        output\n    }\n}\n\n#[derive(Debug, Clone)]\npub enum ValidGrammar {\n    Json(String),\n    Regex(String),\n}\n\n#[derive(Debug, Clone)]\npub struct ValidParameters {\n    /// / exponential scaling output probability distribution\n    pub temperature: f32,\n    /// / restricting to the k highest probability elements\n    pub top_k: u32,\n    /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off\n    pub top_p: f32,\n    /// / restricting to top tokens summing to prob_cut_off <= prob_cut_off\n    pub typical_p: f32,\n    /// / apply sampling on the logits\n    pub do_sample: bool,\n    /// / random seed for sampling\n    pub seed: u64,\n    /// / repetition penalty\n    pub repetition_penalty: f32,\n    /// / frequency penalty\n    pub frequency_penalty: f32,\n    /// / token watermarking using \"A Watermark for Large Language Models\"\n    pub watermark: bool,\n    /// / grammar (applied if not empty)\n    pub grammar: Option<ValidGrammar>,\n}\n\n#[derive(Debug, Clone)]\npub struct ValidStoppingParameters {\n    /// / Maximum number of generated tokens\n    pub max_new_tokens: u32,\n    /// Maximum number of generated tokens before being re-queued by the system\n    pub max_total_new_tokens: u32,\n    /// / Optional stopping sequences\n    pub stop_sequences: Vec<String>,\n    /// / Ignore end of sequence token\n    /// / used for benchmarking\n    pub ignore_eos_token: bool,\n}\n\n#[derive(Debug, Clone)]\npub struct ValidGenerateRequest {\n    pub inputs: Vec<Chunk>,\n    pub input_ids: Option<Arc<Vec<u32>>>,\n    pub input_length: u32,\n    pub truncate: u32,\n    pub add_special_tokens: bool,\n    pub decoder_input_details: bool,\n    pub parameters: ValidParameters,\n    pub stopping_parameters: ValidStoppingParameters,\n    pub top_n_tokens: u32,\n    pub adapter_id: Option<String>,\n}\n\n#[derive(Error, Debug)]\npub enum ValidationError {\n    #[error(\"`best_of` must be > 0 and <= {0}. Given: {1}\")]\n    BestOf(usize, usize),\n    #[error(\"`best_of` != 1 is not allowed for this endpoint\")]\n    BestOfDisabled,\n    #[error(\"you must use sampling when `best_of` is > 1\")]\n    BestOfSampling,\n    #[error(\"`seed` must not be set when `best_of` > 1\")]\n    BestOfSeed,\n    #[error(\"`best_of` != 1 is not supported when streaming tokens\")]\n    BestOfStream,\n    #[error(\"`top_n_tokens` must be >= 0 and <= {0}. Given: {1}\")]\n    TopNTokens(u32, u32),\n    #[error(\"`top_n_tokens` != 0 is not allowed for this endpoint\")]\n    TopNTokensDisabled,\n    #[error(\"`decoder_input_details` == true is not supported when streaming tokens\")]\n    PrefillDetailsStream,\n    #[error(\"`temperature` must be strictly positive\")]\n    Temperature,\n    #[error(\"`repetition_penalty` must be strictly positive\")]\n    RepetitionPenalty,\n    #[error(\"`frequency_penalty` must be >= -2.0 and <= 2.0\")]\n    FrequencyPenalty,\n    #[error(\"`top_p` must be > 0.0 and < 1.0\")]\n    TopP,\n    #[error(\"`top_k` must be strictly positive\")]\n    TopK,\n    #[error(\"`truncate` must be strictly positive and less than {0}. Given: {1}\")]\n    Truncate(usize, usize),\n    #[error(\"`typical_p` must be > 0.0 and < 1.0\")]\n    TypicalP,\n    #[error(\"one of `max_new_tokens` or `truncate` must be set if a fast tokenizer is not in use\")]\n    UnsetMaxNewTokens,\n    #[error(\"`max_new_tokens` must be strictly positive\")]\n    NegativeMaxNewTokens,\n    #[error(\"`max_new_tokens` must be <= {0}. Given: {1}\")]\n    MaxNewTokens(usize, u32),\n    #[error(\"`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`\")]\n    MaxTotalTokens(usize, usize, u32),\n    #[error(\"`inputs` must have less than {0} tokens. Given: {1}\")]\n    InputLength(usize, usize),\n    #[error(\"`inputs` cannot be empty\")]\n    EmptyInput,\n    #[error(\"`stop` supports up to {0} stop sequences. Given: {1}\")]\n    StopSequence(usize, usize),\n    #[error(\"tokenizer error {0}\")]\n    Tokenizer(String),\n    #[error(\"grammar is not supported\")]\n    Grammar,\n    #[error(\"grammar is not valid: {0}\")]\n    InvalidGrammar(String),\n    #[error(\"cannot compile regex from schema: {0}\")]\n    RegexFromSchema(anyhow::Error),\n    #[error(\"base64 encoding is invalid: {0}\")]\n    InvalidBase64(#[from] base64::DecodeError),\n    #[error(\"invalid image: {0}\")]\n    InvalidImage(#[from] image::ImageError),\n    #[error(\"invalid integer: {0}\")]\n    InvalidInt(#[from] core::num::TryFromIntError),\n    #[error(\"invalid image content: {0}\")]\n    InvalidImageContent(String),\n    #[error(\"Could not fetch image: {0}\")]\n    FailedFetchImage(#[from] reqwest::Error),\n    #[error(\"Image size {0} bytes exceeds maximum allowed size of {1} bytes\")]\n    ImageTooLarge(usize, usize),\n    #[error(\"Failed to read image data: {0}\")]\n    ImageReadError(#[from] std::io::Error),\n    #[error(\"{0} modality is not supported\")]\n    UnsupportedModality(&'static str),\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::config::{Idefics2, PaliTextConfig, Paligemma};\n    use crate::default_parameters;\n    use crate::tests::get_tokenizer;\n\n    #[tokio::test]\n    async fn test_validation_max_new_tokens() {\n        let tokenizer = get_tokenizer();\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 6;\n        let workers = 1;\n        let disable_grammar_support = true;\n        let config = None;\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            config,\n            None,\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n\n        let max_new_tokens = 10;\n        match validation\n            .validate_input(\"Hello\".to_string(), true, None, Some(max_new_tokens))\n            .await\n        {\n            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),\n            // Ok((_s, _, 0, 10)) => (),\n            r => panic!(\"Unexpected not max new tokens: {r:?}\"),\n        }\n    }\n\n    #[tokio::test]\n    async fn test_validation_input_length() {\n        let tokenizer = get_tokenizer();\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 6;\n        let disable_grammar_support = true;\n        let workers = 1;\n        let config = None;\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            config,\n            None,\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n\n        let max_new_tokens = 10;\n        match validation\n            .validate_input(\"Hello\".to_string(), true, None, Some(max_new_tokens))\n            .await\n        {\n            Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),\n            _ => panic!(\"Unexpected not max new tokens\"),\n        }\n    }\n\n    #[tokio::test]\n    async fn test_validation_best_of_sampling() {\n        let tokenizer = get_tokenizer();\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 6;\n        let workers = 1;\n        let disable_grammar_support = true;\n        let config = None;\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            config,\n            None,\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n        match validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    best_of: Some(2),\n                    do_sample: false,\n                    ..default_parameters()\n                },\n            })\n            .await\n        {\n            Err(ValidationError::BestOfSampling) => (),\n            _ => panic!(\"Unexpected not best of sampling\"),\n        }\n    }\n\n    #[tokio::test]\n    async fn test_validation_top_p() {\n        let tokenizer = get_tokenizer();\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 106;\n        let workers = 1;\n        let disable_grammar_support = true;\n        let config = None;\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            config,\n            None,\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n        match validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_p: Some(1.0),\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n        {\n            Err(ValidationError::TopP) => (),\n            _ => panic!(\"Unexpected top_p\"),\n        }\n\n        match validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_p: Some(0.99),\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n        {\n            Ok(_) => (),\n            _ => panic!(\"Unexpected top_p error\"),\n        }\n\n        let valid_request = validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_p: None,\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n            .unwrap();\n        // top_p == 1.0 is invalid for users to ask for but it's the default resolved value.\n        assert_eq!(valid_request.parameters.top_p, 1.0);\n    }\n\n    #[tokio::test]\n    async fn test_validation_top_n_tokens() {\n        let tokenizer = get_tokenizer();\n        let max_best_of = 2;\n        let max_stop_sequences = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 106;\n        let workers = 1;\n        let disable_grammar_support = true;\n        let config = None;\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            config,\n            None,\n            max_best_of,\n            max_stop_sequences,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n        match validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_n_tokens: Some(5),\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n        {\n            Err(ValidationError::TopNTokens(4, 5)) => (),\n            _ => panic!(\"Unexpected top_n_tokens\"),\n        }\n\n        validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_n_tokens: Some(4),\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n            .unwrap();\n\n        validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_n_tokens: Some(0),\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n            .unwrap();\n\n        let valid_request = validation\n            .validate(GenerateRequest {\n                inputs: \"Hello\".to_string(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    top_n_tokens: None,\n                    max_new_tokens: Some(5),\n                    ..default_parameters()\n                },\n            })\n            .await\n            .unwrap();\n\n        assert_eq!(valid_request.top_n_tokens, 0);\n    }\n\n    static PIXEL_GIF: &str = \"R0lGODdhAQABAIEAAP///wAAAAAAAAAAACwAAAAAAQABAAAIBAABBAQAOw==\";\n\n    #[tokio::test]\n    async fn test_prepare_input_chunks() {\n        let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();\n\n        let tokenizer = get_tokenizer();\n\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 6;\n        let disable_grammar_support = true;\n        let workers = 1;\n        let config = Config::Paligemma(Paligemma {\n            text_config: PaliTextConfig {\n                num_image_tokens: 1,\n            },\n        });\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            Some(config),\n            None,\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n\n        let chunks = match validation\n            .tokenize(\n                format!(\"test![](data:image/gif;base64,{})\", PIXEL_GIF),\n                true,\n                None,\n            )\n            .await\n        {\n            Ok((_encoding, chunks)) => chunks,\n            _ => panic!(\"Unexpected tokenization failure\"),\n        };\n\n        assert!(\n            chunks\n                == vec![\n                    Chunk::Text(\"test\".to_string()),\n                    Chunk::Image(Image {\n                        data: pixel_data.clone(),\n                        mimetype: \"image/gif\".to_string()\n                    })\n                ],\n            \"Failed to process images\",\n        );\n    }\n\n    #[tokio::test]\n    async fn test_idefics2_correct_n_fake_tokens() {\n        let pixel_data = STANDARD.decode(PIXEL_GIF).unwrap();\n\n        let tokenizer = get_tokenizer();\n\n        let max_best_of = 2;\n        let max_stop_sequence = 3;\n        let max_top_n_tokens = 4;\n        let max_input_length = 5;\n        let max_total_tokens = 6;\n        let disable_grammar_support = true;\n        let workers = 1;\n        let config = Config::Idefics2(Idefics2 {});\n        let validation = Validation::new(\n            workers,\n            tokenizer,\n            Some(config),\n            Some(HubPreprocessorConfig::Idefics2Processor(\n                Idefics2Preprocessor {\n                    do_image_splitting: true,\n                },\n            )),\n            max_best_of,\n            max_stop_sequence,\n            max_top_n_tokens,\n            max_input_length,\n            max_total_tokens,\n            disable_grammar_support,\n            1024 * 1024 * 1024, // 1GB\n        );\n\n        let (encoding, chunks) = match validation\n            .tokenize(\n                format!(\n                    \"test![](data:image/gif;base64,{})![](data:image/gif;base64,{})\",\n                    PIXEL_GIF, PIXEL_GIF\n                ),\n                true,\n                None,\n            )\n            .await\n        {\n            Ok((encoding, chunks)) => (encoding, chunks),\n            _ => panic!(\"Unexpected tokenization failure\"),\n        };\n\n        assert!(\n            chunks\n                == vec![\n                    Chunk::Text(\"test\".to_string()),\n                    Chunk::Image(Image {\n                        data: pixel_data.clone(),\n                        mimetype: \"image/gif\".to_string()\n                    }),\n                    Chunk::Image(Image {\n                        data: pixel_data.clone(),\n                        mimetype: \"image/gif\".to_string()\n                    })\n                ],\n            \"Failed to process images\",\n        );\n\n        // Verify the number of fake tokens:\n        //\n        // - Two images surrounded/separated by a fake token = 3.\n        // - Both are split in 5 subimages, separated by a fake token: 2 * 4\n        //\n        // Fake tokens get split up by the testing tokenizer, but we don't care.\n        assert_eq!(\n            encoding\n                .get_tokens()\n                .iter()\n                .filter(|t| *t == \"fake\")\n                .count(),\n            11\n        );\n    }\n}\n"
  },
  {
    "path": "router/src/vertex.rs",
    "content": "use crate::infer::Infer;\nuse crate::server::{generate_internal, ComputeType};\nuse crate::{ChatRequest, ErrorResponse, GenerateParameters, GenerateRequest};\nuse axum::extract::Extension;\nuse axum::http::{HeaderMap, StatusCode};\nuse axum::response::{IntoResponse, Response};\nuse axum::Json;\nuse serde::{Deserialize, Serialize};\nuse tracing::instrument;\nuse tracing_opentelemetry::OpenTelemetrySpanExt;\nuse utoipa::ToSchema;\n\n#[derive(Clone, Deserialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct GenerateVertexInstance {\n    #[schema(example = \"What is Deep Learning?\")]\n    pub inputs: String,\n    #[schema(nullable = true, default = \"null\", example = \"null\")]\n    pub parameters: Option<GenerateParameters>,\n}\n\n#[derive(Clone, Deserialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\n#[serde(untagged)]\npub(crate) enum VertexInstance {\n    Generate(GenerateVertexInstance),\n    Chat(ChatRequest),\n}\n\n#[derive(Deserialize, ToSchema)]\n#[cfg_attr(test, derive(Debug, PartialEq))]\npub(crate) struct VertexRequest {\n    #[serde(rename = \"instances\")]\n    pub instances: Vec<VertexInstance>,\n}\n\n#[derive(Clone, Deserialize, ToSchema, Serialize)]\npub(crate) struct VertexResponse {\n    pub predictions: Vec<String>,\n}\n\n/// Generate tokens from Vertex request\n#[utoipa::path(\npost,\ntag = \"Text Generation Inference\",\npath = \"/vertex\",\nrequest_body = VertexRequest,\nresponses(\n(status = 200, description = \"Generated Text\", body = VertexResponse),\n(status = 424, description = \"Generation Error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Request failed during generation\"})),\n(status = 429, description = \"Model is overloaded\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Model is overloaded\"})),\n(status = 422, description = \"Input validation error\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Input validation error\"})),\n(status = 500, description = \"Incomplete generation\", body = ErrorResponse,\nexample = json ! ({\"error\": \"Incomplete generation\"})),\n)\n)]\n#[instrument(\n    skip_all,\n    fields(\n        total_time,\n        validation_time,\n        queue_time,\n        inference_time,\n        time_per_token,\n        seed,\n    )\n)]\npub(crate) async fn vertex_compatibility(\n    Extension(infer): Extension<Infer>,\n    Extension(compute_type): Extension<ComputeType>,\n    Extension(context): Extension<Option<opentelemetry::Context>>,\n    Json(req): Json<VertexRequest>,\n) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {\n    let span = tracing::Span::current();\n    if let Some(context) = context {\n        span.set_parent(context);\n    }\n\n    metrics::counter!(\"tgi_request_count\").increment(1);\n\n    // check that theres at least one instance\n    if req.instances.is_empty() {\n        return Err((\n            StatusCode::UNPROCESSABLE_ENTITY,\n            Json(ErrorResponse {\n                error: \"Input validation error\".to_string(),\n                error_type: \"Input validation error\".to_string(),\n            }),\n        ));\n    }\n\n    // Prepare futures for all instances\n    let mut futures = Vec::with_capacity(req.instances.len());\n\n    for instance in req.instances.into_iter() {\n        let generate_request = match instance {\n            VertexInstance::Generate(instance) => GenerateRequest {\n                inputs: instance.inputs.clone(),\n                add_special_tokens: true,\n                parameters: GenerateParameters {\n                    do_sample: true,\n                    max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),\n                    seed: instance.parameters.as_ref().and_then(|p| p.seed),\n                    details: true,\n                    decoder_input_details: true,\n                    ..Default::default()\n                },\n            },\n            VertexInstance::Chat(instance) => {\n                let (generate_request, _using_tools): (GenerateRequest, bool) =\n                    instance.try_into_generate(&infer)?;\n                generate_request\n            }\n        };\n\n        let infer_clone = infer.clone();\n        let compute_type_clone = compute_type.clone();\n        let span_clone = span.clone();\n\n        futures.push(async move {\n            generate_internal(\n                Extension(infer_clone),\n                compute_type_clone,\n                Json(generate_request),\n                span_clone,\n            )\n            .await\n            .map(|(_, _, Json(generation))| generation.generated_text)\n            .map_err(|_| {\n                (\n                    StatusCode::INTERNAL_SERVER_ERROR,\n                    Json(ErrorResponse {\n                        error: \"Incomplete generation\".into(),\n                        error_type: \"Incomplete generation\".into(),\n                    }),\n                )\n            })\n        });\n    }\n\n    // execute all futures in parallel, collect results, returning early if any error occurs\n    let results = futures::future::join_all(futures).await;\n    let predictions: Result<Vec<_>, _> = results.into_iter().collect();\n    let predictions = predictions?;\n\n    let response = VertexResponse { predictions };\n    Ok((HeaderMap::new(), Json(response)).into_response())\n}\n\n#[cfg(test)]\nmod tests {\n    use super::*;\n    use crate::{Message, MessageBody, MessageContent};\n\n    #[test]\n    fn vertex_deserialization() {\n        let string = serde_json::json!({\n\n        \"instances\": [\n            {\n                \"messages\": [{\"role\": \"user\", \"content\": \"What's Deep Learning?\"}],\n                \"max_tokens\": 128,\n                \"top_p\": 0.95,\n                \"temperature\": 0.7\n            }\n        ]\n\n        });\n        let request: VertexRequest = serde_json::from_value(string).expect(\"Can deserialize\");\n        assert_eq!(\n            request,\n            VertexRequest {\n                instances: vec![VertexInstance::Chat(ChatRequest {\n                    messages: vec![Message {\n                        name: None,\n                        role: \"user\".to_string(),\n                        body: MessageBody::Content {\n                            content: MessageContent::SingleText(\n                                \"What's Deep Learning?\".to_string()\n                            )\n                        },\n                    },],\n                    max_tokens: Some(128),\n                    top_p: Some(0.95),\n                    temperature: Some(0.7),\n                    ..Default::default()\n                })]\n            }\n        );\n    }\n}\n"
  },
  {
    "path": "rust-toolchain.toml",
    "content": "[toolchain]\n# Released on: 30 January, 2025\n# https://releases.rs/docs/1.84.1/\nchannel = \"1.85.1\"\ncomponents = [\"rustfmt\", \"clippy\"]\n"
  },
  {
    "path": "sagemaker-entrypoint.sh",
    "content": "#!/bin/bash\n\nif [[ -z \"${HF_MODEL_ID}\" ]]; then\n  echo \"HF_MODEL_ID must be set\"\n  exit 1\nfi\nexport MODEL_ID=\"${HF_MODEL_ID}\"\n\nif [[ -n \"${HF_MODEL_REVISION}\" ]]; then\n  export REVISION=\"${HF_MODEL_REVISION}\"\nfi\n\nif [[ -n \"${SM_NUM_GPUS}\" ]]; then\n  export NUM_SHARD=\"${SM_NUM_GPUS}\"\nfi\n\nif [[ -n \"${HF_MODEL_QUANTIZE}\" ]]; then\n  export QUANTIZE=\"${HF_MODEL_QUANTIZE}\"\nfi\n\nif [[ -n \"${HF_MODEL_TRUST_REMOTE_CODE}\" ]]; then\n  export TRUST_REMOTE_CODE=\"${HF_MODEL_TRUST_REMOTE_CODE}\"\nfi\n\ntext-generation-launcher --port 8080\n"
  },
  {
    "path": "server/.gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\ntext_generation_server/__pycache__/\ntext_generation_server/pb/__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\ntransformers\nsafetensors\nflash-attention/\nflash-attention-v2/\nvllm/\nllm-awq/\neetq/\nmamba/\n"
  },
  {
    "path": "server/Makefile",
    "content": "include Makefile-flash-att\ninclude Makefile-flash-att-v2\ninclude Makefile-vllm\ninclude Makefile-awq\ninclude Makefile-selective-scan\ninclude Makefile-exllamav2\ninclude Makefile-flashinfer\n\nunit-tests:\n\tpip install -U pip uv\n\tuv pip install -e \".[dev]\"\n\tuv sync --inexact --extra dev --active\n\tpytest -s -vv -m \"not private\" tests\n\ngen-server:\n\t# Compile protos\n\tpip install -U pip uv\n\tuv pip install -r requirements_gen.txt\n\tmkdir text_generation_server/pb || true\n\tpython -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \\\n\t\t--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto\n\tfind text_generation_server/pb/ -type f -name \"*.py\" -print0 -exec sed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' {} \\;\n\ttouch text_generation_server/pb/__init__.py\n\ngen-server-raw:\n\tmkdir text_generation_server/pb || true\n\tpython -m grpc_tools.protoc -I../proto/v3 --python_out=text_generation_server/pb \\\n\t\t--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/v3/generate.proto\n\tfind text_generation_server/pb/ -type f -name \"*.py\" -print0 -exec sed -i -e 's/^\\(import.*pb2\\)/from . \\1/g' {} \\;\n\ttouch text_generation_server/pb/__init__.py\n\ninstall-server: gen-server\n\tuv sync --inexact --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines --active\n\n\ninstall: install-cuda\n\techo \"Installed server\"\n\ninstall-cuda: install-server install-flash-attention-v2-cuda install-flash-attention\n\tuv sync --inexact --extra attention --extra bnb --active\n\tuv pip install nvidia-nccl-cu12==2.22.3\n\tkernels download .\n\ninstall-rocm: install-server install-flash-attention-v2-rocm  install-vllm-rocm\n\nexport-requirements:\n\tuv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11\n\tuv pip compile pyproject.toml --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11\n\tuv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11\n\tuv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11\n"
  },
  {
    "path": "server/Makefile-awq",
    "content": "# Fork that adds only the correct stream to this kernel in order\n# to make cuda graphs work.\nawq_commit := bd1dc2d5254345cc76ab71894651fb821275bdd4\n\nawq:\n\trm -rf llm-awq\n\tgit clone https://github.com/huggingface/llm-awq\n\nbuild-awq: awq\n\tcd llm-awq/ && git fetch && git checkout $(awq_commit)\n\tcd llm-awq/awq/kernels && python setup.py build\n\ninstall-awq: build-awq\n\tpip uninstall awq_inference_engine -y || true\n\tcd llm-awq/awq/kernels && python setup.py install\n"
  },
  {
    "path": "server/Makefile-eetq",
    "content": "eetq_commit := 465e9726bf7ae30803a2d0dd9e5d4315aef17491\n\neetq:\n    # Clone eetq\n\tpip install packaging\n\tgit clone https://github.com/NetEase-FuXi/EETQ.git eetq\n\nbuild-eetq: eetq\n\tcd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive\n\tcd eetq && python setup.py build\n\ninstall-eetq: build-eetq\n\tcd eetq && python setup.py install\n"
  },
  {
    "path": "server/Makefile-exllamav2",
    "content": "exllamav2_commit := v0.1.8\n\nbuild-exllamav2:\n\tgit clone https://github.com/turboderp/exllamav2.git exllamav2 && \\\n\tcd exllamav2 && git fetch && git checkout $(exllamav2_commit)  && \\\n\tgit submodule update --init --recursive && \\\n\tpip install -r requirements.txt && \\\n\tCUDA_ARCH_LIST=\"8.0;9.0a\" NVCC_GENCODE=\"-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a\" TORCH_CUDA_ARCH_LIST=\"8.0;9.0a\" python setup.py build\n\ninstall-exllamav2: build-exllamav2\n\tcd exllamav2/ &&  \\\n\tCUDA_ARCH_LIST=\"8.0;9.0a\" NVCC_GENCODE=\"-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a\" TORCH_CUDA_ARCH_LIST=\"8.0;9.0a\" python setup.py install\n"
  },
  {
    "path": "server/Makefile-flash-att",
    "content": "flash_att_commit := ceee0de88c037ee6eda5e75c813a8648e4bcb1c9\n\nbuild-flash-attention:\n\tif [ ! -d 'flash-attention' ]; then \\\n\t\tpip install -U packaging ninja  --no-cache-dir && \\\n\t\tgit clone https://github.com/Narsil/flash-attention.git; \\\n\tfi\n\tcd flash-attention && git fetch && git checkout $(flash_att_commit) && \\\n\tMAX_JOBS=8 python setup.py build && cd csrc/layer_norm && python setup.py build && cd ../rotary && python setup.py build\n\ninstall-flash-attention: build-flash-attention\n\tcd flash-attention && git checkout $(flash_att_commit) && MAX_JOBS=8 python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install\n"
  },
  {
    "path": "server/Makefile-flash-att-v2",
    "content": "flash_att_v2_commit_cuda := v2.6.1\nflash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd\n\nbuild-flash-attention-v2-cuda:\n\tpip install -U packaging wheel\n\tpip install --no-build-isolation flash-attn==$(flash_att_v2_commit_cuda)\n\ninstall-flash-attention-v2-cuda: build-flash-attention-v2-cuda\n\techo \"Flash v2 installed\"\n\nbuild-flash-attention-v2-rocm:\n\tif [ ! -d 'flash-attention-v2' ]; then \\\n\t\tpip install -U packaging ninja  --no-cache-dir && \\\n\t\tgit clone https://github.com/mht-sharma/flash-attention.git flash-attention-v2 && \\\n\t\tcd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \\\n\t\tgit submodule update --init --recursive && GPU_ARCHS=\"gfx90a;gfx942\" PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python setup.py build; \\\n\tfi\n\ninstall-flash-attention-v2-rocm: build-flash-attention-v2-rocm\n\tcd flash-attention-v2 &&  \\\n\tGPU_ARCHS=\"gfx90a;gfx942\" PYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python setup.py install\n"
  },
  {
    "path": "server/Makefile-flashinfer",
    "content": "install-flashinfer:\n\t# We need fsspec as an additional dependency, but\n\t# `pip install flashinfer` cannot resolve it.\n\tuv pip install fsspec sympy==1.13.1 numpy\n\tuv pip install -U setuptools\n\tTORCH_CUDA_ARCH_LIST=\"7.5;8.0;8.6;8.9;9.0+PTX\" FLASHINFER_ENABLE_AOT=1 pip install git+https://github.com/flashinfer-ai/flashinfer.git@v0.2.0.post2#egg=flashinfer-python  --no-build-isolation\n"
  },
  {
    "path": "server/Makefile-selective-scan",
    "content": "selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137\n\ncausal-conv1d:\n\trm -rf causal-conv1d\n\tgit clone https://github.com/Dao-AILab/causal-conv1d.git\n\nbuild-causal-conv1d: causal-conv1d\n\tcd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag\n\tcd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build\n\ninstall-causal-conv1d: build-causal-conv1d\n\tpip uninstall causal-conv1d -y || true\n\tcd causal-conv1d/ && pip install .\n\n# selective-scan dependends on causal-conv1d\nselective-scan:\n\trm -rf mamba\n\tgit clone https://github.com/state-spaces/mamba.git mamba\n\nbuild-selective-scan: selective-scan\n\tcd mamba/ && git fetch && git checkout $(selective_scan_commit)\n\tcd mamba && python setup.py build\n\ninstall-selective-scan: install-causal-conv1d build-selective-scan\n\tpip uninstall selective-scan-cuda -y || true\n\tcd mamba && pip install .\n\nbuild-all: build-causal-conv1d build-selective-scan\n"
  },
  {
    "path": "server/Makefile-vllm",
    "content": "commit_rocm := de990cd12537f78f74e40b5c8ee1a62d63d734dd\n\nbuild-vllm-rocm:\n\tif [ ! -d 'vllm' ]; then \\\n\t\tpip install -U ninja packaging --no-cache-dir && \\\n\t\tgit clone https://github.com/mht-sharma/vllm.git vllm; \\\n\tfi\n\tcd vllm && git fetch && git checkout $(commit_rocm) &&  \\\n\tPYTORCH_ROCM_ARCH=\"gfx90a;gfx942\" python3 setup.py bdist_wheel --dist-dir=dist\n\ninstall-vllm-rocm: build-vllm-rocm\n\tcd vllm && git fetch && git checkout $(commit_rocm)\n"
  },
  {
    "path": "server/README.md",
    "content": "# Text Generation Inference Python gRPC Server\n\nA Python gRPC server for Text Generation Inference\n\n## Install\n\n```shell\nmake install\n```\n\n## Run\n\n```shell\nmake run-dev\n```\n"
  },
  {
    "path": "server/bounds-from-nix.py",
    "content": "#!/usr/bin/env python3\n\nimport json\nimport subprocess\nfrom typing import Dict, Union\nimport toml\n\n# Special cases that have download URLs.\nSKIP = {\"attention-kernels\", \"marlin-kernels\", \"moe-kernels\"}\n\n\ndef is_optional(info: Union[str, Dict[str, str]]) -> bool:\n    return isinstance(info, dict) and \"optional\" in info and info[\"optional\"]\n\n\nif __name__ == \"__main__\":\n    with open(\"pyproject.toml\") as f:\n        pyproject = toml.load(f)\n\n    nix_packages = json.loads(\n        subprocess.run(\n            [\"nix\", \"develop\", \".#server\", \"--command\", \"pip\", \"list\", \"--format=json\"],\n            stdout=subprocess.PIPE,\n        ).stdout\n    )\n\n    nix_packages = {pkg[\"name\"]: pkg[\"version\"] for pkg in nix_packages}\n\n    packages = []\n    optional_packages = []\n\n    for package, info in pyproject[\"tool\"][\"poetry\"][\"dependencies\"].items():\n        if package in nix_packages and package not in SKIP:\n            if is_optional(info):\n                optional_packages.append(f'\"{package}@^{nix_packages[package]}\"')\n            else:\n                packages.append(f'\"{package}@^{nix_packages[package]}\"')\n\n    print(f\"poetry add {' '.join(packages)}\")\n    print(f\"poetry add --optional {' '.join(optional_packages)}\")\n"
  },
  {
    "path": "server/custom_kernels/custom_kernels/fused_attention_cuda.cu",
    "content": "#include <ATen/Dispatch.h>\n#include <THC/THCAtomics.cuh>\n#include <ATen/ATen.h>\n#include <torch/torch.h>\n#include <vector>\n\n#include <optional>\n\n/**\n* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda\n* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu\n**/\n\n// Available in pytorch main\n//#define DISPATCH_CASE_FLOATING_TYPES(...) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \\\n\n/*\n* Forward passes\n*/\n\n/**\n* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype\n**/\ntemplate<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>\n__global__ void forward_masked_softmax_kernel(\n    const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]\n    const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]\n    torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]\n    const int64_t effective_kv_length,\n    const dim3 blockDim,\n    const int64_t rows_per_block,\n    const int64_t kv_length,\n    const int64_t batch_size\n) {\n    const auto row_id = threadIdx.x / effective_kv_length;\n    const auto effective_kv_length_id = threadIdx.x % effective_kv_length;\n    const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;\n    auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;\n    kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;\n    const auto kv_length_end = kv_length_end_;\n\n    const auto batch_id = blockIdx.x * rows_per_block + row_id;\n\n    // We need 2 float storage for each row, one for max computation, the other for normalizing exponential\n    extern __shared__ float temp_storage[];\n    const auto row_id_mem_offset = row_id * 2;\n    if (effective_kv_length_id == 0) {\n        temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();\n        temp_storage[row_id_mem_offset + 1] = 0;\n    }\n    __syncthreads();\n\n    // Compute mask and max\n    if (batch_id < batch_size) {\n        float thread_max = -std::numeric_limits<float>::infinity();\n        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n            if (mask[batch_id][kv_length_id] == 0) {\n                const float candidate = attention_scores[batch_id][kv_length_id];\n                thread_max = (thread_max < candidate) ? candidate : thread_max;\n            }\n        }\n        if (thread_max != -std::numeric_limits<float>::infinity()) {\n            // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot\n            gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);\n        }\n    }\n\n    __syncthreads();\n\n    // Compute exp(elt - max) masked\n    float exponential[min_kv_length_shard_size_per_thread];\n    if (batch_id < batch_size) {\n        float thread_add = 0;\n        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n            if (mask[batch_id][kv_length_id] == 0) {\n                exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);\n                thread_add = thread_add + exponential[kv_length_id - kv_length_start];\n            } else {\n                exponential[kv_length_id - kv_length_start] = 0.;\n            }\n        }\n        if (thread_add > 0) {\n            // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot\n            gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);\n        }\n    }\n\n    __syncthreads();\n\n    // Compute softmax\n    if (batch_id < batch_size) {\n        // If sum of all exponential is 0, we set the softmax values to 0\n        if (temp_storage[row_id_mem_offset + 1] == 0.) {\n            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n                result[batch_id][kv_length_id] = 0.;\n            }\n        } else {\n            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n                result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);\n            }\n        }\n    }\n}\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(\n    const at::Tensor query,\n    const at::Tensor key,\n    const at::Tensor value,\n    const std::optional<std::vector<at::Tensor>> layer_past,\n    const at::Tensor attention_mask,\n    const std::optional<at::Tensor> head_mask,\n    const float inv_norm_factor,\n    const int num_heads,\n    const bool use_cache\n) {\n    auto query_layer = query;\n    auto key_layer = key;\n    auto value_layer = value;\n\n     if (layer_past) {\n        const auto past_key = (*layer_past).at(0);\n        const auto past_value = (*layer_past).at(1);\n        key_layer = at::cat({past_key, key_layer}, 2);\n        value_layer = at::cat({past_value, value_layer}, 2);\n    }\n\n    std::optional<std::vector<at::Tensor>> present;\n    if (use_cache) {\n        present = {key_layer, value_layer};\n    } else {\n        present = {};\n    }\n\n    const auto batch_size = query_layer.size(0);\n    const auto q_length = query_layer.size(2);\n    const auto attn_head_size = query_layer.size(3);\n    const auto batch_size_times_num_heads = batch_size * num_heads;\n    const auto kv_length = key_layer.size(2);\n\n    const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});\n    auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);\n    auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});\n\n    auto query_scaled = query_view * inv_norm_factor;\n    auto attention_scores = at::bmm(query_scaled, key_view);\n\n    // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`\n    at::Tensor attention_probs;\n    if (true) {\n        // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors\n        const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});\n        const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});\n\n        // Custom kernel\n        attention_probs = at::empty_like(attention_scores_2d);\n\n        // Check that inputs and contiguous + cuda tensors\n        CHECK_INPUT(attention_scores_2d);\n        CHECK_INPUT(attention_mask_2d);\n\n        // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out\n        // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), \"masked_softmax\", [&] {\n        AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), \"masked_softmax\", [&] {\n            /*\n            * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/\n            * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf\n            *  - SMs: 108\n            *  - TPCs: 56 (What's that?)\n            *  - Memory size: 40 GB\n            *  - L2 Cache size: 40960 KB (shared across all SMs)\n            *  - L1/Shared memory size: 192 KB (shared across all threads within a SM)\n            *  - Max Threads / SM: 2048\n            *  - Max Thread Blocks / SM: 32\n            */\n\n            /*\n            * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block\n            * with multiple threads as we need to `sync_threads` to run exponential sum.\n            * We maximise the usage of threads within a single block\n            */\n            // TODO @thomasw21 figure out everything warp related:\n            //  - why do they have to be power of 2\n            // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048\n            const auto MAX_THREADS_PER_SM = 1024;\n            // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`\n            const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;\n            // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`\n            const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;\n            const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;\n            const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;\n\n            const dim3 gridDim(num_blocks); // Number of blocks that run\n            const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block\n            const int shared_mem_forward = rows_per_block * 2 * sizeof(float);\n\n            // 192 * 2 ** 10\n            // const auto MAX_L1_MEMORY = 196608;\n            // const auto MAX_SMs = 108;\n            // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, \"Shared memory exceeds 192KB limitation.\");\n            // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, \"A100s only have 108 SMs. Raising as require blocks is bigger.\");\n            // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, \"A100s only have 2048 threads per block. Raising as require requested threads is higher.\");\n\n            forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(\n                attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),\n                attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                effective_kv_length,\n                blockDim,\n                rows_per_block,\n                kv_length,\n                batch_size_times_num_heads * q_length\n            );\n        });\n        attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});\n    } else {\n        // Pytorch C++ API\n        auto input_dtype = attention_scores.scalar_type();\n        if (input_dtype == at::ScalarType::Float) {\n            attention_scores = attention_scores.to(at::ScalarType::Float);\n        };\n        // TODO @thomasw21 Figure out how to get minimum value\n        auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);\n        attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);\n    }\n\n    auto context_layer = attention_probs.bmm(value_view);\n\n    // `_merge_heads`\n    context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});\n    context_layer = context_layer.permute({0, 2, 1, 3});\n    context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});\n\n    return std::make_tuple(context_layer, present, attention_probs);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\n        \"forward\",\n        &forward,\n        \"GPT-Neox attention mechanism forward (CUDA)\"\n    );\n}\n"
  },
  {
    "path": "server/custom_kernels/custom_kernels/fused_bloom_attention_cuda.cu",
    "content": "#include <ATen/Dispatch.h>\n#include <THC/THCAtomics.cuh>\n#include <ATen/ATen.h>\n#include <torch/torch.h>\n#include <vector>\n\n#include <optional>\n\n/**\n* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda\n* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu\n**/\n\n// Available in pytorch main\n//#define DISPATCH_CASE_FLOATING_TYPES(...) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \\\n//  at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \\\n\n/*\n* Forward passes\n*/\n\n/**\n* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype\n**/\ntemplate<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>\n__global__ void forward_masked_softmax_kernel(\n    const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]\n    const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]\n    torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]\n    const int64_t effective_kv_length,\n    const dim3 blockDim,\n    const int64_t rows_per_block,\n    const int64_t kv_length,\n    const int64_t batch_size\n) {\n    const auto row_id = threadIdx.x / effective_kv_length;\n    const auto effective_kv_length_id = threadIdx.x % effective_kv_length;\n    const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;\n    auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;\n    kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;\n    const auto kv_length_end = kv_length_end_;\n\n    const auto batch_id = blockIdx.x * rows_per_block + row_id;\n\n    // We need 2 float storage for each row, one for max computation, the other for normalizing exponential\n    extern __shared__ float temp_storage[];\n    const auto row_id_mem_offset = row_id * 2;\n    if (effective_kv_length_id == 0) {\n        temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();\n        temp_storage[row_id_mem_offset + 1] = 0;\n    }\n    __syncthreads();\n\n    // Compute mask and max\n    if (batch_id < batch_size) {\n        float thread_max = -std::numeric_limits<float>::infinity();\n        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n            if (mask[batch_id][kv_length_id] == 0) {\n                const float candidate = attention_scores[batch_id][kv_length_id];\n                thread_max = (thread_max < candidate) ? candidate : thread_max;\n            }\n        }\n        if (thread_max != -std::numeric_limits<float>::infinity()) {\n            // TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot\n            gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);\n        }\n    }\n\n    __syncthreads();\n\n    // Compute exp(elt - max) masked\n    float exponential[min_kv_length_shard_size_per_thread];\n    if (batch_id < batch_size) {\n        float thread_add = 0;\n        for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n            if (mask[batch_id][kv_length_id] == 0) {\n                exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);\n                thread_add = thread_add + exponential[kv_length_id - kv_length_start];\n            } else {\n                exponential[kv_length_id - kv_length_start] = 0.;\n            }\n        }\n        if (thread_add > 0) {\n            // TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot\n            gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);\n        }\n    }\n\n    __syncthreads();\n\n    // Compute softmax\n    if (batch_id < batch_size) {\n        // If sum of all exponential is 0, we set the softmax values to 0\n        if (temp_storage[row_id_mem_offset + 1] == 0.) {\n            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n                result[batch_id][kv_length_id] = 0.;\n            }\n        } else {\n            for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {\n                result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);\n            }\n        }\n    }\n}\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nstd::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(\n    const at::Tensor fused_qkv,\n    const std::optional<std::vector<at::Tensor>> layer_past,\n    const at::Tensor alibi,\n    const at::Tensor attention_mask,\n    const std::optional<at::Tensor> head_mask,\n    const float beta,\n    const float inv_norm_factor,\n    const int num_heads,\n    const bool use_cache\n) {\n    const auto batch_size = fused_qkv.size(0);\n    const auto q_length = fused_qkv.size(1);\n    const auto three_times_hidden_size = fused_qkv.size(2);\n    const auto head_dim = three_times_hidden_size / (3 * num_heads);\n    const auto batch_size_times_num_heads = batch_size * num_heads;\n\n    // `split_heads`\n    const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});\n    const auto tensor_list = fused_qkv_view.split(head_dim, -1);\n    const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});\n    auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});\n    auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});\n\n    if (layer_past) {\n        const auto past_key = (*layer_past).at(0);\n        const auto past_value = (*layer_past).at(1);\n        key_layer = at::cat({past_key, key_layer}, 2);\n        value_layer = at::cat({past_value, value_layer}, 1);\n    }\n\n    std::optional<std::vector<at::Tensor>> present;\n    if (use_cache) {\n        present = {key_layer, value_layer};\n    } else {\n        present = {};\n    }\n\n    auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);\n\n    // Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`\n    at::Tensor attention_probs;\n    if (true) {\n        const auto kv_length = key_layer.size(2);\n\n        // TODO @thomasw21: it's easier to think of attention_scores as 2D tensors\n        const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});\n        const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});\n\n        // Custom kernel\n        attention_probs = at::empty_like(attention_scores_2d);\n\n        // Check that inputs and contiguous + cuda tensors\n        CHECK_INPUT(attention_scores_2d);\n        CHECK_INPUT(attention_mask_2d);\n\n        // TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out\n        // DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), \"masked_softmax\", [&] {\n        AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), \"masked_softmax\", [&] {\n            /*\n            * Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/\n            * A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf\n            *  - SMs: 108\n            *  - TPCs: 56 (What's that?)\n            *  - Memory size: 40 GB\n            *  - L2 Cache size: 40960 KB (shared across all SMs)\n            *  - L1/Shared memory size: 192 KB (shared across all threads within a SM)\n            *  - Max Threads / SM: 2048\n            *  - Max Thread Blocks / SM: 32\n            */\n\n            /*\n            * We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block\n            * with multiple threads as we need to `sync_threads` to run exponential sum.\n            * We maximise the usage of threads within a single block\n            */\n            // TODO @thomasw21 figure out everything warp related:\n            //  - why do they have to be power of 2\n            // TODO @thomas21 check why everyone is setting 1024 when officially it's 2048\n            const auto MAX_THREADS_PER_SM = 1024;\n            // TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`\n            const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;\n            // `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`\n            const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;\n            const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;\n            const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;\n\n            const dim3 gridDim(num_blocks); // Number of blocks that run\n            const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block\n            const int shared_mem_forward = rows_per_block * 2 * sizeof(float);\n\n            // 192 * 2 ** 10\n            // const auto MAX_L1_MEMORY = 196608;\n            // const auto MAX_SMs = 108;\n            // TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, \"Shared memory exceeds 192KB limitation.\");\n            // TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, \"A100s only have 108 SMs. Raising as require blocks is bigger.\");\n            // TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, \"A100s only have 2048 threads per block. Raising as require requested threads is higher.\");\n\n            forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(\n                attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),\n                attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),\n                effective_kv_length,\n                blockDim,\n                rows_per_block,\n                kv_length,\n                batch_size_times_num_heads * q_length\n            );\n        });\n        attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});\n    } else {\n        // Pytorch C++ API\n        auto input_dtype = attention_scores.scalar_type();\n        if (input_dtype == at::ScalarType::Float) {\n            attention_scores = attention_scores.to(at::ScalarType::Float);\n        };\n        // TODO @thomasw21 Figure out how to get minimum value\n        auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);\n        attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);\n    }\n\n    auto context_layer = attention_probs.bmm(value_layer);\n\n    // `_merge_heads`\n    context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});\n    context_layer = context_layer.permute({0, 2, 1, 3});\n    context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});\n\n    return std::make_tuple(context_layer, present, attention_probs);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\n        \"forward\",\n        &forward,\n        \"Bloom attention mechanism forward (CUDA)\"\n    );\n}\n"
  },
  {
    "path": "server/custom_kernels/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nextra_compile_args = [\"-std=c++17\"]\n\nsetup(\n    name=\"custom_kernels\",\n    ext_modules=[\n        CUDAExtension(\n            name=\"custom_kernels.fused_bloom_attention_cuda\",\n            sources=[\"custom_kernels/fused_bloom_attention_cuda.cu\"],\n            extra_compile_args=extra_compile_args,\n        ),\n        CUDAExtension(\n            name=\"custom_kernels.fused_attention_cuda\",\n            sources=[\"custom_kernels/fused_attention_cuda.cu\"],\n            extra_compile_args=extra_compile_args,\n        ),\n    ],\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cu_compat.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _cuda_compat_cuh\n#define _cuda_compat_cuh\n\n// atomicAdd for half types, to support CC < 7.x\n\n__device__ __forceinline__ void atomicAdd_half(half* address, half val)\n{\n    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));\n    unsigned int old = *address_as_ui;\n    unsigned int assumed;\n\n    do\n    {\n        assumed = old;\n        __half_raw hsum;\n        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);\n        half tmpres = __hadd(hsum, val);\n        hsum = __half_raw(tmpres);\n        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;\n        old = atomicCAS(address_as_ui, assumed, old);\n    }\n    while (assumed != old);\n}\n\n// atomicAdd for half2 types\n\n__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)\n{\n    unsigned int* address_as_ui = (unsigned int*)address;\n    unsigned int old = *address_as_ui;\n    unsigned int assumed;\n    do\n    {\n        assumed = old;\n        half2 old_val = *((half2*)&old);\n        half2 new_val = __hadd2(old_val, val);\n        old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));\n    }\n    while (assumed != old);\n}\n\n//\n\n#if defined(__CUDA_ARCH__) || defined(USE_ROCM)\n#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)\n\n__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }\n\n#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)\n__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }\n#endif\n\n#endif\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_buffers.cu",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#define _cuda_buffers_cu\n#include \"cuda_buffers.cuh\"\n\nCudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};\n// __constant__ half2 q4_table[16][256];\n// half2 q4_table_host[16][256];\n// bool q4_table_init = false;\n\nCudaBuffers::CudaBuffers\n(\n    int _device,\n    half* _temp_state,\n    half* _temp_dq\n) :\n    device(_device),\n    temp_state(_temp_state),\n    temp_dq(_temp_dq)\n{\n    cudaSetDevice(_device);\n\n    cudaStreamCreate(&alt_stream_1);\n    cudaStreamCreate(&alt_stream_2);\n    cudaStreamCreate(&alt_stream_3);\n    cudaEventCreate(&alt_stream_1_done);\n    cudaEventCreate(&alt_stream_2_done);\n    cudaEventCreate(&alt_stream_3_done);\n}\n\nCudaBuffers::~CudaBuffers()\n{\n    cudaStreamDestroy(alt_stream_1);\n    cudaStreamDestroy(alt_stream_2);\n    cudaStreamDestroy(alt_stream_3);\n    cudaEventDestroy(alt_stream_1_done);\n    cudaEventDestroy(alt_stream_2_done);\n    cudaEventDestroy(alt_stream_3_done);\n}\n\nCudaBuffers* get_buffers(const int device_index)\n{\n    return g_buffers[device_index];\n}\n\nvoid prepare_buffers_cuda\n(\n    int _device,\n    half* _temp_state,\n    half* _temp_dq\n)\n{\n    CudaBuffers* buffers = new CudaBuffers\n    (\n        _device,\n        _temp_state,\n        _temp_dq\n    );\n\n    g_buffers[_device] = buffers;\n}\n\nvoid cleanup_buffers_cuda()\n{\n    for (int i = 0; i < CUDA_MAX_DEVICES; i++)\n    {\n        if (!g_buffers[i]) continue;\n        delete g_buffers[i];\n        g_buffers[i] = NULL;\n    }\n}\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_buffers.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _cuda_buffers_cuh\n#define _cuda_buffers_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n\nconst int CUDA_MAX_DEVICES = 16;\n\n// #ifndef _cuda_buffers_cu\n// extern __constant__ half2 q4_table[16][256];\n// #endif\n\nclass CudaBuffers\n{\npublic:\n    int device;\n\n    half* temp_state;           // [max_hidden_rows * intermediate_size]\n    half* temp_dq;              // size of largest quant tensor * 8\n\n    cudaStream_t alt_stream_1;\n    cudaStream_t alt_stream_2;\n    cudaStream_t alt_stream_3;\n    cudaEvent_t alt_stream_1_done;\n    cudaEvent_t alt_stream_2_done;\n    cudaEvent_t alt_stream_3_done;\n\n    CudaBuffers\n    (\n        int _device,\n        half* _temp_state,\n        half* _temp_dq\n    );\n    ~CudaBuffers();\n};\n\nCudaBuffers* get_buffers(const int device_index);\n\nvoid prepare_buffers_cuda\n(\n    int _device,\n    half* _temp_state,\n    half* _temp_dq\n);\n\nvoid cleanup_buffers_cuda();\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cu",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#include \"column_remap.cuh\"\n#include \"../util.cuh\"\n\nconst int SHUF_BLOCKSIZE_X = 256;\nconst int SHUF_BLOCKSIZE_Y = 16;\n\n__global__ void column_remap_kernel\n(\n    const half* __restrict__ x,\n    half* __restrict__ x_new,\n    const int x_width,\n    const int x_height,\n    const uint32_t* x_map\n)\n{\n    int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;\n    int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;\n\n    int x_stride = x_width;\n    int x_idx = x_row * x_stride + x_column;\n\n    int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);\n    int x_idx_end = x_row_end * x_stride + x_column;\n\n    int s_column = x_map[x_column];\n    int s_idx = x_row * x_stride + s_column;\n\n    while (x_idx < x_idx_end)\n    {\n        x_new[x_idx] = x[s_idx];\n        x_idx += x_stride;\n        s_idx += x_stride;\n    }\n}\n\n// Remap columns in x to correspond to sequential group index before matmul\n//\n// perform x -> seq_x such that seq_x @ seq_w == x @ w\n\nvoid column_remap_cuda\n(\n    const half* x,\n    half* x_new,\n    const int x_height,\n    const int x_width,\n    const uint32_t* x_map\n)\n{\n    dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);\n\n    dim3 blocks\n    (\n        (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,\n        (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,\n        1\n    );\n\n    column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);\n}\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/column_remap.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _column_remap_cuh\n#define _column_remap_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n\nvoid column_remap_cuda\n(\n    const half* x,\n    half* x_new,\n    const int x_height,\n    const int x_width,\n    const uint32_t* x_map\n);\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cu",
    "content": "#include \"q4_matmul.cuh\"\n#include \"column_remap.cuh\"\n#include <ATen/cuda/CUDAContext.h>\n#include \"../util.cuh\"\n#include \"../matrix.cuh\"\n#include \"../cu_compat.cuh\"\n#include \"../cuda_buffers.cuh\"\n#if defined(USE_ROCM)\n#include \"../hip_compat.cuh\"\n#endif\n\nconst int THREADS_X = 32;       // Block size and thread count along columns in w and out\nconst int THREADS_Y = 1;        // Block size and thread count along rows in x and out\n\ntypedef void (*fp_q4_matmul_kernel)\n(\n    const half*,\n    const uint32_t*,\n    half*,\n    const half*,\n    const uint32_t*,\n    const int,\n    const int,\n    const int,\n    const int,\n    const int,\n    const uint32_t*,\n    bool\n);\n\ntemplate<bool use_half2, bool use_groupsize, bool use_x_map>\n__global__ void q4_matmul_kernel\n(\n    const half* __restrict__ x,\n    const uint32_t* __restrict__ w,\n    half* __restrict__ out,\n    const half* __restrict__ w_scales,\n    const uint32_t* __restrict__ w_zeros,\n    const int height,\n    const int dim,\n    const int width,\n    const int groupsize,\n    const int block_size_z,\n    const uint32_t* __restrict__ x_map,\n    bool no_zero\n)\n{\n    // Start of block\n\n    int x_column = block_size_z * blockIdx.z;\n    int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));\n\n    int w_column = THREADS_X * blockIdx.x + threadIdx.x;\n    int x_row = THREADS_Y * blockIdx.y + threadIdx.y;\n\n    int iterations = (x_column_end - x_column) / 8;\n\n    // Views\n\n    MatrixView_half x_(x, height, dim);\n    MatrixView_half w_scales_(w_scales, dim / groupsize, width);\n    MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width);\n    MatrixView_q4_column w_(w, dim, width);\n    MatrixView_half_rw out_(out, height, width);\n\n    // Zero output\n\n    if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)\n    {\n        *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;\n        __syncthreads();\n    }\n\n    // Loop over part of x row (and w column)\n\n    half2 acc = {};\n    half acc_h = {};\n\n    if constexpr (use_groupsize)\n    {\n        // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this\n        // could be slightly faster\n\n        for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)\n        {\n            if constexpr (use_half2)\n            {\n                half2 w_scale = w_scales_.item_half2half2(group, w_column);\n                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;\n\n                if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);\n                else                     acc = dot_product_8      (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);\n            }\n            else\n            {\n                half w_scale = w_scales_.item(group, w_column);\n                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;\n\n                if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);\n                else                     acc_h = dot_product_8_h      (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);\n            }\n        }\n    }\n    else\n    {\n        // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache\n\n        for (int k = x_column; k < x_column + iterations * 8; k += 8)\n        {\n            if constexpr (use_half2)\n            {\n                int group = k / groupsize;\n                half2 w_scale = w_scales_.item_half2half2(group, w_column);\n                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;\n\n                if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);\n                else                     acc = dot_product_8      (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);\n            }\n            else\n            {\n                int group = k / groupsize;\n                half w_scale = w_scales_.item(group, w_column);\n                uint32_t w_zero = (w_zeros_.item(group, w_column) + 1) & 0x0F;\n\n                if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);\n                else                     acc_h = dot_product_8_h      (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);\n            }\n        }\n    }\n\n    // Add to block result\n\n    if constexpr (use_half2)\n    {\n        half result = __hadd(__low2half(acc), __high2half(acc));\n        atomicAdd(out_.item_ptr(x_row, w_column), result);\n    }\n    else\n    {\n        atomicAdd(out_.item_ptr(x_row, w_column), acc_h);\n    }\n}\n\nfp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map)\n{\n    // <bool use_half2, bool use_groupsize, bool use_x_map>\n    if (tuningParams->matmul_no_half2) {\n        if (block_size_z % groupsize == 0) {\n            if (x_map) return q4_matmul_kernel<false, true,  true >;\n            else       return q4_matmul_kernel<false, true,  false>;\n        } else {\n            if (x_map) return q4_matmul_kernel<false, false, true >;\n            else       return q4_matmul_kernel<false, false, false>;\n        }\n    } else {\n        if (block_size_z % groupsize == 0)\n        {\n            if (x_map) return q4_matmul_kernel<true,  true,  true >;\n            else       return q4_matmul_kernel<true,  true,  false>;\n        } else {\n            if (x_map) return q4_matmul_kernel<true,  false, true >;\n            else       return q4_matmul_kernel<true,  false, false>;\n        }\n    }\n};\n\n// Compute y = x @ w\n\nvoid q4_matmul_cuda\n(\n    ExLlamaTuning* tuningParams,\n    const half* x,\n    const int x_height,\n    const Q4Matrix* w,\n    half* out,\n    bool no_zero,\n    cudaStream_t alt_stream\n)\n{\n    int height = x_height;\n    int dim = w->height;\n    int width = w->width;\n\n    cudaSetDevice(w->device);\n\n    uint32_t* x_map = w->cuda_x_map;\n    const half* x_mapped = x;\n    if (x_map && !tuningParams->matmul_fused_remap && !alt_stream)\n    {\n        CudaBuffers* buffers = get_buffers(w->device);\n        column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);\n        x_mapped = buffers->temp_state;\n        x_map = NULL;\n    }\n\n    int block_size_z;\n    if (w->width == 4096) block_size_z = 384;           // 7B\n    else if (w->width == 11008) block_size_z = 256;\n    else if (w->width == 5120) block_size_z = 384;      // 13B\n    else if (w->width == 13824) block_size_z = 256;\n    else if (w->width == 6656) block_size_z = 256;      // 33B\n    else if (w->width == 17920) block_size_z = 128;\n    else block_size_z = 256;\n\n    //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half));\n\n    dim3 threads(THREADS_X, THREADS_Y, 1);\n\n    dim3 blocks\n    (\n        (width + threads.x - 1) / threads.x,\n        (height + threads.y - 1) / threads.y,\n        (dim + block_size_z - 1) / block_size_z\n    );\n\n    fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);\n\n    kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);\n}\n\nvoid q4_matmul_recons_cuda\n(\n    ExLlamaTuning* tuningParams,\n    const half* x,\n    const int x_height,\n    Q4Matrix* w,\n    half* out,\n    bool no_zero,\n    const cublasHandle_t handle\n)\n{\n    int height = x_height;\n    int dim = w->height;\n    int width = w->width;\n\n    cudaSetDevice(w->device);\n    CudaBuffers* buffers = get_buffers(w->device);\n\n    const half* x_mapped = x;\n    if (w->cuda_x_map)\n    {\n        column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);\n        x_mapped = buffers->temp_state;\n    }\n\n    w->reconstruct(buffers->temp_dq);\n\n    const half alpha = __float2half(1.0f);\n    const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);\n    cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);\n\n//     const float alpha = 1.0f;\n//     const float beta = no_zero ? 1.0f : 0.0f;\n//     cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,\n//                 x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);\n}\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/q4_matmul.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _q4_matmul_cuh\n#define _q4_matmul_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"q4_matrix.cuh\"\n#include \"../tuning.h\"\n\nvoid q4_matmul_cuda\n(\n    ExLlamaTuning* tuningParams,\n    const half* x,\n    const int x_height,\n    const Q4Matrix* w,\n    half* out,\n    bool no_zero,\n    cudaStream_t alt_stream\n);\n\nvoid q4_matmul_recons_cuda\n(\n    ExLlamaTuning* tuningParams,\n    const half* x,\n    const int x_height,\n    Q4Matrix* w,\n    half* out,\n    bool no_zero,\n    const cublasHandle_t handle\n);\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cu",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#include <ATen/cuda/CUDAContext.h>\n#include \"q4_matrix.cuh\"\n#include <vector>\n#include \"../util.cuh\"\n#include \"../matrix.cuh\"\n\nusing namespace std;\n\nconst int UNSHUF_BLOCKSIZE_X = 64;\n\nconst int RECONS_THREADS_X = 64;      // Block size and thread count along columns in out, each thread converts 1 column\nconst int RECONS_THREADS_Y = 1;       // Block size and thread count along rows in x and out, each thread converts 8 rows\n\nvector<Q4Matrix*> g_q4_matrices;\n\nvoid g_q4_keep_matrix(Q4Matrix* m)\n{\n    g_q4_matrices.push_back(m);\n}\n\nvoid g_q4_free_matrices()\n{\n    for (const auto& m : g_q4_matrices) delete m;\n    g_q4_matrices.clear();\n}\n\nQ4Matrix::Q4Matrix\n(\n    const int _height,\n    const int _width,\n    const int _groups,\n\n    uint32_t* _qweight,\n    uint32_t* _qzeros,\n    half* _scales,\n    uint32_t* _g_idx,\n\n    const int _device\n) :\n    height(_height),\n    width(_width),\n    groups(_groups),\n    device(_device)\n{\n    cudaSetDevice(device);\n\n    cuda_qweight = _qweight;\n    cuda_qzeros = _qzeros;\n    cuda_scales = _scales;\n\n    groupsize = height / groups;\n\n    if (_g_idx) make_sequential(_g_idx);\n}\n\nQ4Matrix::~Q4Matrix()\n{\n}\n\n// Make sequential\n\n__global__ void make_sequential_kernel\n(\n    const uint32_t* __restrict__ w,\n    uint32_t* __restrict__ w_new,\n    const uint32_t* __restrict__ x_map,\n    const int w_height,\n    const int w_width\n)\n{\n    const uint64_t* w2 = (uint64_t*) w;\n    uint64_t* w_new2 = (uint64_t*) w_new;\n    int w2_stride = w_width >> 1;\n\n    int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;\n    int w_new2_row = blockIdx.y;\n\n    int x_map_idx = w_new2_row << 3;\n\n    uint64_t dst = 0;\n\n    #pragma unroll\n    for (int i = 0; i < 8; i++)\n    {\n        int source_row = x_map[x_map_idx++];\n\n        int w2_row = source_row >> 3;\n        int w2_subrow = source_row & 0x07;\n        int w2_row_shift = w2_subrow << 2;\n        int wnew2_row_shift = i << 2;\n\n    uint64_t src = w2[w2_row * w2_stride + w2_column];\n        src >>= w2_row_shift;\n        src &= 0x0000000f0000000f;\n        src <<= wnew2_row_shift;\n        dst |= src;\n    }\n\n    w_new2[w_new2_row * w2_stride + w2_column] = dst;\n}\n\nvoid Q4Matrix::make_sequential(const uint32_t* cpu_g_idx)\n{\n    uint32_t* cuda_new_qweight = NULL;\n    cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));\n    cudaMalloc(&cuda_x_map, height * sizeof(uint32_t));  // TODO: Should probably be allocated in PyTorch\n\n    uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));\n    uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));\n    uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));\n\n    // Group histogram\n\n    for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;\n\n    // Group map\n\n    for (int i = 0, acc = 0; i < groups; i++)\n    {\n        short tmp = cpu_g_idx_map[i];\n        cpu_g_idx_map[i] = acc;\n        acc += tmp;\n    }\n\n    // X map (inverse)\n\n    for (int row = 0; row < height; row++)\n    {\n        uint32_t target_group = cpu_g_idx[row];\n        uint32_t target_row = cpu_g_idx_map[target_group];\n        cpu_g_idx_map[target_group]++;\n        cpu_x_map_inv[row] = target_row;\n    }\n\n    // X map\n\n    for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;\n\n    // Move to CUDA\n\n    cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice);\n\n    // Rearrange rows in w\n\n    dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1);\n    dim3 blocks(width / UNSHUF_BLOCKSIZE_X / 2, height / 8, 1);\n\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    make_sequential_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width);\n\n    // Replace qweights\n\n    cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);\n\n    // Cleanup\n\n    cudaDeviceSynchronize();\n    cudaFree(cuda_new_qweight);\n    free(cpu_g_idx_map);\n    free(cpu_x_map);\n    free(cpu_x_map_inv);\n}\n\n__global__ void reconstruct_kernel\n(\n    const uint32_t* __restrict__ w,\n    half* __restrict__ out,  // (y)\n    const half* __restrict__ w_scales,\n    const uint32_t* __restrict__ w_zeros,\n    const int height,\n    const int width,\n    const int groupsize\n)\n{\n    // Start of block\n\n    int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x;\n    int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8;\n\n    // Views\n\n    MatrixView_q4_column w_(w, height, width);\n    MatrixView_half_rw out_(out, height, width);\n    MatrixView_half w_scales_(w_scales, height / groupsize, width);\n    MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width);\n\n    // Groupsize version\n\n    int group = row / groupsize;\n\n    half w_scale = w_scales_.item(group, column);\n    uint32_t w_zero = (w_zeros_.item(group, column) + 1) & 0x0F;\n\n    uint32_t w_read = w_.item_uint32_t(row, column);\n    half* out_ptr = out_.item_ptr(row, column);\n\n    #pragma unroll\n    for (int s = 0; s < 32; s += 4)\n    {\n        half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);\n        *out_ptr = w_item; out_ptr += out_.width;\n    }\n}\n\nvoid Q4Matrix::reconstruct(half* out)\n{\n    dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1);\n\n    dim3 blocks\n    (\n        (width + threads.x - 1) / threads.x,\n        (height / 8 + threads.y - 1) / threads.y,\n        1\n    );\n\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    reconstruct_kernel<<<blocks, threads, 0, stream>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);\n}\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/cuda_func/q4_matrix.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _q4_matrix_cuh\n#define _q4_matrix_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n\nclass Q4Matrix\n{\npublic:\n\n    int device;\n\n    int height;\n    int width;\n    int groups;\n    int groupsize;\n\n    uint32_t* cuda_qweight = NULL;\n    uint32_t* cuda_qzeros = NULL;\n    half* cuda_scales = NULL;\n    uint32_t* cuda_x_map = NULL;\n\n    Q4Matrix\n    (\n        const int _height,\n        const int _width,\n        const int _groups,\n\n        uint32_t* _qweight,\n        uint32_t* _qzeros,\n        half* _scales,\n        uint32_t* _g_idx,\n\n        const int _device\n    );\n\n    ~Q4Matrix();\n\n    void reconstruct(half* out);\n\nprivate:\n\n    void make_sequential(const uint32_t* cpu_g_idx);\n\n};\n\nvoid g_q4_keep_matrix(Q4Matrix* m);\nvoid g_q4_free_matrices();\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/exllama_ext.cpp",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#include <torch/extension.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n#include \"util.cuh\"\n#include \"tuning.h\"\n#include \"cuda_buffers.cuh\"\n#include \"cuda_func/q4_matrix.cuh\"\n#include \"cuda_func/q4_matmul.cuh\"\n#include \"cuda_func/column_remap.cuh\"\n\n// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a\n// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of\n// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.\n\nvoid check_cuda(cudaError_t ret)\n{\n    switch (ret)\n    {\n        case cudaSuccess:\n            break;\n\n        case cudaUnspecified:\n            printf(\" **** Unspecified error\\n\");\n            TORCH_CHECK(false, \"CUDA error\");\n            break;\n\n        default:\n            printf(\" **** CUDA error\\n\"); \\\n            printf(\" **** %s\\n\", cudaGetErrorString(ret)); \\\n            TORCH_CHECK(false, \"CUDA error\"); \\\n            break;\n    }\n}\n\n// Some decluttering macros\n\n#define STRINGIFY_(__x) #__x\n#define STRINGIFY(__x) STRINGIFY_(__x)\n#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x \" is incorrect datatype, must be \" #__dtype)\n#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x \" is incorrect datatype, must be \" #__dtype)\n#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x \" and \" #__y \" have incompatible shapes\")\n#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x \" and \" #__y \" have incompatible shapes\")\n#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x \".shape[\" STRINGIFY(__dim_x) \"] must be a multiple of \" STRINGIFY(__mod))\n\n#define TORCH_CHECK_DEVICE_INDEX(__index) \\\ndo { \\\n    TORCH_CHECK(__index >= 0, \"no device index\"); \\\n    TORCH_CHECK(__index < CUDA_MAX_DEVICES, \"invalid device index\"); \\\n} while(0)\n\n#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \\\ndo { \\\n    TORCH_CHECK_DTYPE(__w, kInt); \\\n    TORCH_CHECK_DTYPE(__w_scales, kHalf); \\\n    TORCH_CHECK_DTYPE(__w_zeros, kInt); \\\n    TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \\\n    TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \\\n    TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \\\n    TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \\\n} while(0)\n\nint get_groupsize(torch::Tensor w, torch::Tensor w_zeros)\n{\n    int groupsize = w.size(0) * 8 / w_zeros.size(0);\n    TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, \"w.shape[-2] must be a multiple of zeros.shape[-2]\")\n    return groupsize;\n}\n\n\n// Tuning parameters\n\nExLlamaTuning tuningParams;\n\nvoid set_tuning_params\n(\n    int matmul_recons_thd,\n    bool matmul_fused_remap,\n    bool matmul_no_half2\n)\n{\n    tuningParams.matmul_recons_thd = matmul_recons_thd;\n    tuningParams.matmul_fused_remap = matmul_fused_remap;\n    tuningParams.matmul_no_half2 = matmul_no_half2;\n}\n\n\n// Release all unmanaged objects allocated by the extension\n\nvoid cleanup()\n{\n    cleanup_buffers_cuda();\n    g_q4_free_matrices();\n}\n\n\n// Prepare buffers for forward pass\n\nvoid prepare_buffers\n(\n    torch::Device device,\n    torch::Tensor temp_state,\n    torch::Tensor temp_dq\n)\n{\n    int device_index = device.index();\n    TORCH_CHECK_DEVICE_INDEX(device_index);\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    prepare_buffers_cuda\n    (\n        device_index,\n        (half*) temp_state.data_ptr(),\n        (half*) temp_dq.data_ptr()\n    );\n}\n\n\n// Create Q4Matrix, return handle\n\nuintptr_t make_q4\n(\n    torch::Tensor qweight,\n    torch::Tensor qzeros,\n    torch::Tensor scales,\n    torch::Tensor g_idx,\n    int device\n)\n{\n    TORCH_CHECK_DTYPE(qweight, kInt);\n    TORCH_CHECK_DTYPE(qzeros, kInt);\n    TORCH_CHECK_DTYPE(scales, kHalf);\n    TORCH_CHECK_DTYPE_OPT(g_idx, kInt);\n    TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);\n    TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);\n    TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);\n\n    int width = qweight.size(1);\n    int height = qweight.size(0) * 8;\n    int groups = qzeros.size(0);\n\n    Q4Matrix* m = new Q4Matrix\n    (\n        height,\n        width,\n        groups,\n\n        (uint32_t*) qweight.data_ptr(),\n        (uint32_t*) qzeros.data_ptr(),\n        (half*) scales.data_ptr(),\n        g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),\n\n        device\n    );\n\n    g_q4_keep_matrix(m);\n    return reinterpret_cast<uintptr_t> (m);\n}\n\n\n// Matmul half @ quant -> half\n\nvoid q4_matmul\n(\n    torch::Tensor x,\n    uintptr_t w,\n    torch::Tensor out\n)\n{\n    Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);\n\n    TORCH_CHECK_DTYPE(x, kHalf);\n    TORCH_CHECK_DTYPE(out, kHalf);\n    TORCH_CHECK_SHAPES(x, 0, out, 0, 1);\n    TORCH_CHECK(wm->height == x.size(-1), \"x and w have incompatible shapes\")\n\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n\n    int x_height = x.size(0);\n\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)\n    {\n        q4_matmul_cuda\n        (\n            &tuningParams,\n            (half*) x.data_ptr(),\n            x_height,\n            wm,\n            (half*) out.data_ptr(),\n            false,\n            stream\n        );\n    }\n    else\n    {\n        q4_matmul_recons_cuda\n        (\n            &tuningParams,\n            (half*) x.data_ptr(),\n            x_height,\n            wm,\n            (half*) out.data_ptr(),\n            false,\n            at::cuda::getCurrentCUDABlasHandle()\n        );\n    }\n}\n\n\n// Remap columns in half tensor\n\nvoid column_remap\n(\n    torch::Tensor x,\n    torch::Tensor x_new,\n    torch::Tensor x_map\n)\n{\n    TORCH_CHECK_DTYPE(x, kHalf);\n    TORCH_CHECK_DTYPE(x_new, kHalf);\n    TORCH_CHECK_DTYPE(x_map, kInt);\n    TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);\n\n    int height = x.size(0);\n    int width = x.size(1);\n\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n\n    column_remap_cuda\n    (\n        (half*) x.data_ptr(),\n        (half*) x_new.data_ptr(),\n        height,\n        width,\n        (uint32_t*) x_map.data_ptr()\n    );\n}\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"set_tuning_params\", &set_tuning_params, \"set_tuning_params\");\n    m.def(\"prepare_buffers\", &prepare_buffers, \"prepare_buffers\");\n    m.def(\"cleanup\", &cleanup, \"cleanup\");\n    m.def(\"make_q4\", &make_q4, \"make_q4\");\n    m.def(\"q4_matmul\", &q4_matmul, \"q4_matmul\");\n}\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/hip_compat.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _hip_compat_cuh\n#define _hip_compat_cuh\n\n// Workaround for a bug in hipamd, backported from upstream, this is fixed in ROCm 5.6.\n__device__ __forceinline__ __half __compat_hrcp(__half x) {\n    return __half_raw{\n        static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))};\n}\n\n__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) {\n    return _Float16_2{\n        _Float16_2{static_cast<_Float16>(1.0f),\n            static_cast<_Float16>(1.0f)} / x.data};\n}\n\n#define hrcp __compat_hrcp\n#define h2rcp __compat_h2rcp\n\n// Automatic conversion of hipblasHgemm doesn't convert half to hipblasHalf.\n__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,\n                                                               hipblasOperation_t transA,\n                                                               hipblasOperation_t transB,\n                                                               int                m,\n                                                               int                n,\n                                                               int                k,\n                                                               const half*        alpha,\n                                                               const half*        AP,\n                                                               int                lda,\n                                                               const half*        BP,\n                                                               int                ldb,\n                                                               const half*        beta,\n                                                               half*              CP,\n                                                               int                ldc) {\n    return hipblasHgemm(handle, transA, transB, m, n, k,\n                        reinterpret_cast<const hipblasHalf *>(alpha),\n                        reinterpret_cast<const hipblasHalf *>(AP), lda,\n                        reinterpret_cast<const hipblasHalf *>(BP), ldb,\n                        reinterpret_cast<const hipblasHalf *>(beta),\n                        reinterpret_cast<hipblasHalf *>(CP), ldc);\n}\n#define hipblasHgemm __compat_hipblasHgemm\n\n// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.\n#define rocblas_handle hipblasHandle_t\n#define rocblas_operation_none HIPBLAS_OP_N\n#define rocblas_get_stream hipblasGetStream\n#define rocblas_set_stream hipblasSetStream\n#define rocblas_hgemm __compat_hipblasHgemm\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/matrix.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _matrix_cuh\n#define _matrix_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n\nclass MatrixView_half\n{\npublic:\n    const half* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }\n    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }\n    __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }\n    __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }\n};\n\nclass MatrixView_half_rw\n{\npublic:\n    half* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }\n    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }\n    __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }\n    __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }\n    __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }\n};\n\nclass MatrixView_q4_row\n{\npublic:\n    const uint32_t* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ int item(int row, int column) const\n    {\n        int shift = (column & 0x07) * 4;\n        return (data[row * width / 8 + column / 8] >> shift) & 0x0f;\n    }\n};\n\nclass MatrixView_q4_column\n{\npublic:\n    const uint32_t* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ int item(int row, int column) const\n    {\n        int shift = (row & 0x07) * 4;\n        return (data[row / 8 * width + column] >> shift) & 0x0f;\n    }\n\n    __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }\n    __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }\n};\n\n// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu\n\n// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale\n\n__device__ __forceinline__ half2 dot_product_8\n(\n    const half2 acc,\n    MatrixView_half& h_,\n    const int h_row,\n    const int h_column,                 // divisible by 8\n    MatrixView_q4_column& v_,\n    const int v_row,                    // divisible by 8\n    const int v_column,\n    const half2 v_scale_2,\n    const uint32_t v_zero,              // + 1 (!!)\n    const int count\n)\n{\n    const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);\n    const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);\n    half2 result = acc;\n\n    for (int i = 0; i < count; i++)\n    {\n        uint32_t v_read = *v_ptr; v_ptr += v_.width;\n\n        half v_0 = __int2half_rn((int)((v_read      ) & 0x0f) - v_zero);\n        half v_1 = __int2half_rn((int)((v_read >>  4) & 0x0f) - v_zero);\n        half v_2 = __int2half_rn((int)((v_read >>  8) & 0x0f) - v_zero);\n        half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);\n        half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);\n        half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);\n        half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);\n        half v_7 = __int2half_rn((int)((v_read >> 28)       ) - v_zero);\n\n        half2 v_01 = __halves2half2(v_0, v_1);\n        half2 v_23 = __halves2half2(v_2, v_3);\n        half2 v_45 = __halves2half2(v_4, v_5);\n        half2 v_67 = __halves2half2(v_6, v_7);\n\n//         half2 v_01 = q4_table[v_zero - 1][(v_read      ) & 0xff]; // (constant memory is too slow apparently)\n//         half2 v_23 = q4_table[v_zero - 1][(v_read >>  8) & 0xff];\n//         half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];\n//         half2 v_67 = q4_table[v_zero - 1][(v_read >> 24)       ];\n\n        half2 tmp = __hmul2(*h_ptr++, v_01);\n        tmp = __hfma2(*h_ptr++, v_23, tmp);\n        tmp = __hfma2(*h_ptr++, v_45, tmp);\n        tmp = __hfma2(*h_ptr++, v_67, tmp);\n        result = __hfma2(v_scale_2, tmp, result);\n    }\n\n    return result;\n}\n\n__device__ __forceinline__ half dot_product_8_h\n(\n    const half acc,\n    MatrixView_half& h_,\n    const int h_row,\n    const int h_column,                 // divisible by 8\n    MatrixView_q4_column& v_,\n    const int v_row,                    // divisible by 8\n    const int v_column,\n    const half v_scale,\n    const uint32_t v_zero,              // + 1 (!!)\n    const int count\n)\n{\n    const half* h_ptr = h_.item_ptr(h_row, h_column);\n    const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);\n    half result = acc;\n\n    for (int i = 0; i < count; i++)\n    {\n        uint32_t v_read = *v_ptr; v_ptr += v_.width;\n\n        half v_0 = __int2half_rn((int)((v_read      ) & 0x0f) - v_zero);\n        half v_1 = __int2half_rn((int)((v_read >>  4) & 0x0f) - v_zero);\n        half v_2 = __int2half_rn((int)((v_read >>  8) & 0x0f) - v_zero);\n        half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);\n        half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);\n        half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);\n        half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);\n        half v_7 = __int2half_rn((int)((v_read >> 28)       ) - v_zero);\n\n        half tmp = __hmul(*h_ptr++, v_0);\n        tmp = __hfma(*h_ptr++, v_1, tmp);\n        tmp = __hfma(*h_ptr++, v_2, tmp);\n        tmp = __hfma(*h_ptr++, v_3, tmp);\n        tmp = __hfma(*h_ptr++, v_4, tmp);\n        tmp = __hfma(*h_ptr++, v_5, tmp);\n        tmp = __hfma(*h_ptr++, v_6, tmp);\n        tmp = __hfma(*h_ptr++, v_7, tmp);\n        result = __hfma(v_scale, tmp, result);\n    }\n\n    return result;\n}\n\n// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map\n\n__device__ __forceinline__ half2 dot_product_8_x_map\n(\n    const half2 acc,\n    MatrixView_half& h_,\n    const int h_row,\n    const int h_column,                 // divisible by 8\n    MatrixView_q4_column& v_,\n    const int v_row,                    // divisible by 8\n    const int v_column,\n    const half2 v_scale_2,\n    const uint32_t v_zero,              // + 1 (!!)\n    const int count,\n    const uint32_t* x_map\n)\n{\n    const half* h_ptr = h_.item_ptr(h_row, 0);\n    const uint32_t* x_map_ptr = x_map + h_column;\n    const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);\n    half2 result = acc;\n\n    for (int i = 0; i < count; i++)\n    {\n        uint32_t v_read = *v_ptr; v_ptr += v_.width;\n\n        half v_0 = __int2half_rn((int)((v_read      ) & 0x0f) - v_zero);\n        half v_1 = __int2half_rn((int)((v_read >>  4) & 0x0f) - v_zero);\n        half v_2 = __int2half_rn((int)((v_read >>  8) & 0x0f) - v_zero);\n        half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);\n        half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);\n        half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);\n        half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);\n        half v_7 = __int2half_rn((int)((v_read >> 28)       ) - v_zero);\n\n        half2 v_01 = __halves2half2(v_0, v_1);\n        half2 v_23 = __halves2half2(v_2, v_3);\n        half2 v_45 = __halves2half2(v_4, v_5);\n        half2 v_67 = __halves2half2(v_6, v_7);\n\n        half h_0 = h_ptr[*x_map_ptr++];\n        half h_1 = h_ptr[*x_map_ptr++];\n        half h_2 = h_ptr[*x_map_ptr++];\n        half h_3 = h_ptr[*x_map_ptr++];\n        half h_4 = h_ptr[*x_map_ptr++];\n        half h_5 = h_ptr[*x_map_ptr++];\n        half h_6 = h_ptr[*x_map_ptr++];\n        half h_7 = h_ptr[*x_map_ptr++];\n\n        half2 h_01 = __halves2half2(h_0, h_1);\n        half2 h_23 = __halves2half2(h_2, h_3);\n        half2 h_45 = __halves2half2(h_4, h_5);\n        half2 h_67 = __halves2half2(h_6, h_7);\n\n        half2 tmp = __hmul2(h_01, v_01);\n        tmp = __hfma2(h_23, v_23, tmp);\n        tmp = __hfma2(h_45, v_45, tmp);\n        tmp = __hfma2(h_67, v_67, tmp);\n        result = __hfma2(v_scale_2, tmp, result);\n    }\n\n    return result;\n}\n\n__device__ __forceinline__ half dot_product_8_x_map_h\n(\n    const half acc,\n    MatrixView_half& h_,\n    const int h_row,\n    const int h_column,                 // divisible by 8\n    MatrixView_q4_column& v_,\n    const int v_row,                    // divisible by 8\n    const int v_column,\n    const half v_scale,\n    const uint32_t v_zero,              // + 1 (!!)\n    const int count,\n    const uint32_t* x_map\n)\n{\n    const half* h_ptr = h_.item_ptr(h_row, 0);\n    const uint32_t* x_map_ptr = x_map + h_column;\n    const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);\n    half result = acc;\n\n    for (int i = 0; i < count; i++)\n    {\n        uint32_t v_read = *v_ptr; v_ptr += v_.width;\n\n        half v_0 = __int2half_rn((int)((v_read      ) & 0x0f) - v_zero);\n        half v_1 = __int2half_rn((int)((v_read >>  4) & 0x0f) - v_zero);\n        half v_2 = __int2half_rn((int)((v_read >>  8) & 0x0f) - v_zero);\n        half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);\n        half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);\n        half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);\n        half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);\n        half v_7 = __int2half_rn((int)((v_read >> 28)       ) - v_zero);\n\n        half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);\n        tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);\n        result = __hfma(v_scale, tmp, result);\n    }\n\n    return result;\n}\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/tuning.h",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _tuning_h\n#define _tuning_h\n\nstruct ExLlamaTuning\n{\n    int matmul_recons_thd;\n    bool matmul_fused_remap;\n    bool matmul_no_half2;\n};\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/exllama_kernels/util.cuh",
    "content": "// Adapted from turboderp exllama: https://github.com/turboderp/exllama\n\n#ifndef _util_cuh\n#define _util_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n\n#if defined(USE_ROCM)\n#define cudaUnspecified hipErrorUnknown\n#else\n#define cudaUnspecified cudaErrorApiFailureBase\n#endif\n\n// React to failure on return code != cudaSuccess\n\n#define _cuda_check(fn) \\\ndo { \\\n    {_cuda_err = fn;} \\\n    if (_cuda_err != cudaSuccess) goto _cuda_fail; \\\n} while(false)\n\n// React to failure on return code == 0\n\n#define _alloc_check(fn) \\\ndo { \\\n    if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \\\n    else _cuda_err = cudaSuccess; \\\n} while(false)\n\n#endif\n"
  },
  {
    "path": "server/exllama_kernels/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport torch\n\nextra_cuda_cflags = []\nextra_cflags = []\nif torch.version.hip:\n    extra_cflags = [\"-DLEGACY_HIPBLAS_DIRECT=ON\"]\n    extra_cuda_cflags = [\"-DLEGACY_HIPBLAS_DIRECT=ON\"]\n\nextra_compile_args = {\n    \"cxx\": extra_cflags,\n    \"nvcc\": extra_cuda_cflags,\n}\n\nsetup(\n    name=\"exllama_kernels\",\n    ext_modules=[\n        CUDAExtension(\n            name=\"exllama_kernels\",\n            sources=[\n                \"exllama_kernels/exllama_ext.cpp\",\n                \"exllama_kernels/cuda_buffers.cu\",\n                \"exllama_kernels/cuda_func/column_remap.cu\",\n                \"exllama_kernels/cuda_func/q4_matmul.cu\",\n                \"exllama_kernels/cuda_func/q4_matrix.cu\",\n            ],\n            extra_compile_args=extra_compile_args,\n        )\n    ],\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/config.h",
    "content": "#ifndef _config_h\n#define _config_h\n\n#define MAX_Q_GEMM_ROWS 50\n#define MAX_Q_GEMM_WEIGHTS 4  // must be <= MAX_Q_GEMM_ROWS\n\n#define QMODE_2BIT 1\n#define QMODE_3BIT 1\n#define QMODE_4BIT 1\n#define QMODE_5BIT 1\n#define QMODE_6BIT 0\n#define QMODE_8BIT 0\n\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cpp/util.h",
    "content": "#ifndef _util_h\n#define _util_h\n\n#define DBGS(__x) printf(\"%s\\n\", __x)\n#define DBGI(__x) printf(\"%s: %i\\n\", #__x, __x)\n#define DBGI2(__x, __y) printf(\"%s, %s: %i, %i\\n\", #__x, #__y, __x, __y)\n#define DBGI3(__x, __y, __z) printf(\"%s, %s, %s: %i, %i, %i\\n\", #__x, #__y, #__z, __x, __y, __z)\n#define DBGF(__x) printf(\"%s: %f\\n\", #__x, __x)\n#define DBGF2(__x, __y) printf(\"%s, %s: %f, %f\\n\", #__x, #__y, __x, __y)\n#define DBGF3(__x, __y, __z) printf(\"%s, %s, %s: %f, %f, %f\\n\", #__x, #__y, #__z, __x, __y, __z)\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/compat.cuh",
    "content": "#ifndef _compat_cuh\n#define _compat_cuh\n\n// atomicAdd for half types, to support CC < 7.x\n\n__device__ __forceinline__ void atomicAdd_half(half* address, half val)\n{\n    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));\n    unsigned int old = *address_as_ui;\n    unsigned int assumed;\n\n    do\n    {\n        assumed = old;\n        __half_raw hsum;\n        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);\n        half tmpres = __hadd(hsum, val);\n        hsum = __half_raw(tmpres);\n        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;\n        old = atomicCAS(address_as_ui, assumed, old);\n    }\n    while (assumed != old);\n}\n\n// atomicAdd for half2 types\n\n__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)\n{\n    unsigned int* address_as_ui = (unsigned int*)address;\n    unsigned int old = *address_as_ui;\n    unsigned int assumed;\n    do\n    {\n        assumed = old;\n        half2 old_val = *((half2*)&old);\n        half2 new_val = __hadd2(old_val, val);\n        old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));\n    }\n    while (assumed != old);\n}\n\n//\n\n#if defined(__CUDA_ARCH__) || defined(USE_ROCM)\n#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)\n\n__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }\n\n#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)\n__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }\n#endif\n\n#endif\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/matrix_view.cuh",
    "content": "#ifndef _matrix_view_cuh\n#define _matrix_view_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n\n#include \"quant/qdq_util.cuh\"\n\nclass MatrixView_half\n{\npublic:\n    const half* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }\n    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }\n    __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }\n    __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }\n\n    __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const\n    {\n        half2* ptr = (half2*) item_ptr(row, column);\n        half2 i01 = ptr[0];\n        half2 i23 = ptr[1];\n        items[0] = __low2half(i01);\n        items[1] = __high2half(i01);\n        items[2] = __low2half(i23);\n        items[3] = __high2half(i23);\n    }\n    __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const\n    {\n        half2* ptr = (half2*)item_ptr(row, column);\n        half2 i01 = ptr[0];\n        half2 i23 = ptr[1];\n        items[0] = __half2float(__low2half(i01));\n        items[1] = __half2float(__high2half(i01));\n        items[2] = __half2float(__low2half(i23));\n        items[3] = __half2float(__high2half(i23));\n    }\n\n    __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const\n    {\n        half2* ptr = (half2*)item_ptr(row, column);\n        half2 i01 = ptr[0];\n        half2 i23 = ptr[1];\n        items[0] = __half2half2(__low2half(i01));\n        items[1] = __half2half2(__high2half(i01));\n        items[2] = __half2half2(__low2half(i23));\n        items[3] = __half2half2(__high2half(i23));\n    }\n};\n\nclass MatrixView_half_rw\n{\npublic:\n    half* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }\n    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }\n    __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }\n    __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }\n    __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }\n\n    __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)\n    {\n        half2 v01 = __halves2half2(v0, v1);\n        half2 v23 = __halves2half2(v2, v3);\n        half2* ptr = (half2*) item_ptr(row, column);\n        ptr[0] = v01;\n        ptr[1] = v23;\n    }\n};\n\nclass MatrixView_q4_row\n{\npublic:\n    const uint32_t* data;\n    const int height;\n    const int width;\n\n    __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)\n        : data(data), height(height), width(width)\n    { }\n\n    __device__ __forceinline__ int item(int row, int column) const\n    {\n        int shift = (column & 0x07) * 4;\n        return (data[row * width / 8 + column / 8] >> shift) & 0x0f;\n    }\n\n    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const\n    {\n        int shift = (column & 0x07) * 4;\n        uint32_t d = data[row * width / 8 + column / 8] >> shift;\n        items[0] = d & 0x0f;\n        items[1] = (d >> 4) & 0x0f;\n    }\n\n    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const\n    {\n        int shift = (column & 0x07) * 4;\n        uint32_t d = data[row * width / 8 + column / 8] >> shift;\n        items[0] = d & 0x0f;\n        items[1] = (d >> 4) & 0x0f;\n        items[2] = (d >> 8) & 0x0f;\n        items[3] = (d >> 12) & 0x0f;\n    }\n};\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu",
    "content": "#include \"q_gemm.cuh\"\n#include \"util.cuh\"\n#include \"matrix_view.cuh\"\n#include \"../config.h\"\n\n#include \"quant/qdq_2.cuh\"\n#include \"quant/qdq_3.cuh\"\n#include \"quant/qdq_4.cuh\"\n#include \"quant/qdq_5.cuh\"\n#include \"quant/qdq_6.cuh\"\n#include \"quant/qdq_8.cuh\"\n\n#define GPTQ_BLOCK_KN_SIZE 128\n#define GPTQ_BLOCK_M_SIZE_MAX 8\n#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)\n\n#define EXL2_BLOCK_KN_SIZE 64\n#define EXL2_BLOCK_M_SIZE_MAX 8\n#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)\n\n#define CLEAR_N_SIZE 256\n\n#include \"q_gemm_kernel.cuh\"\n#include \"q_gemm_kernel_gptq.cuh\"\n\nvoid gemm_half_q_half_cuda_part\n(\n    const half* a,\n    QMatrix* b,\n    half* c,\n    int size_m,\n    int size_n,\n    int size_k,\n    int m_count,\n    bool clear,\n    const half* r_weights,\n    int r_weights_stride,\n    bool mul_r_weights\n)\n{\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    if (!b->is_gptq)\n    {\n        dim3 blockDim, gridDim;\n        blockDim.x = EXL2_BLOCK_KN_SIZE;\n        blockDim.y = 1;\n        blockDim.z = 1;\n        gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);\n        gridDim.y = DIVIDE(size_m, m_count);\n        gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);\n\n        fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);\n\n        kernel<<<gridDim, blockDim, 0, stream>>>\n        (\n            a,\n            b->cuda_q_weight,\n            b->cuda_q_scale,\n            b->cuda_q_scale_max,\n            c,\n            size_m,\n            size_n,\n            size_k,\n            b->groups,\n            b->cuda_q_group_map,\n            b->cuda_q_perm,\n            b->rows_8,\n            b->rows_6,\n            b->rows_5,\n            b->rows_4,\n            b->rows_3,\n            b->rows_2,\n            clear,\n            r_weights,\n            r_weights_stride\n        );\n    }\n    else\n    {\n        dim3 blockDim, gridDim;\n        blockDim.x = GPTQ_BLOCK_KN_SIZE;\n        blockDim.y = 1;\n        blockDim.z = 1;\n        gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);\n        gridDim.y = DIVIDE(size_m, m_count);\n        gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);\n\n        fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);\n\n//         DBGX((uint64_t) r_weights);\n//         if (r_weights)\n//             print_global_mem(r_weights, 1, 1, 1);\n//         DBGI(r_weights_stride);\n\n        kernel<<<gridDim, blockDim, 0, stream>>>\n        (\n            a,\n            b->cuda_q_weight,\n            b->cuda_gptq_qzeros,\n            b->cuda_gptq_scales,\n            c,\n            size_m,\n            size_n,\n            size_k,\n            b->groups,\n            b->gptq_groupsize,\n            b->cuda_q_perm,\n            b->rows_4,\n            clear,\n            r_weights,\n            r_weights_stride\n        );\n    }\n}\n\nvoid gemm_half_q_half_cuda\n(\n    cublasHandle_t cublas_handle,\n    const half* a,\n    QMatrix* b,\n    half* c,\n    int size_m,\n    int size_n,\n    int size_k,\n    bool clear,\n    half* temp_dq,\n    bool force_cuda,\n    const half* r_weights,\n    const int r_weights_stride,\n    bool mul_r_weights\n)\n{\n    if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)\n    {\n        // Reconstruct FP16 matrix, then cuBLAS\n\n        if (!temp_dq) temp_dq = b->temp_dq;\n        b->reconstruct(temp_dq);\n\n        //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);\n\n        const half alpha = __float2half(1.0f);\n        const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);\n        cublasHgemm(cublas_handle,\n                    CUBLAS_OP_N,\n                    CUBLAS_OP_N,\n                    size_n, size_m, size_k,\n                    &alpha, temp_dq, size_n,\n                            a,       size_k,\n                    &beta,  c,       size_n);\n\n        //const float alpha = 1.0f;\n        //const float beta = clear ? 0.0f : 1.0f;\n        //cublasSgemmEx(cublas_handle,\n        //             CUBLAS_OP_N,\n        //             CUBLAS_OP_N,\n        //             size_n, size_m, size_k,\n        //             &alpha, temp_dq, CUDA_R_16F, size_n,\n        //                     a,       CUDA_R_16F, size_k,\n        //             &beta,  c,       CUDA_R_16F, size_n);\n\n        //const float alpha = 1.0f;\n        //const float beta = clear ? 0.0f : 1.0f;\n        //cublasGemmEx(cublas_handle,\n        //             CUBLAS_OP_N, CUBLAS_OP_N,\n        //             size_n, size_m, size_k,\n        //             &alpha, temp_dq, CUDA_R_16F, size_n,\n        //                     a,       CUDA_R_16F, size_k,\n        //             &beta,  c,       CUDA_R_16F, size_n,\n        //             CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);\n    }\n    else\n    {\n        // Quantized matmul\n\n        int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;\n        int max_chunks = size_m / block_m_size_max;\n        int last_chunk = max_chunks * block_m_size_max;\n        int last_chunk_size = size_m - last_chunk;\n\n        if (max_chunks)\n        {\n            gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);\n        }\n\n        if (last_chunk_size)\n        {\n            gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);\n        }\n    }\n}\n\n__global__ void clear_kernel\n(\n    half* __restrict__ c,\n    const int size_m,\n    const int size_n\n)\n{\n    int m = blockIdx.y;\n    int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;\n    if (n >= size_n) return;\n    int4* c_ptr = (int4*)(c + m * size_n + n);\n    *c_ptr = {};\n}\n\nvoid clear_tensor_cuda\n(\n    half* c,\n    int size_m,\n    int size_n\n)\n{\n//     dim3 blockDim, gridDim;\n//     blockDim.x = CLEAR_N_SIZE;\n//     blockDim.y = 1;\n//     gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);\n//     gridDim.y = size_m;\n//     clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);\n}\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh",
    "content": "#ifndef _q_gemm_cuh\n#define _q_gemm_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n#include <ATen/cuda/CUDAContext.h>\n\n#include \"q_matrix.cuh\"\n\nvoid gemm_half_q_half_cuda\n(\n    cublasHandle_t cublas_handle,\n    const half* a,\n    QMatrix* b,\n    half* c,\n    int size_m,\n    int size_n,\n    int size_k,\n    bool clear = false,\n    half* reconstruct = NULL,\n    bool force_cuda = false,\n    const half* r_weights = NULL,\n    const int r_weights_stride = 0,\n    bool mul_r_weights = false\n);\n\nvoid clear_tensor_cuda\n(\n    half* c,\n    int size_m,\n    int size_n\n);\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel.cuh",
    "content": "#include \"compat.cuh\"\n\n__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);\n}\n\n__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);\n}\n\n__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);\n    return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);\n}\n\n__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));\n    return fma(result_f, qs_f, g_result);\n}\n\n__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));\n    return fma(result_f, qs_f, g_result);\n}\n\n__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);\n    float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));\n    return fma(result_f, qs_f, g_result);\n}\n\n__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)\n{\n    // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127\n\n    float result = {};\n    #pragma unroll\n    for (int i = 0; i < 4; i++)\n    {\n        half2 w01 = dq[i];\n        float w0 = __low2float(w01);\n        float w1 = __high2float(w01);\n        float x0 = __half2float(*a_ptr++);\n        float x1 = __half2float(*a_ptr++);\n        result = fma(w0, x0, result);\n        result = fma(w1, x1, result);\n    }\n    float qs = __half2float(qs_h);\n    result *= qs;\n    half result_h = __float2half_rn(result);\n    return __hadd(result_h, g_result);\n}\n\n__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    half result_h = __hadd(__low2half(result), __high2half(result));\n    return __hfma(result_h, qs_h, g_result);\n}\n\n__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);\n    half result_h = __hadd(__low2half(result), __high2half(result));\n    return __hfma(result_h, qs_h, g_result);\n}\n\n\ntypedef void (*fp_gemm_half_q_half_kernel)\n(\n    const half*,\n    const uint32_t*,\n    const uint32_t*,\n    const half*,\n    half*,\n    const int,\n    const int,\n    const int,\n    const int,\n    const uint16_t*,\n    const uint16_t*,\n    const int,\n    const int,\n    const int,\n    const int,\n    const int,\n    const int,\n    const bool,\n    const half*,\n    const int\n);\n\ntemplate <int m_count, bool use_r_weights, bool mul_r_weights>\n__global__ void gemm_half_q_half_kernel\n(\n    const half*      __restrict__ a,\n    const uint32_t*  __restrict__ b_q_weight,\n    const uint32_t*  __restrict__ b_q_scale,\n    const half*      __restrict__ b_q_scale_max,\n    half*            __restrict__ c,\n    const int size_m,\n    const int size_n,\n    const int size_k,\n    const int groups,\n    const uint16_t* __restrict__ b_q_group_map,\n    const uint16_t* __restrict__ b_q_perm,\n    const int rows_8,\n    const int rows_6,\n    const int rows_5,\n    const int rows_4,\n    const int rows_3,\n    const int rows_2,\n    const bool clear,\n    const half* r_weights,\n    const int r_weights_stride\n)\n{\n    MatrixView_half a_(a, size_m, size_k);\n    MatrixView_half_rw c_(c, size_m, size_n);\n    MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);\n\n    int t = threadIdx.x;\n\n    // Block\n\n    int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;\n    int offset_m = blockIdx.y * m_count;\n    int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;\n\n    int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);\n    int end_m = min(offset_m + m_count, size_m);\n    int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, size_k);\n    int n = offset_n + t * 4;\n\n    // Read weights\n\n    half_uint16 weights[MAX_Q_GEMM_WEIGHTS];\n    if constexpr (use_r_weights)\n    {\n        uint16_t any_w = 0;\n        const half* w_ptr = r_weights;\n        for (int m = 0; m < m_count; ++m)\n        {\n            weights[m].as_half = *w_ptr;\n            w_ptr += r_weights_stride;\n            any_w |= weights[m].as_uint16;\n        }\n        if (!any_w) return;  // Early exit if all weights are zero -- does not zero output (!!!)\n    }\n\n    // Preload block_a\n\n    __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];\n\n    if (offset_k + t < end_k)\n    {\n        for (int m = 0; m < m_count; ++m)\n        {\n            const half* a_ptr = a_.item_ptr(offset_m + m, 0);\n            half* block_a_ptr = block_a[m];\n            half a0 = a_ptr[b_q_perm[offset_k + t]];\n//            half a0 = a_ptr[offset_k + t];\n            block_a_ptr[t] = a0;\n        }\n    }\n\n    // Clear\n\n    if (n >= size_n) return;\n\n    if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)\n    {\n        for (int m = 0; m < m_count; m++)\n            *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;\n    }\n\n    __syncthreads();\n\n    // Find initial group\n\n    //int group = offset_k / groupsize;\n    int group = b_q_group_map[offset_k * 2];\n\n//    if (offset_m == 0 && t == 0)\n//        DBGI2(offset_k, group);\n\n    // Preload scales\n\n    half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];\n\n    //int groups_in_block = DIVIDE((end_k - offset_k), groupsize);\n    int temp_k = offset_k;\n    for (int g = 0; temp_k < end_k; g++)\n    {\n        int qscales[4];\n        b_q_scale_.item4(qscales, group + g, n);\n        qscales[0]++;\n        qscales[1]++;\n        qscales[2]++;\n        qscales[3]++;\n        half maxscale = b_q_scale_max[group + g];\n        scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);\n        scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);\n        scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);\n        scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);\n        temp_k += b_q_group_map[temp_k * 2 + 1];\n    }\n\n    // a, b offset\n\n    int pre_rows_8 = min(rows_8, offset_k);\n    int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;\n    int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;\n    int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;\n    int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;\n    int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;\n    int qk = 0;\n    qk += pre_rows_8 / 32 * 8;\n    qk += pre_rows_6 / 32 * 6;\n    qk += pre_rows_5 / 32 * 5;\n    qk += pre_rows_4 / 32 * 4;\n    qk += pre_rows_3 / 32 * 3;\n    qk += pre_rows_2 / 32 * 2;\n\n    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;\n    const half* a_ptr = &block_a[0][0];\n    int a_stride = EXL2_BLOCK_KN_SIZE;\n\n    // Initial group\n\n    int scales_idx = 0;\n    half qs_h0 = scales[scales_idx][0];\n    half qs_h1 = scales[scales_idx][1];\n    half qs_h2 = scales[scales_idx][2];\n    half qs_h3 = scales[scales_idx][3];\n    int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];\n\n    // Column result\n\n    half block_c[m_count][4] = {};\n\n    // Dequantize groups\n\n    int k = offset_k;\n\n    while (k < rows_8 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 4; j++)\n        {\n            int4 load_int4[2];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][4];\n            dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);\n            dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);\n            dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);\n            dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n            a_ptr += 8;\n        }\n        k += 32;\n    }\n\n    while (k < rows_6 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 2; j++)\n        {\n            int4 load_int4[3];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][8];\n            dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);\n            dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);\n            dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);\n            dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n            a_ptr += 16;\n        }\n        k += 32;\n    }\n\n    while (k < rows_5 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 1; j++)\n        {\n            int4 load_int4[5];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][16];\n            dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);\n            dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);\n            dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);\n            dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n            a_ptr += 32;\n        }\n\n        k += 32;\n    }\n\n    while (k < rows_4 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 4; j++)\n        {\n            int4 load_int4[1];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][4];\n            dequant_4bit_8(load_int4[0].x, dq[0], size_n);\n            dequant_4bit_8(load_int4[0].y, dq[1], size_n);\n            dequant_4bit_8(load_int4[0].z, dq[2], size_n);\n            dequant_4bit_8(load_int4[0].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n            a_ptr += 8;\n        }\n        k += 32;\n    }\n\n    while (k < rows_3 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 1; j++)\n        {\n            int4 load_int4[3];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;\n            load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][16];\n            dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);\n            dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);\n            dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);\n            dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n            a_ptr += 32;\n        }\n        k += 32;\n    }\n\n    while (k < rows_2 && k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            scales_idx++;\n            qs_h0 = scales[scales_idx][0];\n            qs_h1 = scales[scales_idx][1];\n            qs_h2 = scales[scales_idx][2];\n            qs_h3 = scales[scales_idx][3];\n            nextgroup += b_q_group_map[k * 2 + 1];\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 1; j++)\n        {\n            int4 load_int4[1];\n            load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;\n\n            half2 dq[4][8];\n            dequant_2bit_16(load_int4[0].x, dq[0], size_n);\n            dequant_2bit_16(load_int4[0].y, dq[1], size_n);\n            dequant_2bit_16(load_int4[0].z, dq[2], size_n);\n            dequant_2bit_16(load_int4[0].w, dq[3], size_n);\n\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);\n                block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);\n                block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);\n                block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);\n            }\n\n            a_ptr += 16;\n        }\n        k += 16;\n    }\n\n    // Accumulate column sums in c\n\n    for (int m = 0; m < m_count; m++)\n    {\n        half2* out = (half2*)c_.item_ptr(offset_m + m, n);\n        half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);\n        half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);\n\n        if constexpr (mul_r_weights)\n        {\n            half2 w_mul2 = __half2half2(weights[m].as_half);\n            result01 = __hmul2(result01, w_mul2);\n            result23 = __hmul2(result23, w_mul2);\n        }\n\n        atomicAdd(out    , result01);\n        atomicAdd(out + 1, result23);\n//        *out = result01;\n//        *(out + 1) = result23;\n    }\n}\n\ntemplate <bool use_r_weights, bool mul_r_weights>\nstruct map_m_count_exl2 {\n    static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)\n    {\n        #if EXL2_BLOCK_M_SIZE_MAX >= 1\n        if (m_count == 1) return gemm_half_q_half_kernel<1, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 2\n        if (m_count == 2) return gemm_half_q_half_kernel<2, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 3\n        if (m_count == 3) return gemm_half_q_half_kernel<3, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 4\n        if (m_count == 4) return gemm_half_q_half_kernel<4, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 5\n        if (m_count == 5) return gemm_half_q_half_kernel<5, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 6\n        if (m_count == 6) return gemm_half_q_half_kernel<6, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 7\n        if (m_count == 7) return gemm_half_q_half_kernel<7, use_r_weights, mul_r_weights>;\n        #endif\n        #if EXL2_BLOCK_M_SIZE_MAX >= 8\n        if (m_count == 8) return gemm_half_q_half_kernel<8, use_r_weights, mul_r_weights>;\n        #endif\n        return NULL;\n    }\n};\n\nfp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count, bool r_weights, bool mul_r_weights)\n{\n    if (!r_weights && !mul_r_weights) return map_m_count_exl2<false, false>::pick_gemm_half_q_half_kernel(m_count);\n    if (!r_weights &&  mul_r_weights) return map_m_count_exl2<false,  true>::pick_gemm_half_q_half_kernel(m_count);\n    if ( r_weights && !mul_r_weights) return map_m_count_exl2< true, false>::pick_gemm_half_q_half_kernel(m_count);\n    if ( r_weights &&  mul_r_weights) return map_m_count_exl2< true,  true>::pick_gemm_half_q_half_kernel(m_count);\n    return NULL;\n}\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm_kernel_gptq.cuh",
    "content": "#include \"compat.cuh\"\n\n__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    return __hadd2(result, g_result);\n}\n\n__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    return __half2float(__low2half(result)) + __half2float(__high2half(result));\n}\n\n__forceinline__ __device__ half2 dot22_8_h2(half2(&dq)[4], const half* a_ptr)\n{\n    half2 result = {};\n    const half2* a2_ptr = (const half2*)a_ptr;\n    #pragma unroll\n    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);\n    return result;\n}\n\ntypedef void (*fp_gemm_half_q_half_gptq_kernel)\n(\n    const half*,\n    const uint32_t*,\n    const uint32_t*,\n    const half*,\n    half*,\n    const int,\n    const int,\n    const int,\n    const int,\n    const int,\n    const uint16_t*,\n    const int,\n    const bool,\n    const half*,\n    const int\n);\n\ntemplate <int m_count, bool use_r_weights, bool mul_r_weights>\n__global__ void gemm_half_q_half_gptq_kernel\n(\n    const half* __restrict__ a,\n    const uint32_t* __restrict__ b_q_weight,\n    const uint32_t* __restrict__ b_gptq_qzeros,\n    const half* __restrict__ b_gptq_scales,\n    half* __restrict__ c,\n    const int size_m,\n    const int size_n,\n    const int size_k,\n    const int groups,\n    const int groupsize,\n    const uint16_t* __restrict__ b_q_perm,\n    const int rows_4,\n    const bool clear,\n    const half* r_weights,\n    const int r_weights_stride\n)\n{\n    MatrixView_half a_(a, size_m, size_k);\n    MatrixView_half_rw c_(c, size_m, size_n);\n    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);\n    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);\n\n    int t = threadIdx.x;\n\n    // Block\n\n    int offset_n = blockIdx.x * GPTQ_BLOCK_KN_SIZE * 4;\n    int offset_m = blockIdx.y * m_count;\n    int offset_k = blockIdx.z * GPTQ_BLOCK_KN_SIZE;\n\n    int end_n = min(offset_n + GPTQ_BLOCK_KN_SIZE * 4, size_n);\n    int end_m = min(offset_m + m_count, size_m);\n    int end_k = min(offset_k + GPTQ_BLOCK_KN_SIZE, size_k);\n\n    int n = offset_n + t * 4;\n\n    // Read weights\n\n    half_uint16 weights[MAX_Q_GEMM_WEIGHTS];\n    if constexpr (use_r_weights)\n    {\n        uint16_t any_w = 0;\n        const half* w_ptr = r_weights;\n        for (int m = 0; m < m_count; ++m)\n        {\n            weights[m].as_half = *w_ptr;\n            w_ptr += r_weights_stride;\n            any_w |= weights[m].as_uint16;\n        }\n        if (!any_w) return;  // Early exit if all weights are zero -- does not zero output (!!!)\n    }\n\n    // Preload block_a\n\n    __shared__ half block_a[m_count][GPTQ_BLOCK_KN_SIZE];\n\n    if (offset_k + t < end_k)\n    {\n        for (int m = 0; m < m_count; ++m)\n        {\n            const half* a_ptr = a_.item_ptr(offset_m + m, 0);\n            half* block_a_ptr = block_a[m];\n\n            half a0;\n            if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];\n            else a0 = a_ptr[offset_k + t];\n            block_a_ptr[t] = a0;\n        }\n    }\n\n    // Zero output\n\n    if (n >= size_n) return;\n\n    if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)\n    {\n        for (int m = 0; m < m_count; m++)\n            *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;\n    }\n\n    __syncthreads();\n\n    // Find initial group\n\n    int group = offset_k / groupsize;\n    int nextgroup = offset_k + groupsize;\n\n    // a, b offset\n\n    int qk = offset_k / (32 / 4);\n\n    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;\n    const half* a_ptr = &block_a[0][0];\n    int a_stride = GPTQ_BLOCK_KN_SIZE;\n\n    // Initial group\n\n    int zeros[4];\n    half2 scales[4];\n    half2 z1z16[4][2];\n    half2 y1y16[4][2];\n    b_gptq_qzeros_.item4(zeros, group, n);\n    b_gptq_scales_.item4_h2(scales, group, n);\n    dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);\n    dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);\n    dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);\n    dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);\n\n//    __syncthreads();\n\n    // Column result\n\n    half2 block_c[m_count][4] = {};\n\n    // Dequantize and multiply\n\n    int k = offset_k;\n    while (k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            nextgroup += groupsize;\n            b_gptq_qzeros_.item4(zeros, group, n);\n            b_gptq_scales_.item4_h2(scales, group, n);\n            dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);\n            dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);\n            dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);\n            dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);\n        }\n\n        #pragma unroll\n        for (int j = 0; j < 4; j++)\n        {\n            const int4* b_ptr4 = (int4*) b_ptr;\n            int4 load_int4 = *b_ptr4;\n\n            half2 dq[4][4];\n            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);\n            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);\n            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);\n            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);\n\n            #pragma unroll\n            for (int m = 0; m < m_count; m++)\n            {\n                if constexpr (use_r_weights) { if (!weights[m].as_uint16) continue; }\n                block_c[m][0] = __hfma2(dot22_8_h2(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);\n                block_c[m][1] = __hfma2(dot22_8_h2(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);\n                block_c[m][2] = __hfma2(dot22_8_h2(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);\n                block_c[m][3] = __hfma2(dot22_8_h2(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);\n            }\n\n            b_ptr += size_n;\n            a_ptr += 8;\n        }\n\n        k += 32;\n    }\n\n    for (int m = 0; m < m_count; m++)\n    {\n        half2 *out = (half2*) c_.item_ptr(offset_m + m, n);\n        half result0 = __hadd(__low2half(block_c[m][0]), __high2half(block_c[m][0]));\n        half result1 = __hadd(__low2half(block_c[m][1]), __high2half(block_c[m][1]));\n        half result2 = __hadd(__low2half(block_c[m][2]), __high2half(block_c[m][2]));\n        half result3 = __hadd(__low2half(block_c[m][3]), __high2half(block_c[m][3]));\n        half2 result01 = __halves2half2(result0, result1);\n        half2 result23 = __halves2half2(result2, result3);\n\n        if constexpr (mul_r_weights)\n        {\n            half2 w_mul2 = __half2half2(weights[m].as_half);\n            result01 = __hmul2(result01, w_mul2);\n            result23 = __hmul2(result23, w_mul2);\n        }\n\n        atomicAdd(out    , result01);\n        atomicAdd(out + 1, result23);\n    }\n}\n\ntemplate <bool use_r_weights, bool mul_r_weights>\nstruct map_m_count_gptq {\n    static constexpr fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(int m_count)\n    {\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 1\n        if (m_count == 1) return gemm_half_q_half_gptq_kernel<1, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 2\n        if (m_count == 2) return gemm_half_q_half_gptq_kernel<2, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 3\n        if (m_count == 3) return gemm_half_q_half_gptq_kernel<3, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 4\n        if (m_count == 4) return gemm_half_q_half_gptq_kernel<4, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 5\n        if (m_count == 5) return gemm_half_q_half_gptq_kernel<5, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 6\n        if (m_count == 6) return gemm_half_q_half_gptq_kernel<6, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 7\n        if (m_count == 7) return gemm_half_q_half_gptq_kernel<7, use_r_weights, mul_r_weights>;\n        #endif\n        #if GPTQ_BLOCK_M_SIZE_MAX >= 8\n        if (m_count == 8) return gemm_half_q_half_gptq_kernel<8, use_r_weights, mul_r_weights>;\n        #endif\n        return NULL;\n    }\n};\n\nfp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(const int m_count, bool r_weights, bool mul_r_weights)\n{\n    if (!r_weights && !mul_r_weights) return map_m_count_gptq<false, false>::pick_gemm_half_q_half_gptq_kernel(m_count);\n    if (!r_weights &&  mul_r_weights) return map_m_count_gptq<false,  true>::pick_gemm_half_q_half_gptq_kernel(m_count);\n    if ( r_weights && !mul_r_weights) return map_m_count_gptq< true, false>::pick_gemm_half_q_half_gptq_kernel(m_count);\n    if ( r_weights &&  mul_r_weights) return map_m_count_gptq< true,  true>::pick_gemm_half_q_half_gptq_kernel(m_count);\n    return NULL;\n}\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cu",
    "content": "#include \"q_matrix.cuh\"\n#include \"matrix_view.cuh\"\n#include \"util.cuh\"\n\n#include \"quant/qdq_2.cuh\"\n#include \"quant/qdq_3.cuh\"\n#include \"quant/qdq_4.cuh\"\n#include \"quant/qdq_5.cuh\"\n#include \"quant/qdq_6.cuh\"\n#include \"quant/qdq_8.cuh\"\n\n#define BLOCK_KN_SIZE 128\n\n#define THREADS_X 32\n#define THREADS_Y 32\n\n// Shuffle quantized data on load\n\n__global__ void shuffle_kernel\n(\n    uint32_t* __restrict__ b_q_weight,\n    const int size_k,\n    const int size_n,\n    const int rows_8,\n    const int rows_6,\n    const int rows_5,\n    const int rows_4,\n    const int rows_3,\n    const int rows_2\n)\n{\n    int n = blockIdx.x * THREADS_X + threadIdx.x;\n    if (n >= size_n) return;\n    int k = 0;\n    uint32_t* b_ptr = b_q_weight + n;\n    while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  4; }\n    while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }\n    while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }\n    while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  8; }\n    while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }\n    while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }\n}\n\n\n// QMatrix constructor\n\nQMatrix::QMatrix\n(\n    const int _device,\n    const int _height,\n    const int _width,\n    const int _groups,\n\n    uint32_t* _q_weight,\n    uint16_t* _q_perm,\n    uint16_t* _q_invperm,\n    uint32_t* _q_scale,\n    half* _q_scale_max,\n    uint16_t* _q_groups,\n    uint16_t* _q_group_map,\n\n    uint32_t* _gptq_qzeros,\n    half* _gptq_scales,\n    uint32_t* _gptq_g_idx,\n\n    half* _temp_dq\n) :\n    device(_device),\n    height(_height),\n    width(_width),\n    groups(_groups),\n    temp_dq(_temp_dq)\n{\n    cudaSetDevice(device);\n\n    failed = false;\n\n    cuda_q_weight = _q_weight;\n    cuda_q_perm = _q_perm;\n    cuda_q_invperm = _q_invperm;\n    cuda_q_scale = _q_scale;\n    cuda_q_scale_max = _q_scale_max;\n    cuda_q_groups = _q_groups;\n    cuda_q_group_map = _q_group_map;\n    cuda_gptq_qzeros = _gptq_qzeros;\n    cuda_gptq_scales = _gptq_scales;\n\n    is_gptq = (_gptq_qzeros != NULL);\n\n    if (is_gptq)\n    {\n        gptq_groupsize = 1;\n        while (gptq_groupsize * groups < height) gptq_groupsize *= 2;\n    }\n\n    // Create group map\n\n    rows_8 = 0;\n    rows_6 = 0;\n    rows_5 = 0;\n    rows_4 = 0;\n    rows_3 = 0;\n    rows_2 = 0;\n\n    if (!is_gptq)\n    {\n        uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));\n        cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);\n\n        int row = 0;\n        for (int i = 0; i < groups; i++)\n        {\n            int bits = cpu_q_groups[i * 2];\n\n            int rows;\n            if (i < groups - 1)\n            {\n                int qrows = cpu_q_groups[i * 2 + 3] - cpu_q_groups[i * 2 + 1];\n                rows = qrows * 32 / bits;\n            }\n            else rows = height - row;\n\n            if (bits == 8) rows_8 += rows;\n            if (bits == 6) rows_6 += rows;\n            if (bits == 5) rows_5 += rows;\n            if (bits == 4) rows_4 += rows;\n            if (bits == 3) rows_3 += rows;\n            if (bits == 2) rows_2 += rows;\n            row += rows;\n        }\n\n        free(cpu_q_groups);\n\n        rows_6 += rows_8;\n        rows_5 += rows_6;\n        rows_4 += rows_5;\n        rows_3 += rows_4;\n        rows_2 += rows_3;\n    }\n    else\n    {\n        rows_4 = height;\n        rows_3 = height;\n        rows_2 = height;\n\n        if (_gptq_g_idx)\n        {\n            if (!make_sequential(_gptq_g_idx))\n            {\n                failed = true;\n                //printf(\"FAIL\\n\");\n                return;\n            }\n        }\n    }\n\n//     DBGI(rows_8);\n//     DBGI(rows_6);\n//     DBGI(rows_5);\n//     DBGI(rows_4);\n//     DBGI(rows_3);\n//     DBGI(rows_2);\n\n    // Shuffle quantized data\n\n    dim3 blockDim, gridDim;\n    blockDim.x = THREADS_X;\n    blockDim.y = 1;\n    gridDim.x = DIVIDE(width, THREADS_X);\n    gridDim.y = 1;\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);\n}\n\nQMatrix::~QMatrix()\n{\n}\n\n// Reconstruct b[k,n] (GPTQ)\n\n__global__ void reconstruct_gptq_kernel\n(\n    const uint32_t* __restrict__ b_q_weight,\n    const uint16_t* __restrict__ b_q_perm,\n    const uint32_t* __restrict__ b_gptq_qzeros,\n    const half* __restrict__ b_gptq_scales,\n    //const uint16_t* __restrict__ b_q_groups,\n    const int size_k,\n    const int size_n,\n    const int groupsize,\n    const int groups,\n    half* __restrict__ b,\n    const int rows_4\n)\n{\n    MatrixView_half_rw b_(b, size_k, size_n);\n    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);\n    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);\n\n    int offset_k = BLOCK_KN_SIZE * blockIdx.y;\n    int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;\n\n    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);\n\n    // Preload remapping table\n\n    __shared__ uint16_t perm[BLOCK_KN_SIZE];\n    int t = threadIdx.x;\n\n    if (b_q_perm)\n    {\n        if (offset_k + t < size_k)\n            perm[t] = b_q_perm[offset_k + t];\n    }\n\n    // Column\n\n    int n = offset_n + t * 4;\n    if (n >= size_n) return;\n\n    // Find initial group\n\n    int group = offset_k / groupsize;\n    int nextgroup = offset_k + groupsize;\n\n    // b offset\n\n    int qk = offset_k / (32 / 4);\n\n    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;\n\n    // Initial zeros/scale\n\n    int zeros[4];\n    half2 scales[4];\n    half2 z1z16[4][2];\n    half2 y1y16[4][2];\n    b_gptq_qzeros_.item4(zeros, group, n);\n    b_gptq_scales_.item4_h2(scales, group, n);\n    dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);\n    dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);\n    dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);\n    dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);\n\n    __syncthreads();\n\n    int k = offset_k;\n    int lk = 0;\n\n    while (k < end_k)\n    {\n        if (k == nextgroup)\n        {\n            group++;\n            nextgroup += groupsize;\n            b_gptq_qzeros_.item4(zeros, group, n);\n            b_gptq_scales_.item4_h2(scales, group, n);\n            dequant_4bit_8_prep_zero((zeros[0] + 1) & 0x0F, z1z16[0], y1y16[0]);\n            dequant_4bit_8_prep_zero((zeros[1] + 1) & 0x0F, z1z16[1], y1y16[1]);\n            dequant_4bit_8_prep_zero((zeros[2] + 1) & 0x0F, z1z16[2], y1y16[2]);\n            dequant_4bit_8_prep_zero((zeros[3] + 1) & 0x0F, z1z16[3], y1y16[3]);\n        }\n\n        for (int p = 0; p < 4; p++)\n        {\n            half2 dq[4][4];\n            const int4* b_ptr4 = (int4*) b_ptr;\n            int4 load_int4 = *b_ptr4;\n\n            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);\n            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);\n            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);\n            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);\n\n            b_ptr += size_n;\n            //half* dqh = (half*)dq;\n            if (b_q_perm)\n            {\n                for (int j = 0; j < 4; j++)\n                {\n                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);\n                    b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));\n                    b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));\n                }\n            }\n            else\n            {\n                for (int j = 0; j < 4; j++)\n                {\n                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);\n                    b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));\n                    b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));\n                }\n            }\n        }\n        k += 32;\n    }\n}\n\n\n// Reconstruct b[k,n]\n\n__global__ void reconstruct_kernel\n(\n    const uint32_t* __restrict__ b_q_weight,\n    const uint16_t* __restrict__ b_q_perm,\n    const uint32_t* __restrict__ b_q_scale,\n    const half* __restrict__ b_q_scale_max,\n    const uint16_t* __restrict__ b_q_group_map,\n    const int size_k,\n    const int size_n,\n    //const int groupsize,\n    const int groups,\n    half* __restrict__ b,\n    const int rows_8,\n    const int rows_6,\n    const int rows_5,\n    const int rows_4,\n    const int rows_3,\n    const int rows_2\n)\n{\n    MatrixView_half_rw b_(b, size_k, size_n);\n    MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);\n\n    int offset_k = BLOCK_KN_SIZE * blockIdx.y;\n    int offset_n = BLOCK_KN_SIZE * blockIdx.x;\n\n    // Preload remapping table\n\n    int t = threadIdx.x;\n    __shared__ uint16_t perm[BLOCK_KN_SIZE];\n    if (offset_k + t < size_k)\n        perm[t] = b_q_perm[offset_k + t];\n\n    // Column\n\n    int n = offset_n + t;\n    if (n >= size_n) return;\n\n    // Find initial group\n\n    // int group = offset_k / groupsize;\n    int group = b_q_group_map[offset_k * 2];\n\n    int pre_rows_8 = min(rows_8, offset_k);\n    int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;\n    int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;\n    int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;\n    int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;\n    int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;\n    int qk = 0;\n    qk += pre_rows_8 / 32 * 8;\n    qk += pre_rows_6 / 32 * 6;\n    qk += pre_rows_5 / 32 * 5;\n    qk += pre_rows_4 / 32 * 4;\n    qk += pre_rows_3 / 32 * 3;\n    qk += pre_rows_2 / 32 * 2;\n\n    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;\n\n    half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);\n    half2 qs_h2 = __halves2half2(qs_h, qs_h);\n    int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];\n\n    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);\n    int k = offset_k;\n    int lk = 0;\n\n    __syncthreads();\n\n    while (k < rows_8 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 4; p++)\n        {\n            half2 dq[4];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            uint32_t q_1 = *b_ptr; b_ptr += size_n;\n            dequant_8bit_8(q_0, q_1, dq, size_n);\n            for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 32;\n    }\n\n    while (k < rows_6 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 2; p++)\n        {\n            half2 dq[8];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            uint32_t q_1 = *b_ptr; b_ptr += size_n;\n            uint32_t q_2 = *b_ptr; b_ptr += size_n;\n            dequant_6bit_16(q_0, q_1, q_2, dq, size_n);\n            for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 32;\n    }\n\n    while (k < rows_5 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 1; p++)\n        {\n            half2 dq[16];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            uint32_t q_1 = *b_ptr; b_ptr += size_n;\n            uint32_t q_2 = *b_ptr; b_ptr += size_n;\n            uint32_t q_3 = *b_ptr; b_ptr += size_n;\n            uint32_t q_4 = *b_ptr; b_ptr += size_n;\n            dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);\n            for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 32;\n    }\n\n    while (k < rows_4 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 4; p++)\n        {\n            half2 dq[4];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            dequant_4bit_8(q_0, dq, size_n);\n            for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 32;\n    }\n\n    while (k < rows_3 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 1; p++)\n        {\n            half2 dq[16];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            uint32_t q_1 = *b_ptr; b_ptr += size_n;\n            uint32_t q_2 = *b_ptr; b_ptr += size_n;\n            dequant_3bit_32(q_0, q_1, q_2, dq, size_n);\n            for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 32;\n    }\n\n    while (k < rows_2 && k < end_k)\n    {\n        if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += b_q_group_map[k * 2 + 1]; qs_h2 = __halves2half2(qs_h, qs_h); }\n        for (int p = 0; p < 1; p++)\n        {\n            half2 dq[8];\n            uint32_t q_0 = *b_ptr; b_ptr += size_n;\n            dequant_2bit_16(q_0, dq, size_n);\n            for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);\n            half* dqh = (half*) dq;\n            for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);\n        }\n        k += 16;\n    }\n}\n\nvoid QMatrix::reconstruct(half* out)\n{\n    dim3 blockDim, gridDim;\n    blockDim.x = BLOCK_KN_SIZE;\n    blockDim.y = 1;\n    gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    if (!is_gptq)\n    {\n        gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);\n        reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>\n        (\n            cuda_q_weight,\n            cuda_q_perm,\n            cuda_q_scale,\n            cuda_q_scale_max,\n            cuda_q_group_map,\n            height,\n            width,\n            //groupsize,\n            groups,\n            out,\n            rows_8,\n            rows_6,\n            rows_5,\n            rows_4,\n            rows_3,\n            rows_2\n        );\n    }\n    else\n    {\n        gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);\n        reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>\n        (\n            cuda_q_weight,\n            cuda_q_perm,\n            cuda_gptq_qzeros,\n            cuda_gptq_scales,\n            //const uint16_t* __restrict__ b_q_groups,\n            height,\n            width,\n            gptq_groupsize,\n            groups,\n            out,\n            rows_4\n        );\n    }\n}\n\n__global__ void make_sequential_kernel\n(\n    const uint32_t* __restrict__ w,\n    uint32_t* __restrict__ w_new,\n    const uint16_t* __restrict__ q_perm,\n    const int w_height,\n    const int w_width\n)\n{\n    const uint64_t* w2 = (uint64_t*) w;\n    uint64_t* w_new2 = (uint64_t*) w_new;\n    int w2_stride = w_width >> 1;\n\n    int w2_column = THREADS_X * blockIdx.x + threadIdx.x;\n    if (w2_column >= w2_stride) return;\n\n    int w_new2_row = blockIdx.y;\n\n    int q_perm_idx = w_new2_row << 3;\n\n    uint64_t dst = 0;\n\n    #pragma unroll\n    for (int i = 0; i < 8; i++)\n    {\n        int source_row = q_perm[q_perm_idx++];\n\n        int w2_row = source_row >> 3;\n        int w2_subrow = source_row & 0x07;\n        int w2_row_shift = w2_subrow << 2;\n        int wnew2_row_shift = i << 2;\n\n        uint64_t src = w2[w2_row * w2_stride + w2_column];\n        src >>= w2_row_shift;\n        src &= 0x0000000f0000000f;\n        src <<= wnew2_row_shift;\n        dst |= src;\n    }\n\n    w_new2[w_new2_row * w2_stride + w2_column] = dst;\n}\n\nbool QMatrix::make_sequential(const uint32_t* cpu_g_idx)\n{\n    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n    uint32_t* cuda_new_qweight = NULL;\n    cudaError_t err = cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));\n    if (err != cudaSuccess) {\n        cudaError_t cuda_status = cudaGetLastError(); // Clear error\n        return false;\n    }\n\n    uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));\n    uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));\n    uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));\n\n    // Group histogram\n\n    for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;\n\n    // Group map\n\n    for (int i = 0, acc = 0; i < groups; i++)\n    {\n        short tmp = cpu_g_idx_map[i];\n        cpu_g_idx_map[i] = acc;\n        acc += tmp;\n    }\n\n    // X map (inverse)\n\n    for (int row = 0; row < height; row++)\n    {\n        uint32_t target_group = cpu_g_idx[row];\n        uint32_t target_row = cpu_g_idx_map[target_group];\n        cpu_g_idx_map[target_group]++;\n        cpu_x_map_inv[row] = target_row;\n    }\n\n    // X map\n\n    for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;\n\n    // Reduce to uint16_t\n\n    uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;\n    uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;\n    for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];\n    for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];\n\n    // Move to CUDA\n\n    cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);\n    cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);\n\n    // Rearrange rows in w\n\n    dim3 blockDim, gridDim;\n    blockDim.x = THREADS_X;\n    blockDim.y = 1;\n    gridDim.x = DIVIDE(width, THREADS_X);\n    gridDim.y = height / 8;\n\n    make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>\n    (\n        cuda_q_weight,\n        cuda_new_qweight,\n        cuda_q_perm,\n        height / 8,\n        width\n    );\n\n    // Replace qweights\n\n    cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);\n\n    // Cleanup\n\n    cudaDeviceSynchronize();\n\n    cudaFree(cuda_new_qweight);\n    free(cpu_g_idx_map);\n    free(cpu_x_map);\n    free(cpu_x_map_inv);\n\n    return true;\n}\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/q_matrix.cuh",
    "content": "#ifndef _q_matrix_cuh\n#define _q_matrix_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n\n#define MAX_SUPERGROUPS 16\n\nclass QMatrix\n{\npublic:\n\n    int device;\n    bool is_gptq;\n\n    int height;\n    int width;\n    int groups;\n    int gptq_groupsize;\n\n    int rows_8;\n    int rows_6;\n    int rows_5;\n    int rows_4;\n    int rows_3;\n    int rows_2;\n\n    uint32_t* cuda_q_weight = NULL;\n    uint16_t* cuda_q_perm = NULL;\n    uint16_t* cuda_q_invperm = NULL;\n    uint32_t* cuda_q_scale = NULL;\n    half* cuda_q_scale_max = NULL;\n    uint16_t* cuda_q_groups = NULL;\n    uint16_t* cuda_q_group_map = NULL;\n    uint32_t* cuda_gptq_qzeros = NULL;\n    half* cuda_gptq_scales = NULL;\n\n    half* temp_dq;\n\n    bool failed;\n\n    QMatrix\n    (\n        const int _device,\n        const int _height,\n        const int _width,\n        const int _groups,\n\n        uint32_t* _q_weight,\n        uint16_t* _q_perm,\n        uint16_t* _q_invperm,\n        uint32_t* _q_scale,\n        half* _q_scale_max,\n        uint16_t* _q_groups,\n        uint16_t* _q_group_map,\n\n        uint32_t* _gptq_qzeros,\n        half* _gptq_scales,\n        uint32_t* _gptq_g_idx,\n\n        half* _temp_dq\n    );\n\n    ~QMatrix();\n\n    void reconstruct(half* out);\n    bool make_sequential(const uint32_t* cpu_g_idx);\n\nprivate:\n\n};\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_2.cuh",
    "content": "#ifndef _qdq_2_cuh\n#define _qdq_2_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_2BIT == 1\n\n// Permutation:\n//\n// ffddbb99 77553311  eeccaa88 66442200\n\n__forceinline__ __device__ void shuffle_2bit_16\n(\n    uint32_t* q,\n    int stride\n)\n{\n    uint32_t qa = q[0];\n    uint32_t qb = 0;\n\n    #pragma unroll\n    for (int i = 0; i < 8; i++)\n    {\n        uint32_t qa0 = qa & 0x03;\n        uint32_t qa1 = (qa & 0x0c) >> 2;\n        qa >>= 4;\n        qb |= (qa1 << (i * 2 + 16));\n        qb |= (qa0 << (i * 2));\n    }\n    q[0] = qb;\n}\n\n__forceinline__ __device__ void dequant_2bit_16\n(\n    const uint32_t q_0,\n    half2 (&dq)[8],\n    int stride\n)\n{\n    const uint32_t c0 = 0x64006400;\n    const half y4_  = __float2half_rn(1.0f /  4.0f);\n    const half y16_ = __float2half_rn(1.0f / 16.0f);\n    const half y64_ = __float2half_rn(1.0f / 64.0f);\n    const half2 y4  = __halves2half2(y4_,  y4_);\n    const half2 y16 = __halves2half2(y16_, y16_);\n    const half2 y64 = __halves2half2(y64_, y64_);\n    const half z1_  = __float2half_rn(-1024.0f         - 2.0f);\n    const half z4_  = __float2half_rn(-1024.0f /  4.0f - 2.0f);\n    const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);\n    const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);\n    const half2 z1  = __halves2half2(z1_,  z1_);\n    const half2 z4  = __halves2half2(z4_,  z4_);\n    const half2 z16 = __halves2half2(z16_, z16_);\n    const half2 z64 = __halves2half2(z64_, z64_);\n\n    uint32_t qa = q_0;\n    half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1])      + 1024\n    half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) *  4 + 1024\n    half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024\n    half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024\n    qa >>= 8;\n    half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8])      + 1024\n    half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) *  4 + 1024\n    half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024\n    half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024\n\n    dq[0] = __hadd2(q0.as_half2, z1);\n    dq[1] = __hfma2(q1.as_half2, y4,  z4);\n    dq[2] = __hfma2(q2.as_half2, y16, z16);\n    dq[3] = __hfma2(q3.as_half2, y64, z64);\n    dq[4] = __hadd2(q4.as_half2, z1);\n    dq[5] = __hfma2(q5.as_half2, y4,  z4);\n    dq[6] = __hfma2(q6.as_half2, y16, z16);\n    dq[7] = __hfma2(q7.as_half2, y64, z64);\n}\n\n#else\n\n__forceinline__ __device__ void shuffle_2bit_16\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_2bit_16\n(\n    const uint32_t q_0,\n    half2 (&dq)[8],\n    int stride\n)\n{\n    half dqh[16];\n    for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);\n\n    for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_3.cuh",
    "content": "#ifndef _qdq_3_cuh\n#define _qdq_3_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_3BIT == 1\n\n// Permutation:\n//\n// v9997775 55333111  u8886664 44222000  (u, v lsb)\n// vjjjhhhf ffdddbbb  uiiiggge eecccaaa\n// vtttrrrp ppnnnlll  usssqqqo oommmkkk\n\n__forceinline__ __device__ void shuffle_3bit_32\n(\n    uint32_t* q,\n    int stride\n)\n{\n    uint32_t qa = q[0 * stride];\n    uint32_t qb = q[1 * stride];\n    uint32_t qc = q[2 * stride];\n\n    // qa: aa999888 77766655  54443332 22111000\n    // qb: lkkkjjji iihhhggg  fffeeedd dcccbbba\n    // qc: vvvuuutt tsssrrrq  qqpppooo nnnmmmll\n\n    uint32_t qd = qc >> 26;\n    qc <<= 4;\n    qc |= qb >> 28;\n    qb <<= 2;\n    qb |= qa >> 30;\n\n    // qa: ..999888 77766655  54443332 22111000\n    // qb: ..jjjiii hhhgggff  feeedddc ccbbbaaa\n    // qc: ..tttsss rrrqqqpp  pooonnnm mmlllkkk\n    // qd:                               vvvuuu\n\n    uint32_t za = 0;\n    uint32_t zb = 0;\n    uint32_t zc = 0;\n\n    for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }\n    for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }\n    for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }\n\n    // za:  9997775 55333111   8886664 44222000\n    // zb:  jjjhhhf ffdddbbb   iiiggge eecccaaa\n    // zc:  tttrrrp ppnnnlll   sssqqqo oommmkkk\n    // qd:                               vvvuuu\n\n    za |= ((qd & 0x01) >> 0) << 15;\n    zb |= ((qd & 0x02) >> 1) << 15;\n    zc |= ((qd & 0x04) >> 2) << 15;\n    za |= ((qd & 0x08) >> 3) << 31;\n    zb |= ((qd & 0x10) >> 4) << 31;\n    zc |= ((qd & 0x20) >> 5) << 31;\n\n    // za: v9997775 55333111  u8886664 44222000  (u, v lsb)\n    // zb: vjjjhhhf ffdddbbb  uiiiggge eecccaaa\n    // zc: vtttrrrp ppnnnlll  usssqqqo oommmkkk\n\n    q[0 * stride] = za;\n    q[1 * stride] = zb;\n    q[2 * stride] = zc;\n}\n\n__forceinline__ __device__ void dequant_3bit_32\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    const uint32_t q_2,\n    half2 (&dq)[16],\n    int stride\n)\n{\n    const uint32_t c0 = 0x64006400;\n    const half y8_  = __float2half_rn(1.0f /  8.0f);\n    const half y64_ = __float2half_rn(1.0f / 64.0f);\n    const half2 y8  = __halves2half2(y8_,  y8_);\n    const half2 y64 = __halves2half2(y64_, y64_);\n    const half z1_  = __float2half_rn(-1024.0f         - 4.0f);\n    const half z8_  = __float2half_rn(-1024.0f /  8.0f - 4.0f);\n    const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);\n    const half2 z1  = __halves2half2(z1_,  z1_);\n    const half2 z8  = __halves2half2(z8_,  z8_);\n    const half2 z64 = __halves2half2(z64_, z64_);\n\n    uint32_t qa = q_0;\n    uint32_t qb = q_1;\n    uint32_t qc = q_2;\n\n    half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1])      + 1024\n    half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) *  8 + 1024\n    qa >>= 6;\n    half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5])      + 1024\n    half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) *  8 + 1024\n    half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024\n    qa >>= 9;\n    qa &= 0x00010001;\n    half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11])      + 1024\n    half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) *  8 + 1024\n    qb >>= 6;\n    half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15])      + 1024\n    half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) *  8 + 1024\n    half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024\n    qb >>= 8;\n    qb &= 0x00020002;\n    half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21])      + 1024\n    half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) *  8 + 1024\n    qc >>= 6;\n    half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25])      + 1024\n    half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) *  8 + 1024\n    half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024\n    qc >>= 7;\n    qc &= 0x00040004;\n    half2_uint32 q15((qa | qb | qc) | c0);\n\n    dq[ 0] = __hadd2( q0.as_half2, z1);\n    dq[ 1] = __hfma2( q1.as_half2, y8,  z8);\n    dq[ 2] = __hadd2( q2.as_half2, z1);\n    dq[ 3] = __hfma2( q3.as_half2, y8,  z8);\n    dq[ 4] = __hfma2( q4.as_half2, y64, z64);\n    dq[ 5] = __hadd2( q5.as_half2, z1);\n    dq[ 6] = __hfma2( q6.as_half2, y8,  z8);\n    dq[ 7] = __hadd2( q7.as_half2, z1);\n    dq[ 8] = __hfma2( q8.as_half2, y8,  z8);\n    dq[ 9] = __hfma2( q9.as_half2, y64, z64);\n    dq[10] = __hadd2(q10.as_half2, z1);\n    dq[11] = __hfma2(q11.as_half2, y8,  z8);\n    dq[12] = __hadd2(q12.as_half2, z1);\n    dq[13] = __hfma2(q13.as_half2, y8,  z8);\n    dq[14] = __hfma2(q14.as_half2, y64, z64);\n    dq[15] = __hadd2(q15.as_half2, z1);\n}\n\n#else\n\n__forceinline__ __device__ void shuffle_3bit_32\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_3bit_32\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    const uint32_t q_2,\n    half2 (&dq)[16],\n    int stride\n)\n{\n    half dqh[32];\n    for (int i = 0; i < 10; i++) dqh[     i] = dq_ns(exb(     q_0, i * 3    , 0x07), 4);\n                                 dqh[10    ] = dq_ns(exb(q_1, q_0,        30, 0x07), 4);\n    for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb(     q_1, i * 3 + 1, 0x07), 4);\n                                 dqh[21    ] = dq_ns(exb(q_2, q_1,        31, 0x07), 4);\n    for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb(     q_2, i * 3 + 2, 0x07), 4);\n\n    for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_4.cuh",
    "content": "#ifndef _qdq_4_cuh\n#define _qdq_4_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_4BIT == 1\n\n// Permutation:\n//\n// 77775555 33331111  66664444 22220000\n\n__forceinline__ __device__ void shuffle_4bit_8\n(\n    uint32_t* q,\n    int stride\n)\n{\n    uint32_t qa = q[0];\n    uint32_t qb = 0;\n\n    #pragma unroll\n    for (int i = 0; i < 4; i++)\n    {\n        uint32_t qa0 = qa & 0x0f;\n        uint32_t qa1 = (qa & 0xf0) >> 4;\n        qa >>= 8;\n        qb |= (qa1 << (i * 4 + 16));\n        qb |= (qa0 << (i * 4));\n    }\n    q[0] = qb;\n}\n\n__forceinline__ __device__ void dequant_4bit_8\n(\n    const uint32_t q_0,\n    half2 (&dq)[4],\n    int stride\n)\n{\n    const uint32_t c0 = 0x64006400;\n    const half y16_ = __float2half_rn(1.0f / 16.0f);\n    const half2 y16 = __halves2half2(y16_, y16_);\n    const half z1_  = __float2half_rn(-1024.0f         - 8.0f);\n    const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);\n    const half2 z1  = __halves2half2(z1_,  z1_);\n    const half2 z16 = __halves2half2(z16_, z16_);\n\n    uint32_t qa = q_0;\n    half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1])      + 1024\n    half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024\n    qa >>= 8;\n    half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5])      + 1024\n    half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024\n\n    dq[0] = __hadd2(q0.as_half2, z1);\n    dq[1] = __hfma2(q1.as_half2, y16, z16);\n    dq[2] = __hadd2(q2.as_half2, z1);\n    dq[3] = __hfma2(q3.as_half2, y16, z16);\n}\n\n__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale\n(\n    const uint32_t zero,\n    const half scale,\n    half2 (&z1z16)[2],\n    half2 (&y1y16)[2]\n)\n{\n    half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);\n    half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));\n\n    half2 scale2 = __half2half2(scale);\n\n    z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));\n    z1z16[1] = __hmul2(scale2, __half2half2(z16));\n\n    const half y1 = __float2half_rn(1.0f);\n    const half y16 = __float2half_rn(1.0f / 16.0f);\n\n    y1y16[0] = __hmul2(scale2, __half2half2(y1));\n    y1y16[1] = __hmul2(scale2, __half2half2(y16));\n}\n\n__forceinline__ __device__ void dequant_4bit_8_prep_zero\n(\n    const uint32_t zero,\n    half2(&z1z16)[2],\n    half2(&y1y16)[2]\n)\n{\n    half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);\n    half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));\n\n    z1z16[0] = __half2half2(z1.as_half);\n    z1z16[1] = __half2half2(z16);\n\n    const half y1 = __float2half_rn(1.0f);\n    const half y16 = __float2half_rn(1.0f / 16.0f);\n\n    y1y16[0] = __half2half2(y1);\n    y1y16[1] = __half2half2(y16);\n}\n\n\n__forceinline__ __device__ void dequant_4bit_8_gptq\n(\n    const uint32_t q_0,\n    half2 (&dq)[4],\n    half2 (&z1z16)[2],\n    half2 (&y1y16)[2],\n    int stride,\n    bool scaled\n)\n{\n    const uint32_t c0 = 0x64006400;\n\n    uint32_t qa = q_0;\n    half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0]      + 1024, q[1]      + 1024 )\n    half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )\n    qa >>= 8;\n    half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4]      + 1024, q[5]      + 1024 )\n    half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )\n\n    if (scaled)\n    {\n        dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]);  // half2( q[0] * s - z * s, q[1] * s - z * s)\n        dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] * s - z * s, q[3] * s - z * s)\n        dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);\n        dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);\n    }\n    else\n    {\n        dq[0] = __hadd2(q0.as_half2,           z1z16[0]);  // half2( q[0] - z, q[1] - z )\n        dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] - z, q[3] - z )\n        dq[2] = __hadd2(q2.as_half2,           z1z16[0]);  // half2( q[4] - z, q[5] - z )\n        dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);  // half2( q[6] - z, q[7] - z )\n    }\n}\n\n#else\n\n__forceinline__ __device__ void shuffle_4bit_8\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_4bit_8\n(\n    const uint32_t q_0,\n    half2 (&dq)[4],\n    int stride\n)\n{\n    half dqh[8];\n    for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);\n\n    for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale\n(\n    const uint32_t zero,\n    const half scale,\n    half2 (&z1)[2],\n    half2 (&y1)[2]\n)\n{\n    half z = __int2half_rn(-((int)zero));\n    z = __hmul(z, scale);\n    z1[0] = __half2half2(z);\n    y1[0] = __half2half2(scale);\n}\n\n__forceinline__ __device__ void dequant_4bit_8_prep_zero\n(\n    const uint32_t zero,\n    half2(&z1)[2],\n    half2(&y1)[2]\n)\n{\n    half z = __int2half_rn(-((int)zero));\n    z1[0] = __half2half2(z);\n}\n\n__forceinline__ __device__ void dequant_4bit_8_gptq\n(\n    const uint32_t q_0,\n    half2 (&dq)[4],\n    half2 (&z1)[2],\n    half2 (&y1)[2],\n    int stride,\n    bool scaled\n)\n{\n    half2 dqh2[8];\n\n    uint32_t qa = q_0;\n    for (int i = 0; i < 4; i++)\n    {\n        half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;\n        half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;\n        dqh2[i] = __halves2half2(d0, d1);\n    }\n\n    if (scaled)\n    {\n        dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);\n        dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);\n        dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);\n        dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);\n    }\n    else\n    {\n        dq[0] = __hadd2(dqh2[0], z1[0]);\n        dq[1] = __hadd2(dqh2[1], z1[0]);\n        dq[2] = __hadd2(dqh2[2], z1[0]);\n        dq[3] = __hadd2(dqh2[3], z1[0]);\n    }\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_5.cuh",
    "content": "#ifndef _qdq_5_cuh\n#define _qdq_5_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_5BIT == 1\n\n// Permutation:\n//\n// v5555533 33311111  u4444422 22200000  (u, v lsb)\n// vbbbbb99 99977777  uaaaaa88 88866666\n// vhhhhhff fffddddd  ugggggee eeeccccc\n// vnnnnnll llljjjjj  ummmmmkk kkkiiiii\n// vtttttrr rrrppppp  usssssqq qqqooooo\n\n__forceinline__ __device__ void shuffle_5bit_32\n(\n    uint32_t* q,\n    int stride\n)\n{\n    uint32_t qa = q[0 * stride];\n    uint32_t qb = q[1 * stride];\n    uint32_t qc = q[2 * stride];\n    uint32_t qd = q[3 * stride];\n    uint32_t qe = q[4 * stride];\n\n    // qa: 66555554 44443333  32222211 11100000\n    // qb: ccccbbbb baaaaa99  99988888 77777666\n    // qc: jiiiiihh hhhggggg  fffffeee eedddddc\n    // qd: pppooooo nnnnnmmm  mmlllllk kkkkjjjj\n    // qe: vvvvvuuu uuttttts  ssssrrrr rqqqqqpp\n\n    uint32_t qf = qe >> 22;\n    qe <<= 8;\n    qe |= qd >> 24;\n    qd <<= 6;\n    qd |= qc >> 26;\n    qc <<= 4;\n    qc |= qb >> 28;\n    qb <<= 2;\n    qb |= qa >> 30;\n\n    // qa:   555554 44443333  32222211 11100000\n    // qb:   bbbbba aaaa9999  98888877 77766666\n    // qc:   hhhhhg ggggffff  feeeeedd dddccccc\n    // qd:   nnnnnm mmmmllll  lkkkkkjj jjjiiiii\n    // qe:   ttttts ssssrrrr  rqqqqqpp pppooooo\n    // qf:                          vv vvvuuuuu\n\n    uint32_t za = 0;\n    uint32_t zb = 0;\n    uint32_t zc = 0;\n    uint32_t zd = 0;\n    uint32_t ze = 0;\n\n    for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }\n    for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }\n    for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }\n    for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }\n    for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }\n\n    // za:  5555533 33311111   4444422 22200000\n    // zb:  bbbbb99 99977777   aaaaa88 88866666\n    // zc:  hhhhhff fffddddd   gggggee eeeccccc\n    // zd:  nnnnnll llljjjjj   mmmmmkk kkkiiiii\n    // ze:  tttttrr rrrppppp   sssssqq qqqooooo\n    // qf:                          vv vvvuuuuu\n\n    za |= ((qf & 0x001) >> 0) << 15;\n    zb |= ((qf & 0x002) >> 1) << 15;\n    zc |= ((qf & 0x004) >> 2) << 15;\n    zd |= ((qf & 0x008) >> 3) << 15;\n    ze |= ((qf & 0x010) >> 4) << 15;\n    za |= ((qf & 0x020) >> 5) << 31;\n    zb |= ((qf & 0x040) >> 6) << 31;\n    zc |= ((qf & 0x080) >> 7) << 31;\n    zd |= ((qf & 0x100) >> 8) << 31;\n    ze |= ((qf & 0x200) >> 9) << 31;\n\n    // za: v5555533 33311111  u4444422 22200000  (u, v lsb)\n    // zb: vbbbbb99 99977777  uaaaaa88 88866666\n    // zc: vhhhhhff fffddddd  ugggggee eeeccccc\n    // zd: vnnnnnll llljjjjj  ummmmmkk kkkiiiii\n    // ze: vtttttrr rrrppppp  usssssqq qqqooooo\n\n    q[0 * stride] = za;\n    q[1 * stride] = zb;\n    q[2 * stride] = zc;\n    q[3 * stride] = zd;\n    q[4 * stride] = ze;\n}\n\n__forceinline__ __device__ void dequant_5bit_32\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    const uint32_t q_2,\n    const uint32_t q_3,\n    const uint32_t q_4,\n    half2 (&dq)[16],\n    int stride\n)\n{\n    const uint32_t c0 = 0x64006400;\n    const half y32_ = __float2half_rn(1.0f / 32.0f);\n    const half2 y32 = __halves2half2(y32_, y32_);\n    const half z1_  = __float2half_rn(-1024.0f         - 16.0f);\n    const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);\n    const half2 z1  = __halves2half2(z1_,  z1_);\n    const half2 z32 = __halves2half2(z32_, z32_);\n\n    uint32_t qa = q_0;\n    uint32_t qb = q_1;\n    uint32_t qc = q_2;\n    uint32_t qd = q_3;\n    uint32_t qe = q_4;\n\n    half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1])      + 1024\n    half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024\n    qa >>= 10;\n    half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5])      + 1024\n    qa >>= 5;\n    qa &= 0x00010001;\n    half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7])      + 1024\n    half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024\n    qb >>= 10;\n    half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11])      + 1024\n    qb >>= 4;\n    qb &= 0x00020002;\n    half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13])      + 1024\n    half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024\n    qc >>= 10;\n    half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17])      + 1024\n    qc >>= 3;\n    qc &= 0x00040004;\n    half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19])      + 1024\n    half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024\n    qd >>= 10;\n    half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23])      + 1024\n    qd >>= 2;\n    qd &= 0x00080008;\n    half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25])      + 1024\n    half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024\n    qe >>= 10;\n    half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29])      + 1024\n    qe >>= 1;\n    qe &= 0x00100010;\n    half2_uint32 q15((qa | qb | qc | qd | qe) | c0);\n\n    dq[ 0] = __hadd2( q0.as_half2, z1);\n    dq[ 1] = __hfma2( q1.as_half2, y32, z32);\n    dq[ 2] = __hadd2( q2.as_half2, z1);\n    dq[ 3] = __hadd2( q3.as_half2, z1);\n    dq[ 4] = __hfma2( q4.as_half2, y32, z32);\n    dq[ 5] = __hadd2( q5.as_half2, z1);\n    dq[ 6] = __hadd2( q6.as_half2, z1);\n    dq[ 7] = __hfma2( q7.as_half2, y32, z32);\n    dq[ 8] = __hadd2( q8.as_half2, z1);\n    dq[ 9] = __hadd2( q9.as_half2, z1);\n    dq[10] = __hfma2(q10.as_half2, y32, z32);\n    dq[11] = __hadd2(q11.as_half2, z1);\n    dq[12] = __hadd2(q12.as_half2, z1);\n    dq[13] = __hfma2(q13.as_half2, y32, z32);\n    dq[14] = __hadd2(q14.as_half2, z1);\n    dq[15] = __hadd2(q15.as_half2, z1);\n}\n\n#else\n\n__forceinline__ __device__ void shuffle_5bit_32\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_5bit_32\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    const uint32_t q_2,\n    const uint32_t q_3,\n    const uint32_t q_4,\n    half2 (&dq)[16],\n    int stride\n)\n{\n    half dqh[32];\n    for (int i = 0; i <  6; i++) dqh[     i] = dq_ns(exb(     q_0, i * 5    , 0x1f), 16);\n                                 dqh[ 6    ] = dq_ns(exb(q_1, q_0,        30, 0x1f), 16);\n    for (int i = 0; i <  5; i++) dqh[ 7 + i] = dq_ns(exb(     q_1, i * 5 + 3, 0x1f), 16);\n                                 dqh[12    ] = dq_ns(exb(q_2, q_1,        28, 0x1f), 16);\n    for (int i = 0; i <  6; i++) dqh[13 + i] = dq_ns(exb(     q_2, i * 5 + 1, 0x1f), 16);\n                                 dqh[19    ] = dq_ns(exb(q_3, q_2,        31, 0x1f), 16);\n    for (int i = 0; i <  5; i++) dqh[20 + i] = dq_ns(exb(     q_3, i * 5 + 4, 0x1f), 16);\n                                 dqh[25    ] = dq_ns(exb(q_4, q_3,        29, 0x1f), 16);\n    for (int i = 0; i <  6; i++) dqh[26 + i] = dq_ns(exb(     q_4, i * 5 + 2, 0x1f), 16);\n\n    for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_6.cuh",
    "content": "#ifndef _qdq_6_cuh\n#define _qdq_6_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_6BIT == 1\n\n  // Not implemented\n\n#else\n\n__forceinline__ __device__ void shuffle_6bit_16\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_6bit_16\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    const uint32_t q_2,\n    half2 (&dq)[8],\n    int stride\n)\n{\n    half dqh[16];\n    for (int i = 0; i < 5; i++) dqh[     i] = dq_ns(exb(     q_0, i * 6    , 0x3f), 32);\n                                dqh[ 5    ] = dq_ns(exb(q_1, q_0,        30, 0x3f), 32);\n    for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb(     q_1, i * 6 + 4, 0x3f), 32);\n                                dqh[10    ] = dq_ns(exb(q_2, q_1,        28, 0x3f), 32);\n    for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb(     q_2, i * 6 + 2, 0x3f), 32);\n\n    for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_8.cuh",
    "content": "#ifndef _qdq_8_cuh\n#define _qdq_8_cuh\n\n#include \"qdq_util.cuh\"\n#include \"../../config.h\"\n\n#if QMODE_8BIT == 1\n\n  // Not implemented\n\n#else\n\n__forceinline__ __device__ void shuffle_8bit_4\n(\n    uint32_t* q,\n    int stride\n)\n{\n}\n\n__forceinline__ __device__ void dequant_8bit_8\n(\n    const uint32_t q_0,\n    const uint32_t q_1,\n    half2 (&dq)[4],\n    int stride\n)\n{\n    half dqh[8];\n    for (int i = 0; i < 4; i++) dqh[i    ] = dq_ns(exb(q_0, i * 8, 0xff), 128);\n    for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);\n\n    for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);\n}\n\n#endif\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/quant/qdq_util.cuh",
    "content": "#ifndef _qdq_util_cuh\n#define _qdq_util_cuh\n\nunion half2_uint32\n{\n    uint32_t as_uint32;\n    half2 as_half2;\n    __device__ half2_uint32(uint32_t val) : as_uint32(val) {}\n    __device__ half2_uint32(half2 val) : as_half2(val) {}\n    __device__ half2_uint32() : as_uint32(0) {}\n};\n\nunion half_uint16\n{\n    uint16_t as_uint16;\n    half as_half;\n    __device__ half_uint16(uint16_t val) : as_uint16(val) {}\n    __device__ half_uint16(half val) : as_half(val) {}\n    __device__ half_uint16() : as_uint16(0) {}\n};\n\n// Max_scale premultiplied by 1/256\n\n__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)\n{\n    int qs_i = qs + 1;\n    half qs_h = __int2half_rn(qs_i * qs_i);\n    qs_h = __hmul(qs_h, max_scale);\n    return qs_h;\n}\n\n__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)\n{\n    return __hmul(__int2half_rn(q - qzero), scale);\n}\n\n__forceinline__ __device__ half dq_ns(const int q, const int qzero)\n{\n    //return __hsub(__int2half_rn(q), __int2half_rn(qzero));\n    return __int2half_rn(q - qzero);\n}\n\n__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)\n{\n    return (int)((q >> shift) & mask);\n}\n\n__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)\n{\n    return (int)(__funnelshift_rc(q0, q1, shift) & mask);\n}\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/cuda/util.cuh",
    "content": "#ifndef _util_cuh\n#define _util_cuh\n\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n#include <ATen/cuda/CUDAContext.h>\n\n#define DIVIDE(x, size) (((x) + (size) - 1) / (size))\n\n#define DBGS(__x) printf(\"%s\\n\", __x)\n#define DBGI(__x) printf(\"%s: %i\\n\", #__x, __x)\n#define DBGI2(__x, __y) printf(\"%s, %s: %i, %i\\n\", #__x, #__y, __x, __y)\n#define DBGI3(__x, __y, __z) printf(\"%s, %s, %s: %i, %i, %i\\n\", #__x, #__y, #__z, __x, __y, __z)\n#define DBGX(__x) printf(\"%s: %x\\n\", #__x, __x)\n#define DBGX2(__x, __y) printf(\"%s, %s: %x, %x\\n\", #__x, #__y, __x, __y)\n#define DBGX3(__x, __y, __z) printf(\"%s, %s, %s: %x, %x, %x\\n\", #__x, #__y, #__z, __x, __y, __z)\n#define DBGF(__x) printf(\"%s: %f\\n\", #__x, __x)\n#define DBGF2(__x, __y) printf(\"%s, %s: %f, %f\\n\", #__x, #__y, __x, __y)\n#define DBGF3(__x, __y, __z) printf(\"%s, %s, %s: %f, %f, %f\\n\", #__x, #__y, #__z, __x, __y, __z)\n#define DBGH(__x) printf(\"%s: %f\\n\", #__x, __half2float(__x))\n#define DBGH2(__x, __y) printf(\"%s, %s: %f, %f\\n\", #__x, #__y, __half2float(__x), __half2float(__y))\n#define DBGH3(__x, __y, __z) printf(\"%s, %s, %s: %f, %f, %f\\n\", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))\n\n#define DBGIH(__x, __y) printf(\"%s, %s: %i, %f\\n\", #__x, #__y, __x, __half2float(__y))\n#define DBGIH2(__x, __y, __z) printf(\"%s, %s, %s: %i, %f, %f\\n\", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))\n\n__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)\n{\n    half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));\n    qs_h = __hmul(qs_h, qs_h);\n    qs_h = __hmul(qs_h, max_scale);\n    return qs_h;\n}\n\n__forceinline__ __device__ float clamp(float x, float a, float b)\n{\n    return fmaxf(a, fminf(b, x));\n}\n\n#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); }\ninline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true)\n{\n   if (code != cudaSuccess)\n   {\n      fprintf(stderr,\"CUDA error: %s %s %d\\n\", cudaGetErrorString(code), file, line);\n      if (abort) exit(code);\n   }\n}\n\nvoid print_global_mem(const half* ptr, int rows, int columns, int stride);\n\n#endif\n"
  },
  {
    "path": "server/exllamav2_kernels/exllamav2_kernels/ext.cpp",
    "content": "#include <torch/extension.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n#include <cstdint>\n#include <cstdio>\n\n#include \"config.h\"\n\n#include \"cuda/q_matrix.cuh\"\n#include \"cuda/q_gemm.cuh\"\n\n#include \"cpp/util.h\"\n\n// Some decluttering macros\n\n#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x \" is incorrect datatype, must be \" #__dtype)\n#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x \" is incorrect datatype, must be \" #__dtype)\n#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x \" and \" #__y \" have incompatible shapes\")\n#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x \" and \" #__y \" have incompatible shapes\")\n\n\n// Quant matrix\n\nuintptr_t make_q_matrix\n(\n    torch::Tensor q_weight,\n    torch::Tensor q_perm,\n    torch::Tensor q_invperm,\n    torch::Tensor q_scale,\n    torch::Tensor q_scale_max,\n    torch::Tensor q_groups,\n    torch::Tensor q_group_map,\n    torch::Tensor gptq_qzeros,\n    torch::Tensor gptq_scales,\n    torch::Tensor gptq_g_idx,\n    torch::Tensor temp_dq\n)\n{\n    TORCH_CHECK_DTYPE(q_weight, kInt);\n    TORCH_CHECK_DTYPE_OPT(q_perm, kShort);\n    TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);\n    TORCH_CHECK_DTYPE_OPT(q_scale, kInt);\n    TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);\n    TORCH_CHECK_DTYPE_OPT(q_groups, kShort);\n    TORCH_CHECK_DTYPE_OPT(q_group_map, kShort);\n    TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);\n    TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);\n    TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);\n\n    TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);\n\n    int device = q_weight.device().index();\n    int width = q_weight.size(1);\n    int groups;\n    int height;\n\n    if (!q_scale.device().is_meta())\n    {\n        TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);\n        TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);\n        groups = q_scale.size(0);\n        height = q_invperm.size(0);\n    }\n    else\n    {\n        TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);\n        TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);\n        groups = gptq_qzeros.size(0);\n        height = q_weight.size(0) * 8;\n    }\n\n    TORCH_CHECK(temp_dq.size(0) >= width * height, \"Insufficient size of temp_dq buffer\")\n\n    QMatrix* m = new QMatrix\n    (\n        device,\n        height,\n        width,\n        groups,\n        (uint32_t*) q_weight.data_ptr(),\n        q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),\n        q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),\n        q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),\n        q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),\n        q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),\n        q_group_map.device().is_meta() ? NULL : (uint16_t*) q_group_map.data_ptr(),\n        gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),\n        gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),\n        gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),\n        (half*) temp_dq.data_ptr()\n    );\n\n    if (m->failed) throw std::runtime_error(\"CUDA out of memory\");\n\n    return reinterpret_cast<uintptr_t> (m);\n}\n\nvoid gemm_half_q_half\n(\n    torch::Tensor a,\n    uintptr_t b,\n    torch::Tensor c,\n    bool force_cuda\n)\n{\n    QMatrix* qm = reinterpret_cast<QMatrix*> (b);\n\n    TORCH_CHECK_DTYPE(a, kHalf);\n    TORCH_CHECK_DTYPE(c, kHalf);\n    TORCH_CHECK_SHAPES(a, 0, c, 0, 1);\n    TORCH_CHECK(qm->height == a.size(1), \"a and b have incompatible shapes\")\n    TORCH_CHECK(qm->width == c.size(1), \"b and c have incompatible shapes\")\n\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n\n    gemm_half_q_half_cuda\n    (\n        at::cuda::getCurrentCUDABlasHandle(),\n        (const half*) a.data_ptr(),\n        qm,\n        (half*) c.data_ptr(),\n        c.size(0), // m\n        c.size(1), // n\n        a.size(1), // k\n        true,\n        NULL,\n        force_cuda\n    );\n}\n\n// Bindings\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"make_q_matrix\", &make_q_matrix, \"make_q_matrix\");\n    m.def(\"gemm_half_q_half\", &gemm_half_q_half, \"gemm_half_q_half\");\n}\n"
  },
  {
    "path": "server/exllamav2_kernels/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport torch\n\nextra_cuda_cflags = [\"-lineinfo\", \"-O3\"]\nextra_cflags = []\nif torch.version.hip:\n    extra_cflags = [\"-DLEGACY_HIPBLAS_DIRECT=ON\"]\n    extra_cuda_cflags += [\"-DHIPBLAS_USE_HIP_HALF\", \"-DLEGACY_HIPBLAS_DIRECT=ON\"]\n\nextra_compile_args = {\n    \"cxx\": extra_cflags,\n    \"nvcc\": extra_cuda_cflags,\n}\n\nsetup(\n    name=\"exllamav2_kernels\",\n    ext_modules=[\n        CUDAExtension(\n            name=\"exllamav2_kernels\",\n            sources=[\n                \"exllamav2_kernels/ext.cpp\",\n                \"exllamav2_kernels/cuda/q_matrix.cu\",\n                \"exllamav2_kernels/cuda/q_gemm.cu\",\n            ],\n            extra_compile_args=extra_compile_args,\n        )\n    ],\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "server/pyproject.toml",
    "content": "[project]\nname = \"text-generation-server\"\nversion = \"2.0.5-dev0\"\ndescription = \"Text Generation Inference Python gRPC Server\"\nreadme = \"README.md\"\nrequires-python = \">=3.9\"\nauthors = [\n  {name = \"Olivier Dehaene\", email = \"olivier@huggingface.co\"},\n  {name = \"Nicolas Patry\", email = \"nicolas@huggingface.co\"},\n]\ndependencies = [\n    # Remove explicit click dependency once typer/click are compatible again.\n    \"click<8.2.0\",\n    \"einops>=0.8.0\",\n    \"grpc-interceptor>=0.15.4\",\n    \"grpcio>=1.67.0\",\n    \"grpcio-reflection>=1.67.0\",\n    \"grpcio-status>=1.67.0\",\n    \"kernels>=0.2.1\",\n    \"hf-transfer>=0.1.8\",\n    \"loguru>=0.7.3\",\n    \"numpy>=1.26,<3\",\n    \"opentelemetry-api>=1.27.0\",\n    \"opentelemetry-exporter-otlp>=1.27.0\",\n    \"opentelemetry-instrumentation-grpc>=0.50b0\",\n    \"pillow>=11.1.0\",\n    \"prometheus-client>=0.21.0\",\n    \"protobuf>=5.28.3\",\n    \"py-cpuinfo>=9.0.0\",\n    \"rich>=13.8.1\",\n    \"safetensors>=0.4.5\",\n    \"scipy>=1.13.1\",\n    \"sentencepiece>=0.2.0\",\n    \"tokenizers>=0.20.3\",\n    \"typer>=0.15.1\",\n    \"transformers>=4.51.0\",\n    \"huggingface-hub>=0.30.1\",\n    \"hf-xet>=1.0.0\",\n]\n\n[[tool.uv.index]]\nname = \"pytorch-cu128\"\nurl = \"https://download.pytorch.org/whl/cu128\"\nexplicit = true\n\n[tool.uv.sources]\ntorch = [\n  { index = \"pytorch-cu128\", marker = \"sys_platform == 'linux' or sys_platform == 'win32'\" },\n]\ntorchvision = [\n  { index = \"pytorch-cu128\", marker = \"sys_platform == 'linux' or sys_platform == 'win32'\" },\n]\n\n[build-system]\nrequires = [\"kernels>=0.1.7\", \"setuptools\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[tool.kernels.dependencies]\n\"kernels-community/paged-attention\" = \">=0.0.2\"\n\"kernels-community/moe\" = \">=0.1.1\"\n\"kernels-community/punica-sgmv\" = \">=0.0.1\"\n\"kernels-community/quantization\" = \">=0.0.3\"\n\"kernels-community/quantization-eetq\" = \">=0.0.1\"\n\"kernels-community/rotary\" = \">=0.0.1\"\n\n[project.scripts]\ntext-generation-server = \"text_generation_server.cli:app\"\n\n[project.optional-dependencies]\naccelerate = [\n    \"accelerate>=1.2.1,<2\",\n]\nbnb = [\n    \"bitsandbytes>=0.45.0\",\n]\ncompressed-tensors = [\n    \"compressed-tensors>=0.9.0\",\n]\npeft = [\n    \"peft>=0.14.0\",\n]\noutlines = [\n    \"outlines>=0.1.13,<1.0\",\n]\ndev = [\n    \"grpcio-tools>=1.51.1,<2.0\",\n    \"pytest>=7.3.0,<8\"\n]\nquantize = [\n    \"texttable>=1.6.7,<2\",\n    \"datasets>=2.21,<3\",\n]\ngen = [\n    \"grpcio-tools>=1.69.0\",\n    \"mypy-protobuf>=3.6.0\",\n]\ntorch = [\n    \"torch==2.7.0\",\n    \"torchvision==0.22.0\",\n]\n\n[tool.pytest.ini_options]\nmarkers = [\"private: marks tests as requiring an admin hf token (deselect with '-m \\\"not private\\\"')\"]\n\n[tool.isort]\nprofile = \"black\"\n\n[tool.uv]\npackage = true\n\n[tool.setuptools.packages.find]\ninclude = [\"text_generation_server*\"]\n"
  },
  {
    "path": "server/req.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml --extra attention --extra bnb -o req.txt\nattention-kernels @ https://github.com/danieldk/attention-kernels/releases/download/v0.2.0.post2/attention_kernels-0.2.0.post2+cu123torch2.5-cp39-abi3-linux_x86_64.whl\n    # via text-generation-server (pyproject.toml)\nbitsandbytes==0.45.1\n    # via text-generation-server (pyproject.toml)\ncertifi==2025.1.31\n    # via requests\ncharset-normalizer==3.4.1\n    # via requests\nclick==8.1.8\n    # via typer\ndeprecated==1.2.18\n    # via\n    #   opentelemetry-api\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-semantic-conventions\neinops==0.8.0\n    # via text-generation-server (pyproject.toml)\nfilelock==3.17.0\n    # via\n    #   huggingface-hub\n    #   torch\nfsspec==2025.2.0\n    # via\n    #   huggingface-hub\n    #   torch\ngoogleapis-common-protos==1.66.0\n    # via\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\ngrpc-interceptor==0.15.4\n    # via text-generation-server (pyproject.toml)\ngrpcio==1.70.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   grpc-interceptor\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\ngrpcio-reflection==1.70.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-status==1.70.0\n    # via text-generation-server (pyproject.toml)\nhf-transfer==0.1.9\n    # via text-generation-server (pyproject.toml)\nhuggingface-hub==0.28.1\n    # via tokenizers\nidna==3.10\n    # via requests\nimportlib-metadata==8.5.0\n    # via opentelemetry-api\njinja2==3.1.5\n    # via torch\nloguru==0.7.3\n    # via text-generation-server (pyproject.toml)\nmarkdown-it-py==3.0.0\n    # via rich\nmarkupsafe==3.0.2\n    # via jinja2\nmdurl==0.1.2\n    # via markdown-it-py\nmpmath==1.3.0\n    # via sympy\nnetworkx==3.4.2\n    # via torch\nnumpy==2.2.2\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   bitsandbytes\n    #   scipy\nnvidia-cublas-cu12==12.4.5.8\n    # via\n    #   nvidia-cudnn-cu12\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cuda-cupti-cu12==12.4.127\n    # via torch\nnvidia-cuda-nvrtc-cu12==12.4.127\n    # via torch\nnvidia-cuda-runtime-cu12==12.4.127\n    # via torch\nnvidia-cudnn-cu12==9.1.0.70\n    # via torch\nnvidia-cufft-cu12==11.2.1.3\n    # via torch\nnvidia-curand-cu12==10.3.5.147\n    # via torch\nnvidia-cusolver-cu12==11.6.1.9\n    # via torch\nnvidia-cusparse-cu12==12.3.1.170\n    # via\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cusparselt-cu12==0.6.2\n    # via torch\nnvidia-nccl-cu12==2.21.5\n    # via torch\nnvidia-nvjitlink-cu12==12.4.127\n    # via\n    #   nvidia-cusolver-cu12\n    #   nvidia-cusparse-cu12\n    #   torch\nnvidia-nvtx-cu12==12.4.127\n    # via torch\nopentelemetry-api==1.30.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\n    #   opentelemetry-semantic-conventions\nopentelemetry-exporter-otlp==1.30.0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-exporter-otlp-proto-common==1.30.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-exporter-otlp-proto-grpc==1.30.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-exporter-otlp-proto-http==1.30.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-instrumentation==0.51b0\n    # via opentelemetry-instrumentation-grpc\nopentelemetry-instrumentation-grpc==0.51b0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-proto==1.30.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-common\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-sdk==1.30.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-semantic-conventions==0.51b0\n    # via\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\npackaging==24.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-instrumentation\npillow==11.1.0\n    # via text-generation-server (pyproject.toml)\nprometheus-client==0.21.1\n    # via text-generation-server (pyproject.toml)\nprotobuf==5.29.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   googleapis-common-protos\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-proto\npy-cpuinfo==9.0.0\n    # via text-generation-server (pyproject.toml)\npygments==2.19.1\n    # via rich\npyyaml==6.0.2\n    # via huggingface-hub\nrequests==2.32.3\n    # via\n    #   huggingface-hub\n    #   opentelemetry-exporter-otlp-proto-http\nrich==13.9.4\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\nsafetensors==0.5.2\n    # via text-generation-server (pyproject.toml)\nscipy==1.15.1\n    # via text-generation-server (pyproject.toml)\nsentencepiece==0.2.0\n    # via text-generation-server (pyproject.toml)\nsetuptools==75.8.0\n    # via torch\nshellingham==1.5.4\n    # via typer\nsympy==1.13.1\n    # via torch\ntokenizers==0.21.0\n    # via text-generation-server (pyproject.toml)\ntorch==2.6.0\n    # via\n    #   attention-kernels\n    #   bitsandbytes\ntqdm==4.67.1\n    # via huggingface-hub\ntriton==3.2.0\n    # via torch\ntyper==0.15.1\n    # via text-generation-server (pyproject.toml)\ntyping-extensions==4.12.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-sdk\n    #   torch\n    #   typer\nurllib3==2.3.0\n    # via requests\nwrapt==1.17.2\n    # via\n    #   deprecated\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\nzipp==3.21.0\n    # via importlib-metadata\n"
  },
  {
    "path": "server/requirements_cuda.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml --extra bnb --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_cuda.txt --python-version 3.11\naccelerate==1.6.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   peft\naiohappyeyeballs==2.6.1\n    # via aiohttp\naiohttp==3.11.18\n    # via\n    #   datasets\n    #   fsspec\naiosignal==1.3.2\n    # via aiohttp\nairportsdata==20250224\n    # via outlines\nannotated-types==0.7.0\n    # via pydantic\nattrs==25.3.0\n    # via\n    #   aiohttp\n    #   jsonschema\n    #   referencing\nbitsandbytes==0.45.5\n    # via text-generation-server (pyproject.toml)\ncertifi==2025.4.26\n    # via requests\ncharset-normalizer==3.4.2\n    # via requests\nclick==8.1.8\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\ncloudpickle==3.1.1\n    # via outlines\ncompressed-tensors==0.9.4\n    # via text-generation-server (pyproject.toml)\ndatasets==2.21.0\n    # via text-generation-server (pyproject.toml)\ndeprecated==1.2.18\n    # via\n    #   opentelemetry-api\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-semantic-conventions\ndill==0.3.8\n    # via\n    #   datasets\n    #   multiprocess\ndiskcache==5.6.3\n    # via outlines\neinops==0.8.1\n    # via text-generation-server (pyproject.toml)\nfilelock==3.18.0\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\n    #   transformers\nfrozenlist==1.6.0\n    # via\n    #   aiohttp\n    #   aiosignal\nfsspec==2024.6.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\ngenson==1.3.0\n    # via outlines\ngoogleapis-common-protos==1.70.0\n    # via\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\ngrpc-interceptor==0.15.4\n    # via text-generation-server (pyproject.toml)\ngrpcio==1.71.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   grpc-interceptor\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\ngrpcio-reflection==1.71.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-status==1.71.0\n    # via text-generation-server (pyproject.toml)\nhf-transfer==0.1.9\n    # via text-generation-server (pyproject.toml)\nhf-xet==1.1.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   huggingface-hub\nhuggingface-hub==0.31.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   datasets\n    #   kernels\n    #   peft\n    #   tokenizers\n    #   transformers\nidna==3.10\n    # via\n    #   requests\n    #   yarl\nimportlib-metadata==8.6.1\n    # via opentelemetry-api\ninteregular==0.3.3\n    # via\n    #   outlines\n    #   outlines-core\niso3166==2.1.1\n    # via outlines\njinja2==3.1.6\n    # via\n    #   outlines\n    #   torch\njsonschema==4.23.0\n    # via\n    #   outlines\n    #   outlines-core\njsonschema-specifications==2025.4.1\n    # via jsonschema\nkernels==0.5.0\n    # via text-generation-server (pyproject.toml)\nlark==1.2.2\n    # via outlines\nloguru==0.7.3\n    # via text-generation-server (pyproject.toml)\nmarkdown-it-py==3.0.0\n    # via rich\nmarkupsafe==3.0.2\n    # via jinja2\nmdurl==0.1.2\n    # via markdown-it-py\nmpmath==1.3.0\n    # via sympy\nmultidict==6.4.3\n    # via\n    #   aiohttp\n    #   yarl\nmultiprocess==0.70.16\n    # via datasets\nnest-asyncio==1.6.0\n    # via outlines\nnetworkx==3.4.2\n    # via torch\nnumpy==2.2.5\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   bitsandbytes\n    #   datasets\n    #   outlines\n    #   pandas\n    #   peft\n    #   scipy\n    #   transformers\nnvidia-cublas-cu12==12.6.4.1\n    # via\n    #   nvidia-cudnn-cu12\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cuda-cupti-cu12==12.6.80\n    # via torch\nnvidia-cuda-nvrtc-cu12==12.6.77\n    # via torch\nnvidia-cuda-runtime-cu12==12.6.77\n    # via torch\nnvidia-cudnn-cu12==9.5.1.17\n    # via torch\nnvidia-cufft-cu12==11.3.0.4\n    # via torch\nnvidia-cufile-cu12==1.11.1.6\n    # via torch\nnvidia-curand-cu12==10.3.7.77\n    # via torch\nnvidia-cusolver-cu12==11.7.1.2\n    # via torch\nnvidia-cusparse-cu12==12.5.4.2\n    # via\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cusparselt-cu12==0.6.3\n    # via torch\nnvidia-nccl-cu12==2.26.2\n    # via torch\nnvidia-nvjitlink-cu12==12.6.85\n    # via\n    #   nvidia-cufft-cu12\n    #   nvidia-cusolver-cu12\n    #   nvidia-cusparse-cu12\n    #   torch\nnvidia-nvtx-cu12==12.6.77\n    # via torch\nopentelemetry-api==1.33.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\n    #   opentelemetry-semantic-conventions\nopentelemetry-exporter-otlp==1.33.0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-exporter-otlp-proto-common==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-exporter-otlp-proto-grpc==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-exporter-otlp-proto-http==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-instrumentation==0.54b0\n    # via opentelemetry-instrumentation-grpc\nopentelemetry-instrumentation-grpc==0.54b0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-proto==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-common\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-sdk==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-semantic-conventions==0.54b0\n    # via\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\noutlines==0.2.3\n    # via text-generation-server (pyproject.toml)\noutlines-core==0.1.26\n    # via outlines\npackaging==25.0\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   kernels\n    #   opentelemetry-instrumentation\n    #   peft\n    #   transformers\npandas==2.2.3\n    # via datasets\npeft==0.15.2\n    # via text-generation-server (pyproject.toml)\npillow==11.2.1\n    # via text-generation-server (pyproject.toml)\nprometheus-client==0.21.1\n    # via text-generation-server (pyproject.toml)\npropcache==0.3.1\n    # via\n    #   aiohttp\n    #   yarl\nprotobuf==5.29.4\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   googleapis-common-protos\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-proto\npsutil==7.0.0\n    # via\n    #   accelerate\n    #   peft\npy-cpuinfo==9.0.0\n    # via text-generation-server (pyproject.toml)\npyarrow==20.0.0\n    # via datasets\npydantic==2.11.4\n    # via\n    #   compressed-tensors\n    #   outlines\npydantic-core==2.33.2\n    # via pydantic\npygments==2.19.1\n    # via rich\npython-dateutil==2.9.0.post0\n    # via pandas\npytz==2025.2\n    # via pandas\npyyaml==6.0.2\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   peft\n    #   transformers\nreferencing==0.36.2\n    # via\n    #   jsonschema\n    #   jsonschema-specifications\n    #   outlines\nregex==2024.11.6\n    # via transformers\nrequests==2.32.3\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   opentelemetry-exporter-otlp-proto-http\n    #   outlines\n    #   transformers\nrich==14.0.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\nrpds-py==0.24.0\n    # via\n    #   jsonschema\n    #   referencing\nsafetensors==0.5.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   peft\n    #   transformers\nscipy==1.15.3\n    # via text-generation-server (pyproject.toml)\nsentencepiece==0.2.0\n    # via text-generation-server (pyproject.toml)\nsetuptools==80.4.0\n    # via triton\nshellingham==1.5.4\n    # via typer\nsix==1.17.0\n    # via python-dateutil\nsympy==1.14.0\n    # via torch\ntexttable==1.7.0\n    # via text-generation-server (pyproject.toml)\ntokenizers==0.21.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   transformers\ntorch==2.7.0\n    # via\n    #   accelerate\n    #   bitsandbytes\n    #   compressed-tensors\n    #   outlines\n    #   peft\ntqdm==4.67.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   outlines\n    #   peft\n    #   transformers\ntransformers==4.51.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   compressed-tensors\n    #   peft\ntriton==3.3.0\n    # via torch\ntyper==0.15.3\n    # via text-generation-server (pyproject.toml)\ntyping-extensions==4.13.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-sdk\n    #   outlines\n    #   pydantic\n    #   pydantic-core\n    #   referencing\n    #   torch\n    #   typer\n    #   typing-inspection\ntyping-inspection==0.4.0\n    # via pydantic\ntzdata==2025.2\n    # via pandas\nurllib3==2.4.0\n    # via requests\nwrapt==1.17.2\n    # via\n    #   deprecated\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\nxxhash==3.5.0\n    # via datasets\nyarl==1.20.0\n    # via aiohttp\nzipp==3.21.0\n    # via importlib-metadata\n"
  },
  {
    "path": "server/requirements_gen.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml --extra gen -o requirements_gen.txt --python-version 3.11\ncertifi==2025.4.26\n    # via requests\ncharset-normalizer==3.4.2\n    # via requests\nclick==8.1.8\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\ndeprecated==1.2.18\n    # via\n    #   opentelemetry-api\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-semantic-conventions\neinops==0.8.1\n    # via text-generation-server (pyproject.toml)\nfilelock==3.18.0\n    # via\n    #   huggingface-hub\n    #   transformers\nfsspec==2025.3.2\n    # via huggingface-hub\ngoogleapis-common-protos==1.70.0\n    # via\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\ngrpc-interceptor==0.15.4\n    # via text-generation-server (pyproject.toml)\ngrpcio==1.71.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   grpc-interceptor\n    #   grpcio-reflection\n    #   grpcio-status\n    #   grpcio-tools\n    #   opentelemetry-exporter-otlp-proto-grpc\ngrpcio-reflection==1.71.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-status==1.71.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-tools==1.71.0\n    # via text-generation-server (pyproject.toml)\nhf-transfer==0.1.9\n    # via text-generation-server (pyproject.toml)\nhf-xet==1.1.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   huggingface-hub\nhuggingface-hub==0.31.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   kernels\n    #   tokenizers\n    #   transformers\nidna==3.10\n    # via requests\nimportlib-metadata==8.6.1\n    # via opentelemetry-api\nkernels==0.5.0\n    # via text-generation-server (pyproject.toml)\nloguru==0.7.3\n    # via text-generation-server (pyproject.toml)\nmarkdown-it-py==3.0.0\n    # via rich\nmdurl==0.1.2\n    # via markdown-it-py\nmypy-protobuf==3.6.0\n    # via text-generation-server (pyproject.toml)\nnumpy==2.2.5\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   scipy\n    #   transformers\nopentelemetry-api==1.33.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\n    #   opentelemetry-semantic-conventions\nopentelemetry-exporter-otlp==1.33.0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-exporter-otlp-proto-common==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-exporter-otlp-proto-grpc==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-exporter-otlp-proto-http==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-instrumentation==0.54b0\n    # via opentelemetry-instrumentation-grpc\nopentelemetry-instrumentation-grpc==0.54b0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-proto==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-common\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-sdk==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-semantic-conventions==0.54b0\n    # via\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\npackaging==25.0\n    # via\n    #   huggingface-hub\n    #   kernels\n    #   opentelemetry-instrumentation\n    #   transformers\npillow==11.2.1\n    # via text-generation-server (pyproject.toml)\nprometheus-client==0.21.1\n    # via text-generation-server (pyproject.toml)\nprotobuf==5.29.4\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   googleapis-common-protos\n    #   grpcio-reflection\n    #   grpcio-status\n    #   grpcio-tools\n    #   mypy-protobuf\n    #   opentelemetry-proto\npy-cpuinfo==9.0.0\n    # via text-generation-server (pyproject.toml)\npygments==2.19.1\n    # via rich\npyyaml==6.0.2\n    # via\n    #   huggingface-hub\n    #   transformers\nregex==2024.11.6\n    # via transformers\nrequests==2.32.3\n    # via\n    #   huggingface-hub\n    #   opentelemetry-exporter-otlp-proto-http\n    #   transformers\nrich==14.0.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\nsafetensors==0.5.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   transformers\nscipy==1.15.3\n    # via text-generation-server (pyproject.toml)\nsentencepiece==0.2.0\n    # via text-generation-server (pyproject.toml)\nsetuptools==80.4.0\n    # via grpcio-tools\nshellingham==1.5.4\n    # via typer\ntokenizers==0.21.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   transformers\ntqdm==4.67.1\n    # via\n    #   huggingface-hub\n    #   transformers\ntransformers==4.51.3\n    # via text-generation-server (pyproject.toml)\ntyper==0.15.3\n    # via text-generation-server (pyproject.toml)\ntypes-protobuf==6.30.2.20250506\n    # via mypy-protobuf\ntyping-extensions==4.13.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-sdk\n    #   typer\nurllib3==2.4.0\n    # via requests\nwrapt==1.17.2\n    # via\n    #   deprecated\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\nzipp==3.21.0\n    # via importlib-metadata\n"
  },
  {
    "path": "server/requirements_intel.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_intel.txt --python-version 3.11\naccelerate==1.6.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   peft\naiohappyeyeballs==2.6.1\n    # via aiohttp\naiohttp==3.11.18\n    # via\n    #   datasets\n    #   fsspec\naiosignal==1.3.2\n    # via aiohttp\nairportsdata==20250224\n    # via outlines\nannotated-types==0.7.0\n    # via pydantic\nattrs==25.3.0\n    # via\n    #   aiohttp\n    #   jsonschema\n    #   referencing\ncertifi==2025.4.26\n    # via requests\ncharset-normalizer==3.4.2\n    # via requests\nclick==8.1.8\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\ncloudpickle==3.1.1\n    # via outlines\ncompressed-tensors==0.9.4\n    # via text-generation-server (pyproject.toml)\ndatasets==2.21.0\n    # via text-generation-server (pyproject.toml)\ndeprecated==1.2.18\n    # via\n    #   opentelemetry-api\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-semantic-conventions\ndill==0.3.8\n    # via\n    #   datasets\n    #   multiprocess\ndiskcache==5.6.3\n    # via outlines\neinops==0.8.1\n    # via text-generation-server (pyproject.toml)\nfilelock==3.18.0\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\n    #   transformers\nfrozenlist==1.6.0\n    # via\n    #   aiohttp\n    #   aiosignal\nfsspec==2024.6.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\ngenson==1.3.0\n    # via outlines\ngoogleapis-common-protos==1.70.0\n    # via\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\ngrpc-interceptor==0.15.4\n    # via text-generation-server (pyproject.toml)\ngrpcio==1.71.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   grpc-interceptor\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\ngrpcio-reflection==1.71.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-status==1.71.0\n    # via text-generation-server (pyproject.toml)\nhf-transfer==0.1.9\n    # via text-generation-server (pyproject.toml)\nhf-xet==1.1.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   huggingface-hub\nhuggingface-hub==0.31.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   datasets\n    #   kernels\n    #   peft\n    #   tokenizers\n    #   transformers\nidna==3.10\n    # via\n    #   requests\n    #   yarl\nimportlib-metadata==8.6.1\n    # via opentelemetry-api\ninteregular==0.3.3\n    # via\n    #   outlines\n    #   outlines-core\niso3166==2.1.1\n    # via outlines\njinja2==3.1.6\n    # via\n    #   outlines\n    #   torch\njsonschema==4.23.0\n    # via\n    #   outlines\n    #   outlines-core\njsonschema-specifications==2025.4.1\n    # via jsonschema\nkernels==0.5.0\n    # via text-generation-server (pyproject.toml)\nlark==1.2.2\n    # via outlines\nloguru==0.7.3\n    # via text-generation-server (pyproject.toml)\nmarkdown-it-py==3.0.0\n    # via rich\nmarkupsafe==3.0.2\n    # via jinja2\nmdurl==0.1.2\n    # via markdown-it-py\nmpmath==1.3.0\n    # via sympy\nmultidict==6.4.3\n    # via\n    #   aiohttp\n    #   yarl\nmultiprocess==0.70.16\n    # via datasets\nnest-asyncio==1.6.0\n    # via outlines\nnetworkx==3.4.2\n    # via torch\nnumpy==2.2.5\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   datasets\n    #   outlines\n    #   pandas\n    #   peft\n    #   scipy\n    #   transformers\nnvidia-cublas-cu12==12.6.4.1\n    # via\n    #   nvidia-cudnn-cu12\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cuda-cupti-cu12==12.6.80\n    # via torch\nnvidia-cuda-nvrtc-cu12==12.6.77\n    # via torch\nnvidia-cuda-runtime-cu12==12.6.77\n    # via torch\nnvidia-cudnn-cu12==9.5.1.17\n    # via torch\nnvidia-cufft-cu12==11.3.0.4\n    # via torch\nnvidia-cufile-cu12==1.11.1.6\n    # via torch\nnvidia-curand-cu12==10.3.7.77\n    # via torch\nnvidia-cusolver-cu12==11.7.1.2\n    # via torch\nnvidia-cusparse-cu12==12.5.4.2\n    # via\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cusparselt-cu12==0.6.3\n    # via torch\nnvidia-nccl-cu12==2.26.2\n    # via torch\nnvidia-nvjitlink-cu12==12.6.85\n    # via\n    #   nvidia-cufft-cu12\n    #   nvidia-cusolver-cu12\n    #   nvidia-cusparse-cu12\n    #   torch\nnvidia-nvtx-cu12==12.6.77\n    # via torch\nopentelemetry-api==1.33.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\n    #   opentelemetry-semantic-conventions\nopentelemetry-exporter-otlp==1.33.0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-exporter-otlp-proto-common==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-exporter-otlp-proto-grpc==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-exporter-otlp-proto-http==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-instrumentation==0.54b0\n    # via opentelemetry-instrumentation-grpc\nopentelemetry-instrumentation-grpc==0.54b0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-proto==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-common\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-sdk==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-semantic-conventions==0.54b0\n    # via\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\noutlines==0.2.3\n    # via text-generation-server (pyproject.toml)\noutlines-core==0.1.26\n    # via outlines\npackaging==25.0\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   kernels\n    #   opentelemetry-instrumentation\n    #   peft\n    #   transformers\npandas==2.2.3\n    # via datasets\npeft==0.15.2\n    # via text-generation-server (pyproject.toml)\npillow==11.2.1\n    # via text-generation-server (pyproject.toml)\nprometheus-client==0.21.1\n    # via text-generation-server (pyproject.toml)\npropcache==0.3.1\n    # via\n    #   aiohttp\n    #   yarl\nprotobuf==5.29.4\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   googleapis-common-protos\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-proto\npsutil==7.0.0\n    # via\n    #   accelerate\n    #   peft\npy-cpuinfo==9.0.0\n    # via text-generation-server (pyproject.toml)\npyarrow==20.0.0\n    # via datasets\npydantic==2.11.4\n    # via\n    #   compressed-tensors\n    #   outlines\npydantic-core==2.33.2\n    # via pydantic\npygments==2.19.1\n    # via rich\npython-dateutil==2.9.0.post0\n    # via pandas\npytz==2025.2\n    # via pandas\npyyaml==6.0.2\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   peft\n    #   transformers\nreferencing==0.36.2\n    # via\n    #   jsonschema\n    #   jsonschema-specifications\n    #   outlines\nregex==2024.11.6\n    # via transformers\nrequests==2.32.3\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   opentelemetry-exporter-otlp-proto-http\n    #   outlines\n    #   transformers\nrich==14.0.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\nrpds-py==0.24.0\n    # via\n    #   jsonschema\n    #   referencing\nsafetensors==0.5.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   peft\n    #   transformers\nscipy==1.15.3\n    # via text-generation-server (pyproject.toml)\nsentencepiece==0.2.0\n    # via text-generation-server (pyproject.toml)\nsetuptools==80.4.0\n    # via triton\nshellingham==1.5.4\n    # via typer\nsix==1.17.0\n    # via python-dateutil\nsympy==1.14.0\n    # via torch\ntexttable==1.7.0\n    # via text-generation-server (pyproject.toml)\ntokenizers==0.21.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   transformers\ntorch==2.7.0\n    # via\n    #   accelerate\n    #   compressed-tensors\n    #   outlines\n    #   peft\ntqdm==4.67.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   outlines\n    #   peft\n    #   transformers\ntransformers==4.51.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   compressed-tensors\n    #   peft\ntriton==3.3.0\n    # via torch\ntyper==0.15.3\n    # via text-generation-server (pyproject.toml)\ntyping-extensions==4.13.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-sdk\n    #   outlines\n    #   pydantic\n    #   pydantic-core\n    #   referencing\n    #   torch\n    #   typer\n    #   typing-inspection\ntyping-inspection==0.4.0\n    # via pydantic\ntzdata==2025.2\n    # via pandas\nurllib3==2.4.0\n    # via requests\nwrapt==1.17.2\n    # via\n    #   deprecated\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\nxxhash==3.5.0\n    # via datasets\nyarl==1.20.0\n    # via aiohttp\nzipp==3.21.0\n    # via importlib-metadata\n"
  },
  {
    "path": "server/requirements_rocm.txt",
    "content": "# This file was autogenerated by uv via the following command:\n#    uv pip compile pyproject.toml --extra accelerate --extra compressed-tensors --extra quantize --extra peft --extra outlines -o requirements_rocm.txt --python-version 3.11\naccelerate==1.6.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   peft\naiohappyeyeballs==2.6.1\n    # via aiohttp\naiohttp==3.11.18\n    # via\n    #   datasets\n    #   fsspec\naiosignal==1.3.2\n    # via aiohttp\nairportsdata==20250224\n    # via outlines\nannotated-types==0.7.0\n    # via pydantic\nattrs==25.3.0\n    # via\n    #   aiohttp\n    #   jsonschema\n    #   referencing\ncertifi==2025.4.26\n    # via requests\ncharset-normalizer==3.4.2\n    # via requests\nclick==8.1.8\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\ncloudpickle==3.1.1\n    # via outlines\ncompressed-tensors==0.9.4\n    # via text-generation-server (pyproject.toml)\ndatasets==2.21.0\n    # via text-generation-server (pyproject.toml)\ndeprecated==1.2.18\n    # via\n    #   opentelemetry-api\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-semantic-conventions\ndill==0.3.8\n    # via\n    #   datasets\n    #   multiprocess\ndiskcache==5.6.3\n    # via outlines\neinops==0.8.1\n    # via text-generation-server (pyproject.toml)\nfilelock==3.18.0\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\n    #   transformers\nfrozenlist==1.6.0\n    # via\n    #   aiohttp\n    #   aiosignal\nfsspec==2024.6.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   torch\ngenson==1.3.0\n    # via outlines\ngoogleapis-common-protos==1.70.0\n    # via\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\ngrpc-interceptor==0.15.4\n    # via text-generation-server (pyproject.toml)\ngrpcio==1.71.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   grpc-interceptor\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-exporter-otlp-proto-grpc\ngrpcio-reflection==1.71.0\n    # via text-generation-server (pyproject.toml)\ngrpcio-status==1.71.0\n    # via text-generation-server (pyproject.toml)\nhf-transfer==0.1.9\n    # via text-generation-server (pyproject.toml)\nhf-xet==1.1.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   huggingface-hub\nhuggingface-hub==0.31.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   datasets\n    #   kernels\n    #   peft\n    #   tokenizers\n    #   transformers\nidna==3.10\n    # via\n    #   requests\n    #   yarl\nimportlib-metadata==8.6.1\n    # via opentelemetry-api\ninteregular==0.3.3\n    # via\n    #   outlines\n    #   outlines-core\niso3166==2.1.1\n    # via outlines\njinja2==3.1.6\n    # via\n    #   outlines\n    #   torch\njsonschema==4.23.0\n    # via\n    #   outlines\n    #   outlines-core\njsonschema-specifications==2025.4.1\n    # via jsonschema\nkernels==0.5.0\n    # via text-generation-server (pyproject.toml)\nlark==1.2.2\n    # via outlines\nloguru==0.7.3\n    # via text-generation-server (pyproject.toml)\nmarkdown-it-py==3.0.0\n    # via rich\nmarkupsafe==3.0.2\n    # via jinja2\nmdurl==0.1.2\n    # via markdown-it-py\nmpmath==1.3.0\n    # via sympy\nmultidict==6.4.3\n    # via\n    #   aiohttp\n    #   yarl\nmultiprocess==0.70.16\n    # via datasets\nnest-asyncio==1.6.0\n    # via outlines\nnetworkx==3.4.2\n    # via torch\nnumpy==2.2.5\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   datasets\n    #   outlines\n    #   pandas\n    #   peft\n    #   scipy\n    #   transformers\nnvidia-cublas-cu12==12.6.4.1\n    # via\n    #   nvidia-cudnn-cu12\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cuda-cupti-cu12==12.6.80\n    # via torch\nnvidia-cuda-nvrtc-cu12==12.6.77\n    # via torch\nnvidia-cuda-runtime-cu12==12.6.77\n    # via torch\nnvidia-cudnn-cu12==9.5.1.17\n    # via torch\nnvidia-cufft-cu12==11.3.0.4\n    # via torch\nnvidia-cufile-cu12==1.11.1.6\n    # via torch\nnvidia-curand-cu12==10.3.7.77\n    # via torch\nnvidia-cusolver-cu12==11.7.1.2\n    # via torch\nnvidia-cusparse-cu12==12.5.4.2\n    # via\n    #   nvidia-cusolver-cu12\n    #   torch\nnvidia-cusparselt-cu12==0.6.3\n    # via torch\nnvidia-nccl-cu12==2.26.2\n    # via torch\nnvidia-nvjitlink-cu12==12.6.85\n    # via\n    #   nvidia-cufft-cu12\n    #   nvidia-cusolver-cu12\n    #   nvidia-cusparse-cu12\n    #   torch\nnvidia-nvtx-cu12==12.6.77\n    # via torch\nopentelemetry-api==1.33.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\n    #   opentelemetry-semantic-conventions\nopentelemetry-exporter-otlp==1.33.0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-exporter-otlp-proto-common==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-exporter-otlp-proto-grpc==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-exporter-otlp-proto-http==1.33.0\n    # via opentelemetry-exporter-otlp\nopentelemetry-instrumentation==0.54b0\n    # via opentelemetry-instrumentation-grpc\nopentelemetry-instrumentation-grpc==0.54b0\n    # via text-generation-server (pyproject.toml)\nopentelemetry-proto==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-common\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-sdk==1.33.0\n    # via\n    #   opentelemetry-exporter-otlp-proto-grpc\n    #   opentelemetry-exporter-otlp-proto-http\nopentelemetry-semantic-conventions==0.54b0\n    # via\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\n    #   opentelemetry-sdk\noutlines==0.2.3\n    # via text-generation-server (pyproject.toml)\noutlines-core==0.1.26\n    # via outlines\npackaging==25.0\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   kernels\n    #   opentelemetry-instrumentation\n    #   peft\n    #   transformers\npandas==2.2.3\n    # via datasets\npeft==0.15.2\n    # via text-generation-server (pyproject.toml)\npillow==11.2.1\n    # via text-generation-server (pyproject.toml)\nprometheus-client==0.21.1\n    # via text-generation-server (pyproject.toml)\npropcache==0.3.1\n    # via\n    #   aiohttp\n    #   yarl\nprotobuf==5.29.4\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   googleapis-common-protos\n    #   grpcio-reflection\n    #   grpcio-status\n    #   opentelemetry-proto\npsutil==7.0.0\n    # via\n    #   accelerate\n    #   peft\npy-cpuinfo==9.0.0\n    # via text-generation-server (pyproject.toml)\npyarrow==20.0.0\n    # via datasets\npydantic==2.11.4\n    # via\n    #   compressed-tensors\n    #   outlines\npydantic-core==2.33.2\n    # via pydantic\npygments==2.19.1\n    # via rich\npython-dateutil==2.9.0.post0\n    # via pandas\npytz==2025.2\n    # via pandas\npyyaml==6.0.2\n    # via\n    #   accelerate\n    #   datasets\n    #   huggingface-hub\n    #   peft\n    #   transformers\nreferencing==0.36.2\n    # via\n    #   jsonschema\n    #   jsonschema-specifications\n    #   outlines\nregex==2024.11.6\n    # via transformers\nrequests==2.32.3\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   opentelemetry-exporter-otlp-proto-http\n    #   outlines\n    #   transformers\nrich==14.0.0\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   typer\nrpds-py==0.24.0\n    # via\n    #   jsonschema\n    #   referencing\nsafetensors==0.5.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   accelerate\n    #   peft\n    #   transformers\nscipy==1.15.3\n    # via text-generation-server (pyproject.toml)\nsentencepiece==0.2.0\n    # via text-generation-server (pyproject.toml)\nsetuptools==80.4.0\n    # via triton\nshellingham==1.5.4\n    # via typer\nsix==1.17.0\n    # via python-dateutil\nsympy==1.14.0\n    # via torch\ntexttable==1.7.0\n    # via text-generation-server (pyproject.toml)\ntokenizers==0.21.1\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   transformers\ntorch==2.7.0\n    # via\n    #   accelerate\n    #   compressed-tensors\n    #   outlines\n    #   peft\ntqdm==4.67.1\n    # via\n    #   datasets\n    #   huggingface-hub\n    #   outlines\n    #   peft\n    #   transformers\ntransformers==4.51.3\n    # via\n    #   text-generation-server (pyproject.toml)\n    #   compressed-tensors\n    #   peft\ntriton==3.3.0\n    # via torch\ntyper==0.15.3\n    # via text-generation-server (pyproject.toml)\ntyping-extensions==4.13.2\n    # via\n    #   huggingface-hub\n    #   opentelemetry-sdk\n    #   outlines\n    #   pydantic\n    #   pydantic-core\n    #   referencing\n    #   torch\n    #   typer\n    #   typing-inspection\ntyping-inspection==0.4.0\n    # via pydantic\ntzdata==2025.2\n    # via pandas\nurllib3==2.4.0\n    # via requests\nwrapt==1.17.2\n    # via\n    #   deprecated\n    #   opentelemetry-instrumentation\n    #   opentelemetry-instrumentation-grpc\nxxhash==3.5.0\n    # via datasets\nyarl==1.20.0\n    # via aiohttp\nzipp==3.21.0\n    # via importlib-metadata\n"
  },
  {
    "path": "server/tests/conftest.py",
    "content": "import pytest\nimport os\nfrom text_generation_server.pb import generate_pb2\n\nos.environ[\"PREFIX_CACHING\"] = \"1\"\nos.environ[\"ATTENTION\"] = \"flashinfer\"\n\n\n@pytest.fixture\ndef default_pb_parameters():\n    return generate_pb2.NextTokenChooserParameters(\n        temperature=1.0,\n        repetition_penalty=1.0,\n        top_k=0,\n        top_p=1.0,\n        typical_p=1.0,\n        do_sample=False,\n    )\n\n\n@pytest.fixture\ndef default_pb_stop_parameters():\n    return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)\n"
  },
  {
    "path": "server/tests/models/test_bloom.py",
    "content": "import pytest\nimport torch\n\nfrom copy import copy\nfrom transformers import AutoTokenizer\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.causal_lm import CausalLMBatch\nfrom text_generation_server.utils import weight_hub_files, download_weights\nfrom text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded\nfrom text_generation_server.models.custom_modeling.bloom_modeling import (\n    BloomForCausalLM,\n)\n\n\n@pytest.fixture(scope=\"session\")\ndef default_bloom():\n    model_id = \"bigscience/bloom-560m\"\n    revision = \"main\"\n    filenames = weight_hub_files(model_id, revision, \".safetensors\")\n    download_weights(filenames, model_id, revision)\n    return BLOOMSharded(\n        model_id,\n        model_class=BloomForCausalLM,\n    )\n\n\n@pytest.fixture(scope=\"session\")\ndef bloom_560m_tokenizer():\n    return AutoTokenizer.from_pretrained(\"bigscience/bloom-560m\", padding_side=\"left\")\n\n\n@pytest.fixture\ndef default_pb_request(default_pb_parameters, default_pb_stop_parameters):\n    return generate_pb2.Request(\n        id=0,\n        inputs=\"Test\",\n        input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text=\"Test\")]),\n        prefill_logprobs=True,\n        truncate=100,\n        parameters=default_pb_parameters,\n        stopping_parameters=default_pb_stop_parameters,\n    )\n\n\n@pytest.fixture\ndef default_pb_batch(default_pb_request):\n    return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)\n\n\n@pytest.fixture\ndef default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):\n    return BloomCausalLMBatch.from_pb(\n        default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\n@pytest.fixture\ndef default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer):\n    req_0 = copy(default_pb_request)\n    req_0.id = 1\n    req_1 = default_pb_request\n    req_1.id = 2\n    req_1.stopping_parameters.max_new_tokens = 5\n\n    batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)\n    return BloomCausalLMBatch.from_pb(\n        batch_pb, bloom_560m_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\ndef test_batch_from_pb(default_pb_batch, default_bloom_batch):\n    batch = default_bloom_batch\n\n    assert batch.batch_id == default_pb_batch.id\n    assert batch.requests == default_pb_batch.requests\n\n    assert len(batch.input_ids) == default_pb_batch.size\n    assert batch.input_ids[0][-1] == 10264\n    assert torch.all(batch.input_ids[0][:-1] == 3)\n\n    assert batch.attention_mask[0][0] == 1\n    assert torch.all(batch.attention_mask[0][1:] == 0)\n\n    assert batch.past_key_values is None\n\n    assert all(\n        [\n            torch.equal(input_ids, all_input_ids[:, 0])\n            for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)\n        ]\n    )\n\n    assert batch.input_lengths == [1]\n\n    assert len(batch) == default_pb_batch.size\n    assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)\n\n    assert batch.max_input_length == batch.input_lengths[0]\n\n\ndef test_batch_concatenate_no_prefill(default_bloom_batch):\n    with pytest.raises(ValueError):\n        BloomCausalLMBatch.concatenate([default_bloom_batch, default_bloom_batch])\n\n\ndef test_causal_lm_batch_type(default_bloom):\n    assert default_bloom.batch_type == BloomCausalLMBatch\n\n\ndef test_causal_lm_generate_token(default_bloom, default_bloom_batch):\n    sequence_length = len(default_bloom_batch.all_input_ids[0])\n    generations, next_batch, _ = default_bloom.generate_token(default_bloom_batch)\n\n    assert len(generations) == len(default_bloom_batch)\n    assert isinstance(next_batch, CausalLMBatch)\n    assert not next_batch.keys_head_dim_last\n\n    assert len(next_batch.all_input_ids) == len(next_batch)\n    assert len(next_batch.all_input_ids[0]) == sequence_length + 1\n    assert len(next_batch.attention_mask[0]) == 11\n    assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)\n    assert torch.all(next_batch.all_input_ids[0][:-2] == 3)\n\n    assert torch.all(next_batch.attention_mask[0][:2] == 1)\n    assert torch.all(next_batch.attention_mask[0][2:] == 0)\n\n    assert next_batch.input_ids.shape == (len(next_batch), 1)\n    assert next_batch.input_ids[0, 0] == 10264\n\n    assert next_batch.input_lengths == [2]\n    assert next_batch.max_input_length == next_batch.input_lengths[0]\n\n    assert next_batch.past_key_values is not None\n    assert all(\n        [p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]\n    )\n    assert all([generation.generated_text is None for generation in generations])\n    assert all([len(generation.prefill_tokens) == 1 for generation in generations])\n    assert all(\n        [\n            token_id.item() == 10264\n            for generation in generations\n            for token_id in generation.tokens.token_ids\n        ]\n    )\n    assert all(\n        [\n            token_text == \"Test\"\n            for generation in generations\n            for token_text in generation.tokens.texts\n        ]\n    )\n    assert generations[0].request_id == 0\n\n\ndef test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):\n    next_batch = default_bloom_batch\n    for _ in range(default_bloom_batch.stopping_criterias[0].max_new_tokens - 1):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(default_bloom_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert (\n        generations[0].generated_text.text == \"TestTestTestTestTestTestTestTestTestTest\"\n    )\n    assert generations[0].request_id == default_bloom_batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_bloom_batch.stopping_criterias[0].max_new_tokens\n    )\n\n\ndef test_causal_lm_generate_token_completion_multi(\n    default_bloom, default_multi_requests_bloom_batch\n):\n    next_batch = default_multi_requests_bloom_batch\n\n    for i in range(\n        default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 1\n    ):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(default_multi_requests_bloom_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert generations[1].generated_text.text == \"TestTestTestTestTest\"\n    assert (\n        generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id\n    )\n    assert (\n        generations[1].generated_text.generated_tokens\n        == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens\n    )\n    # Copy stopping_criterias before filtering\n    stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()\n\n    next_batch = next_batch.filter([next_batch.requests[0].id])\n\n    for _ in range(\n        stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1\n    ):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert (\n        generations[0].generated_text.text == \"TestTestTestTestTestTestTestTestTestTest\"\n    )\n    assert (\n        generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id\n    )\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens\n    )\n\n\ndef test_batch_concatenate(\n    default_bloom, default_bloom_batch, default_multi_requests_bloom_batch\n):\n    next_batch_0 = default_bloom_batch\n    _, next_batch_0, _ = default_bloom.generate_token(next_batch_0)\n    _, next_batch_0, _ = default_bloom.generate_token(next_batch_0)\n\n    next_batch_1 = default_multi_requests_bloom_batch\n    _, next_batch_1, _ = default_bloom.generate_token(next_batch_1)\n\n    # Clone past_key_values before concatenating to compare after,\n    # because they are removed from the concatenated batches\n    next_batch_0_past_key_values = [\n        (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values\n    ]\n    next_batch_1_past_key_values = [\n        (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values\n    ]\n\n    next_batch = BloomCausalLMBatch.concatenate([next_batch_0, next_batch_1])\n\n    assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])\n    assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])\n    assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])\n\n    assert torch.all(\n        next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1\n    )\n    assert torch.all(\n        next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1\n    )\n    assert torch.all(next_batch.attention_mask[1:, 3:] == 0)\n\n    assert next_batch.batch_id == 0\n    assert torch.all(next_batch.input_ids == 10264)\n\n    assert next_batch.input_lengths == [3, 2, 2]\n    assert next_batch.max_input_length == 3\n\n    assert next_batch.requests[0] == next_batch_0.requests[0]\n    assert next_batch.requests[1:] == list(next_batch_1.requests)\n\n    assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]\n    assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers\n\n    assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]\n    assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias\n\n    assert next_batch.past_key_values is not None\n    assert all([p[0].shape == (3, 16, 64, 2) for p in next_batch.past_key_values])\n    assert all([p[1].shape == (3, 16, 2, 64) for p in next_batch.past_key_values])\n\n    for i, past in enumerate(next_batch.past_key_values):\n        assert torch.equal(next_batch_0_past_key_values[i][0][:, :, -2:], past[0][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][0][:, :, -1:],\n            past[0][1:, :, :, -1].reshape(-1, 64, 1),\n        )\n\n        assert torch.equal(next_batch_0_past_key_values[i][1][:, -2:, :], past[1][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][1][:, -1:, :],\n            past[1][1:, :, -1, :].reshape(-1, 1, 64),\n        )\n\n    for _ in range(\n        default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens - 2\n    ):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 3\n    assert generations[2].generated_text.text == \"TestTestTestTestTest\"\n    assert (\n        generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id\n    )\n    assert (\n        generations[2].generated_text.generated_tokens\n        == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens\n    )\n\n    next_batch = next_batch.filter(\n        [next_batch.requests[0].id, next_batch.requests[1].id]\n    )\n\n    for _ in range(\n        default_bloom_batch.stopping_criterias[0].max_new_tokens\n        - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens\n        - 2\n    ):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert (\n        generations[0].generated_text.text == \"TestTestTestTestTestTestTestTestTestTest\"\n    )\n    assert generations[0].request_id == default_bloom_batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_bloom_batch.stopping_criterias[0].max_new_tokens\n    )\n\n    next_batch = next_batch.filter([next_batch.requests[1].id])\n\n    for _ in range(\n        default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens\n        - default_bloom_batch.stopping_criterias[0].max_new_tokens\n        - default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens\n        - 4\n    ):\n        generations, next_batch, _ = default_bloom.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_bloom.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert (\n        generations[0].generated_text.text == \"TestTestTestTestTestTestTestTestTestTest\"\n    )\n    assert (\n        generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id\n    )\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens\n    )\n"
  },
  {
    "path": "server/tests/models/test_causal_lm.py",
    "content": "import pytest\nimport torch\n\nfrom copy import copy\nfrom transformers import AutoTokenizer\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.causal_lm import CausalLM, CausalLMBatch\n\n\n@pytest.fixture(scope=\"session\")\ndef default_causal_lm():\n    return CausalLM.fallback(\"gpt2\")\n\n\n@pytest.fixture(scope=\"session\")\ndef gpt2_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\"gpt2\", padding_side=\"left\")\n    tokenizer.pad_token_id = 50256\n    return tokenizer\n\n\n@pytest.fixture\ndef default_pb_request(default_pb_parameters, default_pb_stop_parameters):\n    return generate_pb2.Request(\n        id=0,\n        inputs=\"Test\",\n        input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text=\"Test\")]),\n        prefill_logprobs=True,\n        truncate=100,\n        parameters=default_pb_parameters,\n        stopping_parameters=default_pb_stop_parameters,\n    )\n\n\n@pytest.fixture\ndef default_pb_batch(default_pb_request):\n    return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)\n\n\n@pytest.fixture\ndef default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):\n    return CausalLMBatch.from_pb(\n        default_pb_batch, gpt2_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\n@pytest.fixture\ndef default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):\n    req_0 = copy(default_pb_request)\n    req_0.id = 1\n    req_1 = default_pb_request\n    req_1.id = 2\n    req_1.stopping_parameters.max_new_tokens = 5\n\n    batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)\n    return CausalLMBatch.from_pb(\n        batch_pb, gpt2_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\ndef test_batch_from_pb(default_pb_batch, default_causal_lm_batch):\n    batch = default_causal_lm_batch\n\n    assert batch.batch_id == default_pb_batch.id\n    assert batch.requests == default_pb_batch.requests\n\n    assert len(batch.input_ids) == default_pb_batch.size\n    assert batch.input_ids[0][-1] == 14402\n    assert torch.all(batch.input_ids[0][:-1] == 50256)\n\n    assert batch.attention_mask[0, 0] == 1\n    assert torch.all(batch.attention_mask[0, 1:] == 0)\n\n    assert batch.past_key_values is None\n\n    assert all(\n        [\n            torch.equal(input_ids, all_input_ids[:, 0])\n            for input_ids, all_input_ids in zip(batch.input_ids, batch.all_input_ids)\n        ]\n    )\n\n    assert batch.input_lengths == [1]\n\n    assert len(batch) == default_pb_batch.size\n    assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)\n\n    assert batch.max_input_length == batch.input_lengths[0]\n\n\ndef test_batch_concatenate_no_prefill(default_causal_lm_batch):\n    with pytest.raises(ValueError):\n        CausalLMBatch.concatenate([default_causal_lm_batch, default_causal_lm_batch])\n\n\ndef test_causal_lm_batch_type(default_causal_lm):\n    assert default_causal_lm.batch_type == CausalLMBatch\n\n\ndef test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):\n    sequence_length = len(default_causal_lm_batch.all_input_ids[0])\n    generations, next_batch, _ = default_causal_lm.generate_token(\n        default_causal_lm_batch\n    )\n\n    assert len(generations) == len(next_batch)\n    assert isinstance(next_batch, CausalLMBatch)\n\n    assert len(next_batch.all_input_ids) == len(next_batch)\n    assert len(next_batch.all_input_ids[0]) == sequence_length + 1\n    assert len(next_batch.attention_mask[0]) == 11\n    assert next_batch.all_input_ids[0][-1] == 13\n    assert next_batch.all_input_ids[0][-2] == 14402\n    assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)\n\n    assert torch.all(next_batch.attention_mask[0][0:2] == 1)\n    assert torch.all(next_batch.attention_mask[0][2:] == 0)\n\n    assert next_batch.input_ids.shape == (len(next_batch), 1)\n    assert next_batch.input_ids[0, 0] == 13\n\n    assert next_batch.input_lengths == [2]\n    assert next_batch.max_input_length == next_batch.input_lengths[0]\n\n    assert next_batch.past_key_values is not None\n    assert all(\n        [p[0].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[1].shape == (1, 12, sequence_length, 64) for p in next_batch.past_key_values]\n    )\n    assert all([generation.generated_text is None for generation in generations])\n    assert all([len(generation.prefill_tokens) == 1 for generation in generations])\n    assert all(\n        [\n            token_id.item() == 13\n            for generation in generations\n            for token_id in generation.tokens.token_ids\n        ]\n    )\n    assert all(\n        [\n            token_text == \".\"\n            for generation in generations\n            for token_text in generation.tokens.texts\n        ]\n    )\n    assert generations[0].request_id == 0\n\n\ndef test_causal_lm_generate_token_completion(\n    default_causal_lm, default_causal_lm_batch\n):\n    next_batch = default_causal_lm_batch\n    for _ in range(default_causal_lm_batch.stopping_criterias[0].max_new_tokens - 1):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \".java:784) at net.minecraft.\"\n    assert generations[0].request_id == default_causal_lm_batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_causal_lm_batch.stopping_criterias[0].max_new_tokens\n    )\n\n\ndef test_causal_lm_generate_token_completion_multi(\n    default_causal_lm, default_multi_requests_causal_lm_batch\n):\n    next_batch = default_multi_requests_causal_lm_batch\n\n    for i in range(\n        default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 1\n    ):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert generations[1].generated_text.text == \".java:784)\"\n    assert (\n        generations[1].request_id\n        == default_multi_requests_causal_lm_batch.requests[1].id\n    )\n    assert (\n        generations[1].generated_text.generated_tokens\n        == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens\n    )\n    # Copy stopping_criterias before filtering\n    stopping_criterias = (\n        default_multi_requests_causal_lm_batch.stopping_criterias.copy()\n    )\n\n    next_batch = next_batch.filter([next_batch.requests[0].id])\n\n    for _ in range(\n        stopping_criterias[0].max_new_tokens - stopping_criterias[1].max_new_tokens - 1\n    ):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \".java:784) at net.minecraft.\"\n    assert (\n        generations[0].request_id\n        == default_multi_requests_causal_lm_batch.requests[0].id\n    )\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens\n    )\n\n\ndef test_batch_concatenate(\n    default_causal_lm, default_causal_lm_batch, default_multi_requests_causal_lm_batch\n):\n    next_batch_0 = default_causal_lm_batch\n    _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)\n    _, next_batch_0, _ = default_causal_lm.generate_token(next_batch_0)\n\n    next_batch_1 = default_multi_requests_causal_lm_batch\n    _, next_batch_1, _ = default_causal_lm.generate_token(next_batch_1)\n\n    # Clone past_key_values before concatenating to compare after,\n    # because they are removed from the concatenated batches\n    next_batch_0_past_key_values = [\n        (k.clone(), v.clone()) for (k, v) in next_batch_0.past_key_values\n    ]\n    next_batch_1_past_key_values = [\n        (k.clone(), v.clone()) for (k, v) in next_batch_1.past_key_values\n    ]\n\n    next_batch = CausalLMBatch.concatenate([next_batch_0, next_batch_1])\n\n    assert torch.equal(next_batch.all_input_ids[0], next_batch_0.all_input_ids[0])\n    assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])\n    assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])\n\n    assert torch.all(\n        next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1\n    )\n    assert torch.all(\n        next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1\n    )\n    assert torch.all(next_batch.attention_mask[1:, 3:] == 0)\n\n    assert next_batch.batch_id == 0\n    assert next_batch.input_ids[0, 0] == 12355\n    assert torch.all(next_batch.input_ids[1:] == 13)\n\n    assert next_batch.input_lengths == [3, 2, 2]\n    assert next_batch.max_input_length == 3\n\n    assert next_batch.requests[0] == next_batch_0.requests[0]\n    assert next_batch.requests[1:] == list(next_batch_1.requests)\n\n    assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]\n    assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers\n\n    assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]\n    assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias\n\n    assert next_batch.past_key_values is not None\n    assert all([p[0].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])\n    assert all([p[1].shape == (3, 12, 2, 64) for p in next_batch.past_key_values])\n\n    for i, past in enumerate(next_batch.past_key_values):\n        assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:], past[0][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][0][:, :, -1:], past[0][1:, :, -1:, :]\n        )\n\n        assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:], past[1][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][1][:, :, -1:], past[1][1:, :, -1:, :]\n        )\n\n    for _ in range(\n        default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens - 2\n    ):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 3\n    assert generations[2].generated_text.text == \".java:784)\"\n    assert (\n        generations[2].request_id\n        == default_multi_requests_causal_lm_batch.requests[1].id\n    )\n    assert (\n        generations[2].generated_text.generated_tokens\n        == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens\n    )\n\n    next_batch = next_batch.filter(\n        [next_batch.requests[0].id, next_batch.requests[1].id]\n    )\n\n    for _ in range(\n        default_causal_lm_batch.stopping_criterias[0].max_new_tokens\n        - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens\n        - 2\n    ):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert generations[0].generated_text.text == \".java:784) at net.minecraft.\"\n    assert generations[0].request_id == default_causal_lm_batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_causal_lm_batch.stopping_criterias[0].max_new_tokens\n    )\n\n    next_batch = next_batch.filter([next_batch.requests[1].id])\n\n    for _ in range(\n        default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens\n        - default_causal_lm_batch.stopping_criterias[0].max_new_tokens\n        - default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens\n        - 4\n    ):\n        generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_causal_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \".java:784) at net.minecraft.\"\n    assert (\n        generations[0].request_id\n        == default_multi_requests_causal_lm_batch.requests[0].id\n    )\n    assert (\n        generations[0].generated_text.generated_tokens\n        == default_multi_requests_causal_lm_batch.stopping_criterias[0].max_new_tokens\n    )\n"
  },
  {
    "path": "server/tests/models/test_model.py",
    "content": "import pytest\nimport torch\n\nfrom transformers import AutoTokenizer\n\nfrom text_generation_server.models import Model\n\n\ndef get_test_model():\n    class TestModel(Model):\n        def batch_type(self):\n            raise NotImplementedError\n\n        def generate_token(self, batch):\n            raise NotImplementedError\n\n    tokenizer = AutoTokenizer.from_pretrained(\"huggyllama/llama-7b\")\n\n    model = TestModel(\n        \"test_model_id\",\n        torch.nn.Linear(1, 1),\n        tokenizer,\n        False,\n        torch.float32,\n        torch.device(\"cpu\"),\n    )\n    return model\n\n\n@pytest.mark.private\ndef test_decode_streaming_english_spaces():\n    model = get_test_model()\n    truth = \"Hello here, this is a simple test\"\n    all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]\n    assert (\n        all_input_ids == model.tokenizer(truth, add_special_tokens=False)[\"input_ids\"]\n    )\n\n    decoded_text = \"\"\n    offset = 0\n    token_offset = 0\n    for i in range(len(all_input_ids)):\n        text, offset, token_offset = model.decode_token(\n            all_input_ids[: i + 1], offset, token_offset\n        )\n        decoded_text += text\n\n    assert decoded_text == truth\n\n\n@pytest.mark.private\ndef test_decode_streaming_chinese_utf8():\n    model = get_test_model()\n    truth = \"我很感谢你的热情\"\n    all_input_ids = [\n        30672,\n        232,\n        193,\n        139,\n        233,\n        135,\n        162,\n        235,\n        179,\n        165,\n        30919,\n        30210,\n        234,\n        134,\n        176,\n        30993,\n    ]\n\n    decoded_text = \"\"\n    offset = 0\n    token_offset = 0\n    for i in range(len(all_input_ids)):\n        text, offset, token_offset = model.decode_token(\n            all_input_ids[: i + 1], offset, token_offset\n        )\n        decoded_text += text\n\n    assert decoded_text == truth\n"
  },
  {
    "path": "server/tests/models/test_santacoder.py",
    "content": "import pytest\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.causal_lm import CausalLMBatch, CausalLM\n\n\n@pytest.fixture(scope=\"session\")\ndef default_santacoder():\n    return CausalLM.fallback(model_id=\"bigcode/santacoder\")\n\n\n@pytest.fixture\ndef default_pb_request(default_pb_parameters, default_pb_stop_parameters):\n    return generate_pb2.Request(\n        id=0,\n        inputs=\"def\",\n        input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text=\"def\")]),\n        prefill_logprobs=True,\n        truncate=100,\n        parameters=default_pb_parameters,\n        stopping_parameters=default_pb_stop_parameters,\n    )\n\n\n@pytest.fixture\ndef default_pb_batch(default_pb_request):\n    return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)\n\n\n@pytest.fixture\ndef default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):\n    return generate_pb2.Request(\n        id=0,\n        inputs=\"<fim-prefix>def<fim-suffix>world<fim-middle>\",\n        input_chunks=generate_pb2.Input(\n            chunks=[\n                generate_pb2.InputChunk(\n                    text=\"<fim-prefix>def<fim-suffix>world<fim-middle>\"\n                )\n            ]\n        ),\n        prefill_logprobs=True,\n        truncate=100,\n        parameters=default_pb_parameters,\n        stopping_parameters=default_pb_stop_parameters,\n    )\n\n\n@pytest.fixture\ndef default_fim_pb_batch(default_fim_pb_request):\n    return generate_pb2.Batch(id=0, requests=[default_fim_pb_request], size=1)\n\n\n@pytest.mark.skip\ndef test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):\n    batch = CausalLMBatch.from_pb(\n        default_pb_batch,\n        default_santacoder.tokenizer,\n        default_santacoder.dtype,\n        default_santacoder.device,\n    )\n    next_batch = batch\n\n    for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):\n        generations, next_batch, _ = default_santacoder.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_santacoder.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \" test_get_all_users_with_\"\n    assert generations[0].request_id == batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == batch.stopping_criterias[0].max_new_tokens\n    )\n\n\n@pytest.mark.skip\ndef test_fim_santacoder_generate_token_completion(\n    default_santacoder, default_fim_pb_batch\n):\n    batch = CausalLMBatch.from_pb(\n        default_fim_pb_batch,\n        default_santacoder.tokenizer,\n        default_santacoder.dtype,\n        default_santacoder.device,\n    )\n    next_batch = batch\n\n    for _ in range(batch.stopping_criterias[0].max_new_tokens - 1):\n        generations, next_batch, _ = default_santacoder.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_santacoder.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert (\n        generations[0].generated_text.text\n        == \"\"\"ineProperty(exports, \"__esModule\", { value\"\"\"\n    )\n    assert generations[0].request_id == batch.requests[0].id\n    assert (\n        generations[0].generated_text.generated_tokens\n        == batch.stopping_criterias[0].max_new_tokens\n    )\n"
  },
  {
    "path": "server/tests/models/test_seq2seq_lm.py",
    "content": "import pytest\nimport torch\n\nfrom copy import copy\n\nfrom transformers import AutoTokenizer\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch\n\n\n@pytest.fixture(scope=\"session\")\ndef mt0_small_tokenizer():\n    tokenizer = AutoTokenizer.from_pretrained(\n        \"bigscience/mt0-small\", padding_side=\"left\"\n    )\n    tokenizer.bos_token_id = 0\n    return tokenizer\n\n\n@pytest.fixture(scope=\"session\")\ndef default_seq2seq_lm():\n    return Seq2SeqLM.fallback(\"bigscience/mt0-small\")\n\n\n@pytest.fixture\ndef default_pb_request(default_pb_parameters, default_pb_stop_parameters):\n    return generate_pb2.Request(\n        id=0,\n        inputs=\"Test\",\n        input_chunks=generate_pb2.Input(chunks=[generate_pb2.InputChunk(text=\"Test\")]),\n        prefill_logprobs=True,\n        truncate=100,\n        parameters=default_pb_parameters,\n        stopping_parameters=default_pb_stop_parameters,\n    )\n\n\n@pytest.fixture\ndef default_pb_batch(default_pb_request):\n    return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1)\n\n\n@pytest.fixture\ndef default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):\n    return Seq2SeqLMBatch.from_pb(\n        default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\n@pytest.fixture\ndef default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokenizer):\n    req_0 = copy(default_pb_request)\n    req_0.id = 1\n    req_1 = default_pb_request\n    req_1.id = 2\n    req_1.stopping_parameters.max_new_tokens = 5\n\n    batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)\n    return Seq2SeqLMBatch.from_pb(\n        batch_pb, mt0_small_tokenizer, torch.float32, torch.device(\"cpu\")\n    )\n\n\ndef test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):\n    batch = default_seq2seq_lm_batch\n    sequence_length = len(default_seq2seq_lm_batch.input_ids[0])\n\n    assert batch.batch_id == default_pb_batch.id\n    assert batch.requests == default_pb_batch.requests\n\n    assert batch.input_ids.shape == (default_pb_batch.size, sequence_length)\n    assert batch.input_ids[0][-2] == 4268\n    assert batch.input_ids[0][-1] == 1\n    assert torch.all(batch.input_ids[0][:-2] == 0)\n\n    assert torch.all(batch.attention_mask[0][-2:] == 1)\n    assert torch.all(batch.attention_mask[0][:-2] == 0)\n\n    assert len(batch.decoder_input_ids) == default_pb_batch.size\n    assert batch.decoder_attention_mask is None\n    assert batch.encoder_last_hidden_state is None\n\n    assert batch.past_key_values is None\n\n    assert batch.input_lengths == [2]\n    assert batch.decoder_input_lengths == [1]\n\n    assert len(batch) == default_pb_batch.size\n    assert len(batch.next_token_choosers) == len(batch.stopping_criterias) == len(batch)\n\n    assert batch.max_input_length == batch.input_lengths[0]\n    assert batch.max_decoder_input_length == batch.decoder_input_lengths[0]\n\n\ndef test_batch_concatenate_no_prefill(default_seq2seq_lm_batch):\n    with pytest.raises(ValueError):\n        Seq2SeqLMBatch.concatenate([default_seq2seq_lm_batch, default_seq2seq_lm_batch])\n\n\ndef test_seq2seq_lm_batch_type(default_seq2seq_lm):\n    assert default_seq2seq_lm.batch_type == Seq2SeqLMBatch\n\n\ndef test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch):\n    sequence_length = len(default_seq2seq_lm_batch.input_ids[0])\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(\n        default_seq2seq_lm_batch\n    )\n\n    assert len(generations) == len(next_batch)\n    assert isinstance(next_batch, Seq2SeqLMBatch)\n\n    assert next_batch.input_ids is None\n    assert torch.equal(\n        next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask\n    )\n    assert next_batch.input_lengths == default_seq2seq_lm_batch.input_lengths\n    assert next_batch.max_input_length == default_seq2seq_lm_batch.max_input_length\n    assert (\n        next_batch.next_token_choosers == default_seq2seq_lm_batch.next_token_choosers\n    )\n    assert next_batch.stopping_criterias == default_seq2seq_lm_batch.stopping_criterias\n\n    assert len(next_batch.decoder_input_ids) == len(next_batch)\n    assert next_batch.all_decoder_input_ids[0][0] == 0\n    assert next_batch.all_decoder_input_ids[0][1] == 259\n    assert next_batch.decoder_attention_mask is None\n    assert next_batch.encoder_last_hidden_state.shape == (1, sequence_length, 512)\n\n    assert next_batch.decoder_input_lengths == [2]\n    assert next_batch.max_decoder_input_length == 2\n\n    assert next_batch.past_key_values is not None\n    assert all(\n        [p[0].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[1].shape == (len(next_batch), 6, 1, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [\n            p[2].shape == (len(next_batch), 6, sequence_length, 64)\n            for p in next_batch.past_key_values\n        ]\n    )\n    assert all(\n        [\n            p[3].shape == (len(next_batch), 6, sequence_length, 64)\n            for p in next_batch.past_key_values\n        ]\n    )\n    assert all([generation.generated_text is None for generation in generations])\n    assert all([len(generation.prefill_tokens) == 1 for generation in generations])\n    assert all(\n        [\n            token_id.item() == 259\n            for generation in generations\n            for token_id in generation.tokens.token_ids\n        ]\n    )\n    assert all(\n        [\n            token_text == \" \"\n            for generation in generations\n            for token_text in generation.tokens.texts\n        ]\n    )\n    assert generations[0].request_id == 0\n\n\ndef test_seq2seq_lm_generate_token_completion(\n    default_seq2seq_lm, default_seq2seq_lm_batch\n):\n    next_batch = default_seq2seq_lm_batch\n    for _ in range(6):\n        generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \"a few weeks\"\n    assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id\n    assert generations[0].generated_text.generated_tokens == 7\n\n\ndef test_seq2seq_lm_generate_token_completion_multi(\n    default_seq2seq_lm, default_multi_requests_seq2seq_lm_batch\n):\n    next_batch = default_multi_requests_seq2seq_lm_batch\n\n    for i in range(4):\n        generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert generations[1].generated_text.text == \"a few \"\n    assert (\n        generations[1].request_id\n        == default_multi_requests_seq2seq_lm_batch.requests[1].id\n    )\n    assert generations[1].generated_text.generated_tokens == 5\n\n    next_batch = next_batch.filter([next_batch.requests[0].id])\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \"a few weeks\"\n    assert (\n        generations[0].request_id\n        == default_multi_requests_seq2seq_lm_batch.requests[0].id\n    )\n    assert generations[0].generated_text.generated_tokens == 7\n\n\ndef test_batch_concatenate(\n    default_seq2seq_lm,\n    default_seq2seq_lm_batch,\n    default_multi_requests_seq2seq_lm_batch,\n):\n    next_batch_0 = default_seq2seq_lm_batch\n    _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)\n    _, next_batch_0, _ = default_seq2seq_lm.generate_token(next_batch_0)\n\n    next_batch_1 = default_multi_requests_seq2seq_lm_batch\n    _, next_batch_1, _ = default_seq2seq_lm.generate_token(next_batch_1)\n\n    # Copy hidden state because it is removed from the concatenated branches\n    next_batch_0_encoder_last_hidden_state = next_batch_0.encoder_last_hidden_state\n    next_batch_1_encoder_last_hidden_state = next_batch_1.encoder_last_hidden_state\n\n    # Clone past_key_values before concatenating to compare after,\n    # because they are removed from the concatenated batches\n    next_batch_0_past_key_values = [\n        [t.clone() for t in layer] for layer in next_batch_0.past_key_values\n    ]\n    next_batch_1_past_key_values = [\n        [t.clone() for t in layer] for layer in next_batch_1.past_key_values\n    ]\n\n    next_batch = Seq2SeqLMBatch.concatenate([next_batch_0, next_batch_1])\n\n    assert next_batch.batch_id == 0\n\n    assert torch.equal(\n        next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]\n    )\n    assert next_batch.all_decoder_input_ids[1][0] == 0\n    assert next_batch.all_decoder_input_ids[2][0] == 0\n    assert torch.equal(\n        next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids\n    )\n\n    assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)\n    assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)\n    assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)\n    assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)\n\n    assert torch.equal(\n        next_batch.encoder_last_hidden_state[0],\n        next_batch_0_encoder_last_hidden_state[0, -2:],\n    )\n    assert torch.equal(\n        next_batch.encoder_last_hidden_state[1:],\n        next_batch_1_encoder_last_hidden_state[:, -2:],\n    )\n\n    assert next_batch.input_lengths == [2, 2, 2]\n    assert next_batch.decoder_input_lengths == [3, 2, 2]\n    assert next_batch.max_input_length == 2\n    assert next_batch.max_decoder_input_length == 3\n\n    assert next_batch.requests[0] == next_batch_0.requests[0]\n    assert next_batch.requests[1:] == list(next_batch_1.requests)\n\n    assert next_batch.next_token_choosers[0] == next_batch_0.next_token_choosers[0]\n    assert next_batch.next_token_choosers[1:] == next_batch_1.next_token_choosers\n\n    assert next_batch.stopping_criterias[0] == next_batch_0.stopping_criterias[0]\n    assert next_batch.stopping_criterias[1:] == next_batch_1.stopping_criterias\n\n    assert next_batch.past_key_values is not None\n    assert all(\n        [p[0].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[1].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[2].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]\n    )\n    assert all(\n        [p[3].shape == (len(next_batch), 6, 2, 64) for p in next_batch.past_key_values]\n    )\n\n    for i, past in enumerate(next_batch.past_key_values):\n        assert torch.equal(next_batch_0_past_key_values[i][0][0, :, -2:, :], past[0][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][0][:, :, -1:, :], past[0][1:, :, -1:, :]\n        )\n\n        assert torch.equal(next_batch_0_past_key_values[i][1][0, :, -2:, :], past[1][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][1][:, :, -1:, :], past[1][1:, :, -1:, :]\n        )\n\n        assert torch.equal(next_batch_0_past_key_values[i][2][0, :, -2:, :], past[2][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][2][:, :, -2:, :], past[2][1:]\n        )\n\n        assert torch.equal(next_batch_0_past_key_values[i][3][0, :, -2:, :], past[3][0])\n        assert torch.equal(\n            next_batch_1_past_key_values[i][3][:, :, -2:, :], past[3][1:]\n        )\n\n    for _ in range(3):\n        generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n        assert len(generations) == len(next_batch)\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 3\n    assert generations[2].generated_text.text == \"a few \"\n    assert (\n        generations[2].request_id\n        == default_multi_requests_seq2seq_lm_batch.requests[1].id\n    )\n    assert generations[2].generated_text.generated_tokens == 5\n\n    next_batch = next_batch.filter(\n        [next_batch.requests[0].id, next_batch.requests[1].id]\n    )\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is not None\n\n    assert len(generations) == 2\n    assert generations[0].generated_text.text == \"a few weeks\"\n    assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id\n    assert generations[0].generated_text.generated_tokens == 7\n\n    next_batch = next_batch.filter([next_batch.requests[1].id])\n\n    generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)\n    assert next_batch is None\n\n    assert len(generations) == 1\n    assert generations[0].generated_text.text == \"a few weeks\"\n    assert (\n        generations[0].request_id\n        == default_multi_requests_seq2seq_lm_batch.requests[0].id\n    )\n    assert generations[0].generated_text.generated_tokens == 7\n"
  },
  {
    "path": "server/tests/utils/test_adapter.py",
    "content": "import pytest\nfrom unittest.mock import Mock\nfrom text_generation_server.utils.adapter import (\n    get_attn_weights,\n    get_mlp_weights,\n    parse_lora_adapters,\n    AdapterInfo,\n)\n\n\ndef test_parse_lora_adapters_empty():\n    assert parse_lora_adapters(None) == []\n    assert parse_lora_adapters(\"\") == []\n\n\ndef test_parse_lora_adapters_single():\n    result = parse_lora_adapters(\"adapter1\")\n    assert result == [AdapterInfo(id=\"adapter1\", path=None, revision=None)]\n\n\ndef test_parse_lora_adapters_with_path():\n    result = parse_lora_adapters(\"adapter1=path/to/adapter1\")\n    assert result == [\n        AdapterInfo(id=\"adapter1\", path=\"path/to/adapter1\", revision=None)\n    ]\n\n\ndef test_parse_lora_adapters_with_path_and_revision():\n    result = parse_lora_adapters(\"adapter1=path/to/adapter1@main\")\n    assert result == [\n        AdapterInfo(id=\"adapter1\", path=\"path/to/adapter1\", revision=\"main\")\n    ]\n\n\ndef test_parse_lora_adapters_multiple():\n    result = parse_lora_adapters(\n        \"adapter1,adapter2=path/to/adapter2,adapter3=path/to/adapter3@dev\"\n    )\n    assert result == [\n        AdapterInfo(id=\"adapter1\", path=None, revision=None),\n        AdapterInfo(id=\"adapter2\", path=\"path/to/adapter2\", revision=None),\n        AdapterInfo(id=\"adapter3\", path=\"path/to/adapter3\", revision=\"dev\"),\n    ]\n\n\ndef test_parse_lora_adapters_invalid_format():\n    try:\n        parse_lora_adapters(\"adapter1,invalid=format=test,adapter3\")\n        assert False, \"Should have raised ValueError\"\n    except ValueError as e:\n        assert str(e) == \"Invalid LoRA adapter format: invalid=format=test\"\n\n\ndef test_get_attn_weights():\n    # create a mock layer\n    mock_layer = Mock()\n    mock_layer.self_attn.query_key_value = Mock()\n    mock_layer.self_attn.o_proj = Mock()\n\n    # call the function\n    result = get_attn_weights(2, mock_layer)\n\n    # assert the result\n    expected = {\n        (2, \"q_proj\"): (\n            \"model.layers.2.self_attn.q_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"k_proj\"): (\n            \"model.layers.2.self_attn.k_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"qkv_proj\"): (\n            \"model.layers.2.self_attn.qkv_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"v_proj\"): (\n            \"model.layers.2.self_attn.v_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"o_proj\"): (\"model.layers.2.self_attn.o_proj\", mock_layer.self_attn.o_proj),\n    }\n    assert result == expected\n\n\ndef test_get_mlp_weights_with_gate_up_proj():\n    # create a mock layer with gate_up_proj\n    mock_layer = Mock()\n    mock_layer.mlp.gate_up_proj = Mock()\n    mock_layer.mlp.down_proj = Mock()\n\n    # call the function\n    result = get_mlp_weights(3, mock_layer)\n\n    # assert the result\n    expected = {\n        (3, \"c_fc\"): (\"model.layers.3.mlp.c_fc\", mock_layer.mlp.c_fc),\n        (3, \"c_proj\"): (\"model.layers.3.mlp.c_proj\", mock_layer.mlp.c_proj),\n        (3, \"gate_proj\"): (\"model.layers.3.mlp.gate_proj\", mock_layer.mlp.gate_up_proj),\n        (3, \"up_proj\"): (\"model.layers.3.mlp.up_proj\", mock_layer.mlp.gate_up_proj),\n        (3, \"down_proj\"): (\"model.layers.3.mlp.down_proj\", mock_layer.mlp.down_proj),\n    }\n    assert result == expected\n\n\ndef test_get_mlp_weights_without_gate_up_proj():\n    # create a mock layer without gate_up_proj\n    mock_layer = Mock()\n    mock_layer.mlp = Mock(spec=[])\n\n    # call the function\n    result = get_mlp_weights(1, mock_layer)\n\n    # assert the result\n    assert result == {}\n\n\n@pytest.mark.parametrize(\"layer_index\", [0, 1, 5])\ndef test_get_attn_weights_different_layers(layer_index):\n    mock_layer = Mock()\n    mock_layer.self_attn.query_key_value = Mock()\n    mock_layer.self_attn.o_proj = Mock()\n\n    result = get_attn_weights(layer_index, mock_layer)\n\n    for k in [\"q\", \"k\", \"v\"]:\n        assert (layer_index, f\"{k}_proj\") in result\n        assert (\n            result[(layer_index, f\"{k}_proj\")][0]\n            == f\"model.layers.{layer_index}.self_attn.{k}_proj\"\n        )\n\n    assert (layer_index, \"o_proj\") in result\n    assert (\n        result[(layer_index, \"o_proj\")][0]\n        == f\"model.layers.{layer_index}.self_attn.o_proj\"\n    )\n\n\n@pytest.mark.parametrize(\"layer_index\", [0, 1, 5])\ndef test_get_mlp_weights_different_layers(layer_index):\n    mock_layer = Mock()\n    mock_layer.mlp.gate_up_proj = Mock()\n    mock_layer.mlp.down_proj = Mock()\n\n    result = get_mlp_weights(layer_index, mock_layer)\n\n    for k in [\"gate\", \"up\", \"down\"]:\n        assert (layer_index, f\"{k}_proj\") in result\n        assert (\n            result[(layer_index, f\"{k}_proj\")][0]\n            == f\"model.layers.{layer_index}.mlp.{k}_proj\"\n        )\n\n\ndef test_get_attn_weights_llama_compatibility():\n    mock_layer = Mock()\n    mock_layer.self_attn.query_key_value = Mock()\n    mock_layer.self_attn.o_proj = Mock()\n\n    result = get_attn_weights(2, mock_layer)\n\n    expected = {\n        (2, \"q_proj\"): (\n            \"model.layers.2.self_attn.q_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"k_proj\"): (\n            \"model.layers.2.self_attn.k_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"qkv_proj\"): (\n            \"model.layers.2.self_attn.qkv_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"v_proj\"): (\n            \"model.layers.2.self_attn.v_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"o_proj\"): (\"model.layers.2.self_attn.o_proj\", mock_layer.self_attn.o_proj),\n    }\n    assert result == expected\n\n\ndef test_get_mlp_weights_llama_compatibility():\n    mock_layer = Mock()\n    mock_layer.mlp.gate_up_proj = Mock()\n    mock_layer.mlp.down_proj = Mock()\n\n    result = get_mlp_weights(3, mock_layer)\n\n    expected = {\n        (3, \"c_fc\"): (\"model.layers.3.mlp.c_fc\", mock_layer.mlp.c_fc),\n        (3, \"c_proj\"): (\"model.layers.3.mlp.c_proj\", mock_layer.mlp.c_proj),\n        (3, \"gate_proj\"): (\"model.layers.3.mlp.gate_proj\", mock_layer.mlp.gate_up_proj),\n        (3, \"up_proj\"): (\"model.layers.3.mlp.up_proj\", mock_layer.mlp.gate_up_proj),\n        (3, \"down_proj\"): (\"model.layers.3.mlp.down_proj\", mock_layer.mlp.down_proj),\n    }\n    assert result == expected\n\n\ndef test_get_attn_weights_gemma_compatibility():\n    mock_layer = Mock()\n    mock_layer.self_attn.query_key_value = Mock()\n    mock_layer.self_attn.o_proj = Mock()\n\n    result = get_attn_weights(2, mock_layer)\n\n    expected = {\n        (2, \"q_proj\"): (\n            \"model.layers.2.self_attn.q_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"k_proj\"): (\n            \"model.layers.2.self_attn.k_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"qkv_proj\"): (\n            \"model.layers.2.self_attn.qkv_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"v_proj\"): (\n            \"model.layers.2.self_attn.v_proj\",\n            mock_layer.self_attn.query_key_value,\n        ),\n        (2, \"o_proj\"): (\"model.layers.2.self_attn.o_proj\", mock_layer.self_attn.o_proj),\n    }\n    assert result == expected\n\n\ndef test_get_mlp_weights_gemma_compatibility():\n    mock_layer = Mock()\n    mock_layer.mlp.gate_proj = Mock()\n    mock_layer.mlp.up_proj = Mock()\n    mock_layer.mlp.down_proj = Mock()\n\n    # ensure that the mock_layer.mlp.gate_up_proj attribute does not exist.\n    # This is necessary because the use of `Mock` automatically creates any\n    # attributes that are accessed, even if they don't exist in the actual\n    # implementation. If `gate_up_proj` were created, `get_mlp_weights` might\n    # follow the wrong execution path and return an incorrect result.\n    del mock_layer.mlp.gate_up_proj\n\n    result = get_mlp_weights(3, mock_layer)\n\n    expected = {\n        (3, \"c_fc\"): (\"model.layers.3.mlp.c_fc\", mock_layer.mlp.c_fc),\n        (3, \"c_proj\"): (\"model.layers.3.mlp.c_proj\", mock_layer.mlp.c_proj),\n        (3, \"gate_proj\"): (\"model.layers.3.mlp.gate_proj\", mock_layer.mlp.gate_proj),\n        (3, \"up_proj\"): (\"model.layers.3.mlp.up_proj\", mock_layer.mlp.up_proj),\n        (3, \"down_proj\"): (\"model.layers.3.mlp.down_proj\", mock_layer.mlp.down_proj),\n    }\n    assert result == expected\n"
  },
  {
    "path": "server/tests/utils/test_convert.py",
    "content": "from text_generation_server.utils.hub import (\n    download_weights,\n    weight_hub_files,\n    weight_files,\n)\n\nfrom text_generation_server.utils.convert import convert_files\n\n\ndef test_convert_files():\n    model_id = \"bigscience/bloom-560m\"\n    pt_filenames = weight_hub_files(model_id, extension=\".bin\")\n    local_pt_files = download_weights(pt_filenames, model_id)\n    local_st_files = [\n        p.parent / f\"{p.stem.lstrip('pytorch_')}.safetensors\" for p in local_pt_files\n    ]\n    convert_files(local_pt_files, local_st_files, discard_names=[])\n\n    found_st_files = weight_files(model_id)\n\n    assert all([p in found_st_files for p in local_st_files])\n"
  },
  {
    "path": "server/tests/utils/test_hub.py",
    "content": "import os\nimport tempfile\n\nimport pytest\n\nimport huggingface_hub.constants\n\nimport text_generation_server.utils.hub\nfrom text_generation_server.utils.hub import (\n    weight_hub_files,\n    download_weights,\n    weight_files,\n    EntryNotFoundError,\n    LocalEntryNotFoundError,\n    RevisionNotFoundError,\n)\n\n\n@pytest.fixture()\ndef offline():\n    current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE\n    text_generation_server.utils.hub.HF_HUB_OFFLINE = True\n    yield \"offline\"\n    text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value\n\n\n@pytest.fixture()\ndef fresh_cache():\n    with tempfile.TemporaryDirectory() as d:\n        current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE\n        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d\n        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d\n        os.environ[\"HUGGINGFACE_HUB_CACHE\"] = d\n        yield\n        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value\n        os.environ[\"HUGGINGFACE_HUB_CACHE\"] = current_value\n        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value\n\n\n@pytest.fixture()\ndef prefetched():\n    model_id = \"bert-base-uncased\"\n    huggingface_hub.snapshot_download(\n        repo_id=model_id,\n        revision=\"main\",\n        local_files_only=False,\n        repo_type=\"model\",\n        allow_patterns=[\"*.safetensors\"],\n    )\n    yield model_id\n\n\ndef test_weight_hub_files_offline_error(offline, fresh_cache):\n    # If the model is not prefetched then it will raise an error\n    with pytest.raises(EntryNotFoundError):\n        weight_hub_files(\"gpt2\")\n\n\ndef test_weight_hub_files_offline_ok(prefetched, offline):\n    # If the model is prefetched then we should be able to get the weight files from local cache\n    filenames = weight_hub_files(prefetched)\n    root = None\n    assert len(filenames) == 1\n    for f in filenames:\n        curroot, filename = os.path.split(f)\n        if root is None:\n            root = curroot\n        else:\n            assert root == curroot\n        assert filename == \"model.safetensors\"\n\n\ndef test_weight_hub_files():\n    filenames = weight_hub_files(\"bigscience/bloom-560m\")\n    assert filenames == [\"model.safetensors\"]\n\n\ndef test_weight_hub_files_llm():\n    filenames = weight_hub_files(\"bigscience/bloom\")\n    assert filenames == [f\"model_{i:05d}-of-00072.safetensors\" for i in range(1, 73)]\n\n\ndef test_weight_hub_files_empty():\n    with pytest.raises(EntryNotFoundError):\n        weight_hub_files(\"bigscience/bloom\", extension=\".errors\")\n\n\ndef test_download_weights():\n    model_id = \"bigscience/bloom-560m\"\n    filenames = weight_hub_files(model_id)\n    files = download_weights(filenames, model_id)\n    local_files = weight_files(\"bigscience/bloom-560m\")\n    assert files == local_files\n\n\ndef test_weight_files_revision_error():\n    with pytest.raises(RevisionNotFoundError):\n        weight_files(\"bigscience/bloom-560m\", revision=\"error\")\n\n\ndef test_weight_files_not_cached_error(fresh_cache):\n    with pytest.raises(LocalEntryNotFoundError):\n        weight_files(\"bert-base-uncased\")\n"
  },
  {
    "path": "server/tests/utils/test_layers.py",
    "content": "import torch\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n)\n\n\nclass ProcessGroup:\n    def __init__(self, rank: int, world_size: int):\n        self._rank = rank\n        self.world_size = world_size\n\n    def size(self) -> int:\n        return self.world_size\n\n    def rank(self) -> int:\n        return self._rank\n\n\nclass Weights:\n    def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):\n        self.weight = (\n            torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)\n        )\n        self.process_group = ProcessGroup(rank, world_size)\n\n    def get_partial_sharded(self, name: str, dim: int):\n        assert dim == 0\n\n        rank = self.process_group.rank()\n        world_size = self.process_group.size()\n        size = self.weight.shape[dim]\n\n        block_size = (size + world_size - 1) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        return self.weight[start:stop]\n\n    def get_shape(self, name: str):\n        return self.weight.shape\n\n\ndef test_weight_hub_files_offline_error():\n\n    vocab_size = 17\n    weights = Weights(\n        rank=0,\n        world_size=1,\n        vocab_size=vocab_size,\n        hidden_dim=256,\n    )\n    embeddings = TensorParallelEmbedding(\"\", weights)\n\n    input_ids = torch.arange(vocab_size)\n    output = embeddings.forward(input_ids)\n    assert embeddings.min_id == 0\n    assert embeddings.max_id == 17\n    torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))\n\n    weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)\n    weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)\n    embeddings_0_2 = TensorParallelEmbedding(\"\", weights_0_2, reduce=False)\n    assert embeddings_0_2.min_id == 0\n    assert embeddings_0_2.max_id == 9\n    torch.testing.assert_close(\n        embeddings_0_2.weight,\n        torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)\n        .view(10, 256)\n        .float(),\n    )\n    embeddings_1_2 = TensorParallelEmbedding(\"\", weights_1_2, reduce=False)\n    assert embeddings_1_2.min_id == 9\n    assert embeddings_1_2.max_id == 17\n    torch.testing.assert_close(\n        embeddings_1_2.weight,\n        torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)\n        .view(9, 256)\n        .float(),\n    )\n    output_tp_0 = embeddings_0_2.forward(input_ids)\n    output_tp_1 = embeddings_1_2.forward(input_ids)\n\n    torch.testing.assert_close(output, output_tp_0 + output_tp_1)\n"
  },
  {
    "path": "server/tests/utils/test_tokens.py",
    "content": "import torch\nfrom text_generation_server.utils.tokens import (\n    StopSequenceCriteria,\n    StoppingCriteria,\n    FinishReason,\n    batch_top_tokens,\n)\n\n\ndef test_stop_sequence_criteria():\n    criteria = StopSequenceCriteria(\"/test;\")\n\n    assert not criteria(\"/\")\n    assert not criteria(\"/test\")\n    assert criteria(\"/test;\")\n    assert not criteria(\"/test; \")\n\n\ndef test_stop_sequence_criteria_escape():\n    criteria = StopSequenceCriteria(\"<|stop|>\")\n\n    assert not criteria(\"<\")\n    assert not criteria(\"<|stop\")\n    assert criteria(\"<|stop|>\")\n    assert not criteria(\"<|stop|> \")\n\n\ndef test_stopping_criteria():\n    criteria = StoppingCriteria(0, [StopSequenceCriteria(\"/test;\")], max_new_tokens=5)\n    assert criteria(65827, \"/test\") == (False, None)\n    assert criteria(30, \";\") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)\n\n\ndef test_stopping_criteria_eos():\n    criteria = StoppingCriteria(0, [StopSequenceCriteria(\"/test;\")], max_new_tokens=5)\n    assert criteria(1, \"\") == (False, None)\n    assert criteria(0, \"\") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)\n\n\ndef test_stopping_criteria_max():\n    criteria = StoppingCriteria(0, [StopSequenceCriteria(\"/test;\")], max_new_tokens=5)\n    assert criteria(1, \"\") == (False, None)\n    assert criteria(1, \"\") == (False, None)\n    assert criteria(1, \"\") == (False, None)\n    assert criteria(1, \"\") == (False, None)\n    assert criteria(1, \"\") == (True, FinishReason.FINISH_REASON_LENGTH)\n\n\ndef test_batch_top_tokens():\n    top_n_tokens = [0, 2, 3, 4, 5]\n    top_n_tokens_tensor = torch.tensor(top_n_tokens)\n    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)\n    accepted_ids = torch.ones_like(top_n_tokens_tensor)\n\n    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(\n        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids\n    )\n\n    assert topn_tok_ids[0] == [[]]\n    assert topn_tok_ids[1] == [[0, 3]]\n    assert topn_tok_ids[2] == [[0, 3, 1, 4]]\n    assert topn_tok_ids[3] == [[0, 3, 1, 4]]\n    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]\n\n    assert topn_tok_logprobs[0] == [[]]\n    assert topn_tok_logprobs[1] == [[-1, -2]]\n    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]\n    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]\n    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]\n\n    # Now let's make second member of the batch be speculated\n    inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5 * 2)\n    accepted_ids[1] = 2\n    topn_tok_ids, topn_tok_logprobs = batch_top_tokens(\n        top_n_tokens, top_n_tokens_tensor, inp_logprobs, accepted_ids\n    )\n\n    assert topn_tok_ids[0] == [[]]\n    assert topn_tok_ids[1] == [[0, 3], [0, 3]]\n    assert topn_tok_ids[2] == [[0, 3, 1, 4]]\n    assert topn_tok_ids[3] == [[0, 3, 1, 4]]\n    assert topn_tok_ids[4] == [[0, 3, 1, 4, 2]]\n\n    assert topn_tok_logprobs[0] == [[]]\n    assert topn_tok_logprobs[1] == [[-1, -2], [-1, -2]]\n    assert topn_tok_logprobs[2] == [[-1, -2, -3, -3]]\n    assert topn_tok_logprobs[3] == [[-1, -2, -3, -3]]\n    assert topn_tok_logprobs[4] == [[-1, -2, -3, -3, -4]]\n"
  },
  {
    "path": "server/tests/utils/test_watermark.py",
    "content": "# test_watermark_logits_processor.py\n\nimport os\nimport numpy as np\nimport torch\nfrom text_generation_server.utils.watermark import WatermarkLogitsProcessor\n\n\nGAMMA = os.getenv(\"WATERMARK_GAMMA\", 0.5)\nDELTA = os.getenv(\"WATERMARK_DELTA\", 2.0)\n\n\ndef test_seed_rng():\n    input_ids = [101, 2036, 3731, 102, 2003, 103]\n    processor = WatermarkLogitsProcessor()\n    processor._seed_rng(input_ids)\n    assert isinstance(processor.rng, torch.Generator)\n\n\ndef test_get_greenlist_ids():\n    input_ids = [101, 2036, 3731, 102, 2003, 103]\n    processor = WatermarkLogitsProcessor()\n    result = processor._get_greenlist_ids(input_ids, 10, torch.device(\"cpu\"))\n    assert max(result) <= 10\n    assert len(result) == int(10 * 0.5)\n\n\ndef test_calc_greenlist_mask():\n    processor = WatermarkLogitsProcessor()\n    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])\n    greenlist_token_ids = torch.tensor([2, 3])\n    result = processor._calc_greenlist_mask(scores, greenlist_token_ids)\n    assert result.tolist() == [[False, False, False, False], [False, False, True, True]]\n    assert result.shape == scores.shape\n\n\ndef test_bias_greenlist_logits():\n    processor = WatermarkLogitsProcessor()\n    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])\n    green_tokens_mask = torch.tensor(\n        [[False, False, True, True], [False, False, False, True]]\n    )\n    greenlist_bias = 2.0\n    result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)\n    assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])\n    assert result.shape == scores.shape\n\n\ndef test_call():\n    input_ids = [101, 2036, 3731, 102, 2003, 103]\n    processor = WatermarkLogitsProcessor()\n    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])\n    result = processor(input_ids, scores)\n    assert result.shape == scores.shape\n"
  },
  {
    "path": "server/tests/utils/test_weights.py",
    "content": "import pytest\nimport torch\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    Weights,\n    WeightsLoader,\n)\nfrom text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader\nfrom text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader\nfrom text_generation_server.layers.marlin.marlin import (\n    MarlinWeight,\n    MarlinWeightsLoader,\n)\nfrom types import SimpleNamespace\nfrom typing import List, Optional, Dict, Union\nfrom pathlib import Path\n\n\n@pytest.fixture\ndef gptq_weights_loader():\n    return GPTQWeightsLoader(\n        bits=4,\n        groupsize=-1,\n        desc_act=False,\n        quant_method=\"gptq\",\n        quantize=\"gptq\",\n        sym=True,\n        modules_to_not_convert=[],\n    )\n\n\n@pytest.fixture\ndef gptq_weights_loader_awq():\n    return GPTQWeightsLoader(\n        bits=4,\n        groupsize=-1,\n        desc_act=False,\n        quant_method=\"awq\",\n        quantize=\"awq\",\n        sym=True,\n        modules_to_not_convert=[],\n    )\n\n\n@pytest.fixture\ndef marlin_weights_loader():\n    return MarlinWeightsLoader(bits=4, is_marlin_24=False)\n\n\ndummy_file_system = {\n    \"test_weights\": {\n        \"layer.0.weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n            ],\n            dtype=torch.float32,\n        ),\n    },\n    \"test_weights_2\": {\n        \"layer.1337.weight\": torch.tensor(\n            [\n                [1, 2, 3, 4],\n                [5, 6, 7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    },\n    \"test_get_weights_col_packed\": {\n        \"weight.weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    },\n    \"test_get_multi_weights_col\": {\n        \"weight.weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    },\n    \"test_get_weights_row\": {\n        \"weight.weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    },\n    \"test_get_weights_col_gptq\": {\n        \"weight.qweight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n        \"weight.g_idx\": torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        \"weight.qzeros\": torch.tensor(\n            [\n                [0, 1],\n                [1, 0],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.scales\": torch.tensor(\n            [\n                [100.0, 100.0],\n                [100.0, 100.0],\n            ],\n            dtype=torch.float16,\n        ),\n        \"gptq_bits\": torch.tensor([8], dtype=torch.float32),\n        \"gptq_groupsize\": torch.tensor([2], dtype=torch.float32),\n    },\n    \"test_get_weights_col_marlin\": {\n        \"weight.B\": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        \"weight.s\": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),\n    },\n    \"test_get_weights_row_gptq\": {\n        \"weight.qweight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.g_idx\": torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        \"weight.qzeros\": torch.tensor(\n            [\n                [0, 1],\n                [1, 0],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.scales\": torch.tensor(\n            [\n                [100.0, 100.0],\n                [100.0, 100.0],\n            ],\n            dtype=torch.float16,\n        ),\n        \"gptq_bits\": torch.tensor([8], dtype=torch.float32),\n        \"gptq_groupsize\": torch.tensor([2], dtype=torch.float32),\n    },\n    \"test_get_multi_weights_col_gptq\": {\n        \"weight.qweight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.g_idx\": torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        \"weight.qzeros\": torch.tensor(\n            [\n                [0, 1],\n                [1, 0],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.scales\": torch.tensor(\n            [\n                [100.0, 100.0],\n                [100.0, 100.0],\n            ],\n            dtype=torch.float16,\n        ),\n        \"gptq_bits\": torch.tensor([8], dtype=torch.float32),\n        \"gptq_groupsize\": torch.tensor([2], dtype=torch.float32),\n    },\n    \"test_get_weights_col_packed_gptq\": {\n        \"weight.qweight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.g_idx\": torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        \"weight.qzeros\": torch.tensor(\n            [\n                [0, 1],\n                [1, 0],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.scales\": torch.tensor(\n            [\n                [100.0, 100.0],\n                [100.0, 100.0],\n            ],\n            dtype=torch.float16,\n        ),\n        \"gptq_bits\": torch.tensor([8], dtype=torch.float32),\n        \"gptq_groupsize\": torch.tensor([2], dtype=torch.float32),\n    },\n    \"test_get_weights_col_packed_exl2\": {\n        \"weight.q_weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.q_scale\": torch.tensor([8], dtype=torch.int32),\n        \"weight.q_invperm\": torch.tensor([1, 0, 3, 2], dtype=torch.int32),\n        \"weight.q_scale_max\": torch.tensor([100], dtype=torch.float16),\n        \"weight.q_groups\": torch.tensor([4], dtype=torch.int16),\n    },\n    \"test_get_weights_row_exl2\": {\n        \"weight.q_weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.q_scale\": torch.tensor([8], dtype=torch.int32),\n        \"weight.q_invperm\": torch.tensor([1, 0, 3, 2], dtype=torch.int32),\n        \"weight.q_scale_max\": torch.tensor([100], dtype=torch.float16),\n        \"weight.q_groups\": torch.tensor([4], dtype=torch.int16),\n    },\n    \"test_get_multi_weights_col_exl2\": {\n        \"weight.q_weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.q_scale\": torch.tensor([8], dtype=torch.int32),\n        \"weight.q_invperm\": torch.tensor([1, 0, 3, 2], dtype=torch.int32),\n        \"weight.q_scale_max\": torch.tensor([100], dtype=torch.float16),\n        \"weight.q_groups\": torch.tensor([4], dtype=torch.int16),\n    },\n    \"test_get_weights_col_exl2\": {\n        \"weight.q_weight\": torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.int32,\n        ),\n        \"weight.q_scale\": torch.tensor([8], dtype=torch.int32),\n        \"weight.q_invperm\": torch.tensor([1, 0, 3, 2], dtype=torch.int32),\n        \"weight.q_scale_max\": torch.tensor([100], dtype=torch.float16),\n        \"weight.q_groups\": torch.tensor([4], dtype=torch.int16),\n    },\n    \"test_get_weights_row_marlin\": {\n        \"weight.B\": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        \"weight.s\": torch.tensor([[0.5], [0.25]], dtype=torch.float16),\n    },\n    \"test_get_multi_weights_col_marlin\": {\n        \"weight.B\": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        \"weight.s\": torch.tensor([[0.5], [0.25]], dtype=torch.float16),\n    },\n    \"test_get_weights_col_packed_marlin\": {\n        \"weight.B\": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        \"weight.s\": torch.tensor([[0.5], [0.25]], dtype=torch.float16),\n    },\n}\n\n\nclass MockSlice:\n    def __init__(self, tensor):\n        self.tensor = tensor\n\n    def get_shape(self):\n        return self.tensor.shape\n\n    def __getitem__(self, idx):\n        return self.tensor[idx]\n\n\ndef mock_get_slice(tensor_name, filename):\n    tensor = dummy_file_system[filename][tensor_name]\n    return MockSlice(tensor)\n\n\ndef mock_handle(filename, device, dtype):\n    return SimpleNamespace(\n        get_slice=lambda tensor_name: mock_get_slice(tensor_name, filename)\n    )\n\n\nclass MockSafeOpen:\n    def __init__(self, filename, framework, dummy_fs):\n        self.filename = filename\n        self.framework = framework\n        self.dummy_fs = dummy_fs\n\n    def keys(self):\n        return list(self.dummy_fs[self.filename].keys())\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        pass\n\n\nclass MockWeights(Weights):\n    def __init__(\n        self,\n        filenames: List[Union[Path, str]],\n        device,\n        dtype,\n        process_group,\n        dummy_fs,\n        aliases: Optional[Dict[str, List[str]]] = None,\n        prefix: Optional[str] = None,\n        weights_loader: Optional[WeightsLoader] = None,\n    ):\n        routing = {}\n        self.dummy_fs = dummy_fs\n        for filename in filenames:\n            with MockSafeOpen(filename, framework=\"pytorch\", dummy_fs=dummy_fs) as f:\n                for k in f.keys():\n                    if k in routing:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n        if aliases is None:\n            aliases = {}\n        self.aliases = aliases\n        self.routing = routing\n        self.device = device\n        self.dtype = dtype\n        self.process_group = process_group\n        self.prefix = prefix\n        self.weights_loader = (\n            # We don't need to get linear layers, so just wrap raw tensors.\n            DefaultWeightsLoader(lambda x: x)\n            if weights_loader is None\n            else weights_loader\n        )\n        self._handles = {}\n\n    def _get_handle(self, filename: Union[Path, str]):\n        if filename in self._handles:\n            return self._handles[filename]\n        else:\n            handle = mock_handle(filename, self.device, self.dtype)\n            self._handles[filename] = handle\n            return handle\n\n    def get_shape(self, tensor_name: str):\n        filename, _ = self.get_filename(tensor_name)\n        handle = self._get_handle(filename)\n        return handle.get_slice(tensor_name).get_shape()\n\n    def get_tensor(self, tensor_name: str):\n        filename, _ = self.get_filename(tensor_name)\n        handle = self._get_handle(filename)\n        return handle.get_slice(tensor_name).tensor\n\n\ndummy_process_group = SimpleNamespace(rank=lambda: 0, size=lambda: 1)\n\n\ndef test_weights():\n    weights = MockWeights(\n        [\n            \"test_weights\",\n            \"test_weights_2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n    assert weights.get_shape(\"layer.0.weight\") == (2, 2)\n    assert weights.get_tensor(\"layer.1337.weight\").shape == (2, 4)\n\n\ndef test_get_tensor():\n    weights = MockWeights(\n        [\n            \"test_weights\",\n            \"test_weights_2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n    assert torch.allclose(\n        weights.get_tensor(\"layer.0.weight\"),\n        torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n    assert torch.allclose(\n        weights.get_tensor(\"layer.1337.weight\"),\n        torch.tensor(\n            [\n                [1, 2, 3, 4],\n                [5, 6, 7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n\ndef test_get_weights_col_packed():\n\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n\n    prefix = \"weight\"\n    block_sizes = 1\n\n    w = weights.get_weights_col_packed(\n        prefix=prefix,\n        block_sizes=block_sizes,\n    )\n\n    assert torch.allclose(\n        w,\n        torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n\ndef test_get_weights_col_packed_block_size():\n\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n\n    prefix = \"weight\"\n    block_sizes = 2\n\n    w = weights.get_weights_col_packed(\n        prefix=prefix,\n        block_sizes=block_sizes,\n    )\n\n    assert torch.allclose(\n        w,\n        torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n\ndef test_get_weights_col_packed_block_size_arr():\n\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n\n    prefix = \"weight\"\n    block_sizes = [1, 1]\n\n    w = weights.get_weights_col_packed(\n        prefix=prefix,\n        block_sizes=block_sizes,\n    )\n\n    assert torch.allclose(\n        w,\n        torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n\ndef test_get_multi_weights_col():\n    weights = MockWeights(\n        [\n            \"test_get_multi_weights_col\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n\n    prefixes = [\"weight\", \"weight\"]\n\n    w = weights.get_multi_weights_col(\n        prefixes=prefixes,\n        dim=0,\n    )\n\n    assert torch.allclose(\n        w,\n        torch.tensor(\n            [\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n                [1, 2],\n                [3, 4],\n                [5, 6],\n                [7, 8],\n            ],\n            dtype=torch.float32,\n        ),\n    )\n\n\ndef test_get_weights_row():\n    weights = MockWeights(\n        [\n            \"test_get_weights_row\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_row(\n        prefix=prefix,\n    )\n\n    assert torch.allclose(\n        w,\n        torch.tensor(\n            [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],\n            dtype=torch.float32,\n        ),\n    )\n\n\n# test_get_weights_col\n\n\ndef test_get_weights_col_awq(gptq_weights_loader_awq):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader_awq,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_col(\n        prefix=prefix,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor(\n            [[100.0, 100.0], [100.0, 100.0]],\n            dtype=torch.float16,\n        ),\n        g_idx=None,\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=True,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert w.g_idx == expected_weight.g_idx, \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_weights_col_gtpq(gptq_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_col(\n        prefix=prefix,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=False,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert torch.allclose(w.g_idx, expected_weight.g_idx), \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_weights_col_exl2():\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_exl2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=Exl2WeightsLoader(),\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_col(\n        prefix=prefix,\n    )\n\n    scaled_scale_max = 0.3906 * 256\n    expected_weight = Exl2Weight(\n        q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        q_scale=torch.tensor([8], dtype=torch.int32),\n        q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),\n        q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),\n        q_groups=torch.tensor([4], dtype=torch.int16),\n    )\n\n    assert torch.allclose(w.q_weight, expected_weight.q_weight), \"q_weight mismatch\"\n    assert torch.allclose(w.q_scale, expected_weight.q_scale), \"q_scale mismatch\"\n    assert torch.allclose(w.q_invperm, expected_weight.q_invperm), \"q_invperm mismatch\"\n    assert torch.allclose(\n        w.q_scale_max, expected_weight.q_scale_max\n    ), \"q_scale_max mismatch\"\n    assert torch.allclose(w.q_groups, expected_weight.q_groups), \"q_groups mismatch\"\n\n\ndef test_get_weights_col_marlin(marlin_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_marlin\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float16,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=marlin_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_col(\n        prefix=prefix,\n    )\n\n    expected_weight = MarlinWeight(\n        B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),\n    )\n\n    assert torch.allclose(w.B, expected_weight.B), \"B mismatch\"\n    assert torch.allclose(w.s, expected_weight.s), \"s mismatch\"\n\n\n# test_get_weights_col_packed\n\n\ndef test_get_weights_col_packed_awq(gptq_weights_loader_awq):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader_awq,\n    )\n\n    prefix = \"weight\"\n    block_sizes = 1\n\n    w = weights.get_weights_col_packed(\n        prefix=prefix,\n        block_sizes=block_sizes,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=None,\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=True,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert w.g_idx == expected_weight.g_idx, \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\n@pytest.mark.skip(reason=\"Review expected functionality\")\ndef test_get_weights_col_packed_exl2():\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed_exl2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=Exl2WeightsLoader(),\n    )\n\n    prefix = \"weight\"\n    block_sizes = 1\n\n    w = weights.get_weights_col_packed(\n        prefix=prefix,\n        block_sizes=block_sizes,\n    )\n\n    scaled_scale_max = 0.3906 * 256\n    expected_weight = Exl2Weight(\n        q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        q_scale=torch.tensor([8], dtype=torch.int32),\n        q_invperm=torch.tensor([1], dtype=torch.int16),\n        q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),\n        q_groups=torch.tensor([4], dtype=torch.int16),\n    )\n\n    assert torch.allclose(w.q_weight, expected_weight.q_weight), \"q_weight mismatch\"\n    assert torch.allclose(w.q_scale, expected_weight.q_scale), \"q_scale mismatch\"\n    assert torch.allclose(w.q_invperm, expected_weight.q_invperm), \"q_invperm mismatch\"\n    assert torch.allclose(\n        w.q_scale_max, expected_weight.q_scale_max\n    ), \"q_scale_max mismatch\"\n    assert torch.allclose(w.q_groups, expected_weight.q_groups), \"q_groups mismatch\"\n\n\ndef test_get_weights_col_packed_gptq(gptq_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader,\n    )\n\n    prefixes = [\"weight\"]\n\n    w = weights.get_multi_weights_col(\n        prefixes=prefixes,\n        dim=0,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=False,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert torch.allclose(w.g_idx, expected_weight.g_idx), \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_weights_col_packed_marlin(marlin_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_col_packed_marlin\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float16,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=marlin_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_multi_weights_col(\n        prefixes=[prefix],\n        dim=0,\n    )\n\n    expected_weight = MarlinWeight(\n        B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),\n    )\n\n    print(expected_weight)\n\n    assert torch.allclose(w.B, expected_weight.B), \"B mismatch\"\n    assert torch.allclose(w.s, expected_weight.s), \"s mismatch\"\n\n\n# test_get_multi_weights_col\n\n\ndef test_get_multi_weights_col_awq(gptq_weights_loader_awq):\n    weights = MockWeights(\n        [\n            \"test_get_multi_weights_col_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader_awq,\n    )\n\n    prefixes = [\"weight\"]\n\n    w = weights.get_multi_weights_col(\n        prefixes=prefixes,\n        dim=0,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=None,\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=True,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert w.g_idx == expected_weight.g_idx, \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_multi_weights_col_exl2():\n    weights = MockWeights(\n        [\n            \"test_get_multi_weights_col_exl2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=Exl2WeightsLoader(),\n    )\n\n    prefix = \"weight\"\n\n    try:\n        weights.get_multi_weights_col(\n            prefixes=[prefix],\n            dim=0,\n        )\n    except ValueError as e:\n        assert e.args[0] == \"get_multi_weights_col is not supported for exl2\"\n\n\ndef test_get_multi_weights_col_gptq(gptq_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_multi_weights_col_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader,\n    )\n\n    prefixes = [\"weight\"]\n\n    w = weights.get_multi_weights_col(\n        prefixes=prefixes,\n        dim=0,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=False,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert torch.allclose(w.g_idx, expected_weight.g_idx), \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_multi_weights_col_marlin(marlin_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_multi_weights_col_marlin\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float16,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=marlin_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_multi_weights_col(\n        prefixes=[prefix],\n        dim=0,\n    )\n\n    expected_weight = MarlinWeight(\n        B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),\n    )\n\n    assert torch.allclose(w.B, expected_weight.B), \"B mismatch\"\n    assert torch.allclose(w.s, expected_weight.s), \"s mismatch\"\n\n\n# test_get_weights_row\n\n\ndef test_get_weights_row_awq(gptq_weights_loader_awq):\n    weights = MockWeights(\n        [\n            \"test_get_weights_row_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader_awq,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_row(\n        prefix=prefix,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=None,\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=True,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert w.g_idx == expected_weight.g_idx, \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_weights_row_exl2():\n    weights = MockWeights(\n        [\n            \"test_get_weights_row_exl2\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=Exl2WeightsLoader(),\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_row(\n        prefix=prefix,\n    )\n    print(w)\n\n    scaled_scale_max = 0.3906 * 256\n    expected_weight = Exl2Weight(\n        q_weight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        q_scale=torch.tensor([8], dtype=torch.int32),\n        q_invperm=torch.tensor([1, 0, 3, 2], dtype=torch.int16),\n        q_scale_max=torch.tensor([scaled_scale_max], dtype=torch.float16),\n        q_groups=torch.tensor([4], dtype=torch.int16),\n    )\n\n    assert torch.allclose(w.q_weight, expected_weight.q_weight), \"q_weight mismatch\"\n    assert torch.allclose(w.q_scale, expected_weight.q_scale), \"q_scale mismatch\"\n    assert torch.allclose(w.q_invperm, expected_weight.q_invperm), \"q_invperm mismatch\"\n    assert torch.allclose(\n        w.q_scale_max, expected_weight.q_scale_max\n    ), \"q_scale_max mismatch\"\n    assert torch.allclose(w.q_groups, expected_weight.q_groups), \"q_groups mismatch\"\n\n\ndef test_get_weights_row_gptq(gptq_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_row_gptq\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float32,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=gptq_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_row(\n        prefix=prefix,\n    )\n\n    expected_weight = GPTQWeight(\n        qweight=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.int32),\n        qzeros=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),\n        scales=torch.tensor([[100.0, 100.0], [100.0, 100.0]], dtype=torch.float16),\n        g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32),\n        bits=8.0,\n        groupsize=2.0,\n        use_awq_kernel=False,\n        use_exllama=False,\n    )\n\n    assert torch.allclose(w.qweight, expected_weight.qweight), \"qweight mismatch\"\n    assert torch.allclose(w.qzeros, expected_weight.qzeros), \"qzeros mismatch\"\n    assert torch.allclose(w.scales, expected_weight.scales), \"scales mismatch\"\n    assert torch.allclose(w.g_idx, expected_weight.g_idx), \"g_idx mismatch\"\n    assert w.bits == expected_weight.bits, \"bits mismatch\"\n    assert w.groupsize == expected_weight.groupsize, \"groupsize mismatch\"\n    assert w.use_awq_kernel == expected_weight.use_awq_kernel, \"use_awq_kernel mismatch\"\n    assert w.use_exllama == expected_weight.use_exllama, \"use_exllama mismatch\"\n\n\ndef test_get_weights_row_marlin(marlin_weights_loader):\n    weights = MockWeights(\n        [\n            \"test_get_weights_row_marlin\",\n        ],\n        device=\"cpu\",\n        dtype=torch.float16,\n        process_group=dummy_process_group,\n        dummy_fs=dummy_file_system,\n        weights_loader=marlin_weights_loader,\n    )\n\n    prefix = \"weight\"\n\n    w = weights.get_weights_row(\n        prefix=prefix,\n    )\n\n    expected_weight = MarlinWeight(\n        B=torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),\n        s=torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),\n    )\n\n    assert torch.allclose(w.B, expected_weight.B), \"B mismatch\"\n    assert torch.allclose(w.s, expected_weight.s), \"s mismatch\"\n"
  },
  {
    "path": "server/text_generation_server/__init__.py",
    "content": ""
  },
  {
    "path": "server/text_generation_server/adapters/__init__.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/__init__.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom text_generation_server.adapters.weights import (\n    AdapterBatchData,\n    AdapterBatchMetadata,\n)\n\n__all__ = [\n    \"AdapterBatchData\",\n    \"AdapterBatchMetadata\",\n]\n"
  },
  {
    "path": "server/text_generation_server/adapters/config.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/config.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Dict, Set, Tuple\n\nimport torch\n\nfrom text_generation_server.adapters.weights import AdapterWeights\n\n\n@dataclass\nclass ModuleMap:\n    module_name: str\n    module_weights: Dict[str, Tuple[torch.Tensor, str]]\n\n\n@dataclass\nclass AdapterConfig(ABC):\n    base_model_name_or_path: str\n\n    @abstractmethod\n    def map_weights_for_model(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        weight_names: Tuple[str],\n    ) -> Tuple[ModuleMap, Set[str]]:\n        pass\n"
  },
  {
    "path": "server/text_generation_server/adapters/lora.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/lora.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Set, Tuple, Type, Union\n\nfrom loguru import logger\nimport torch\nfrom peft import LoraConfig as _LoraConfig\nfrom torch.distributed import ProcessGroup\nfrom text_generation_server.utils.log import log_master\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.adapters.config import AdapterConfig, ModuleMap\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.adapters.weights import (\n    AdapterBatchMetadata,\n    AdapterWeights,\n    BatchAdapterWeights,\n)\n\nif SYSTEM == \"cuda\":\n    punica_sgmv = load_kernel(\n        module=\"punica_sgmv\", repo_id=\"kernels-community/punica-sgmv\"\n    )\nelse:\n    punica_sgmv = None\n\n\ndef get_start_stop_idxs_for_rank(offset, size, rank, world_size):\n    block_size = size // world_size\n    start = offset + rank * block_size\n    stop = offset + (rank + 1) * block_size\n    return start, stop\n\n\ndef shard_on_dim(\n    t: torch.Tensor, dim: int, process_group: torch.distributed.ProcessGroup\n):\n    world_size = process_group.size()\n    rank = process_group.rank()\n\n    size = t.shape[dim]\n    start, stop = get_start_stop_idxs_for_rank(0, size, rank, world_size)\n\n    if dim == 0:\n        tensor = t[start:stop]\n    elif dim == 1:\n        tensor = t[:, start:stop]\n    else:\n        raise NotImplementedError(\"Let's make that generic when needed\")\n\n    return tensor\n\n\ndef shard_lora_weights(\n    weights_a: List[torch.Tensor],\n    weights_b: List[torch.Tensor],\n    split_dim: int,\n    process_group: ProcessGroup,\n) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:\n    # [hidden_size, r]\n    weights_a = [\n        shard_on_dim(w, dim=split_dim, process_group=process_group) for w in weights_a\n    ]\n\n    # [r, hidden_size]\n    weights_b = [shard_on_dim(w, dim=1, process_group=process_group) for w in weights_b]\n\n    return weights_a, weights_b\n\n\n@dataclass\nclass LoraConfig(AdapterConfig):\n    r: int\n    target_modules: Optional[Union[List[str], str]]\n    fan_in_fan_out: bool\n    lora_alpha: int\n    use_rslora: bool\n\n    def map_weights_for_model(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        weight_names: Tuple[str],\n    ) -> Tuple[ModuleMap, Set[str]]:\n        adapter_weight_names = set()\n        module_map = {}\n        for weight_name in weight_names:\n            lora_a_name = f\"base_model.model.{weight_name}.lora_A.weight\"\n            lora_b_name = f\"base_model.model.{weight_name}.lora_B.weight\"\n            if lora_a_name not in adapter_weights or lora_b_name not in adapter_weights:\n                continue\n\n            module_map[weight_name] = {\n                \"lora_A\": (adapter_weights[lora_a_name], lora_a_name),\n                \"lora_B\": (adapter_weights[lora_b_name], lora_b_name),\n            }\n            adapter_weight_names.add(lora_a_name)\n            adapter_weight_names.add(lora_b_name)\n        return module_map, adapter_weight_names\n\n    @classmethod\n    def load(cls, adapter_id: str, api_token: str) -> \"LoraConfig\":\n        hf_config = _LoraConfig.from_pretrained(adapter_id, token=api_token)\n        return cls(\n            base_model_name_or_path=hf_config.base_model_name_or_path,\n            r=hf_config.r,\n            target_modules=hf_config.target_modules,\n            fan_in_fan_out=hf_config.fan_in_fan_out,\n            lora_alpha=hf_config.lora_alpha,\n            use_rslora=(\n                hf_config.use_rslora if hasattr(hf_config, \"use_rslora\") else False\n            ),\n        )\n\n\nclass LoraWeights(AdapterWeights):\n    \"\"\"LoRA weights for a single adapter merged across all layers.\"\"\"\n\n    def __init__(\n        self,\n        weights_a: List[torch.Tensor],\n        weights_b: List[torch.Tensor],\n        adapter_config: LoraConfig,\n    ):\n        self.lora_a_r = weights_a[0].size(1) if len(weights_a) > 0 else 1\n        self.lora_b_r = weights_b[0].size(0) if len(weights_a) > 0 else 1\n\n        self._is_transposed = False\n        if SYSTEM == \"ipex\":\n            self._use_cutlass_shrink = False\n            # [num_layers, r, hidden_size]\n            weights_a = [w.transpose(0, 1).contiguous() for w in weights_a]\n            self._weights_a = torch.stack(weights_a)\n\n            # [num_layers, hidden_size, r]\n            weights_b = [w.transpose(0, 1).contiguous() for w in weights_b]\n            self._weights_b = torch.stack(weights_b)\n        else:\n            self._use_cutlass_shrink = punica_sgmv.use_cutlass_shrink(self.lora_a_r)\n            # [num_layers, hidden_size, r]\n            weights_a = [\n                punica_sgmv.orient_for_rank(w, w.size(1)).contiguous()\n                for w in weights_a\n            ]\n            self._weights_a = torch.stack(weights_a)\n\n            # [num_layers, r, hidden_size]\n            self._weights_b = torch.stack(weights_b)\n\n        self.adapter_config = adapter_config\n\n    @property\n    def weights_a(self) -> torch.Tensor:\n        if self._is_transposed:\n            self._transpose_weights()\n        return self._weights_a\n\n    @property\n    def weights_b(self) -> torch.Tensor:\n        if self._is_transposed:\n            self._transpose_weights()\n        return self._weights_b\n\n    @property\n    def weights_a_t(self) -> torch.Tensor:\n        if not self._is_transposed:\n            self._transpose_weights()\n        return self._weights_a\n\n    @property\n    def weights_b_t(self) -> torch.Tensor:\n        if not self._is_transposed:\n            self._transpose_weights()\n        return self._weights_b\n\n    def _transpose_weights(self):\n        if self._use_cutlass_shrink:\n            # If we're not using the cutlass shrink, then both SGMV and BGMV use the same orientation\n            self._weights_a = self._weights_a.transpose(1, 2).contiguous()\n        self._weights_b = self._weights_b.transpose(1, 2).contiguous()\n        self._is_transposed = not self._is_transposed\n\n    @classmethod\n    def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]:\n        if SYSTEM == \"ipex\":\n            return [IPEXBatchLoraWeights]\n        else:\n            return [BatchLoraWeights]\n\n    # prepare pre-loaded lora weights for use in the model.\n    #\n    # this method processes and organizes lora weights for a specific layer type across all layers:\n    # - uses `config` (LoraConfig) to apply lora-specific settings like scaling factor.\n    # - retrieves weights from `module_map` based on the `layer_type`.\n    # - processes `nlayers` number of layers.\n    # - converts weights to the specified `dtype`.\n    # - shards weights across `world_size` number of processes using the `process_group`.\n    # - maps weights to specific layers using `target_to_layer`.\n    # - tracks `unused_weight_names` to identify any unused weights.\n    #\n    # the method handles weight transposition, scaling, and padding to ensure compatibility\n    # with SGMV or BGMV operations.\n    @classmethod\n    def prepare_weights(\n        cls,\n        config: LoraConfig,\n        module_map: Dict[str, Dict],\n        layer_type: str,\n        unused_weight_names: Set[str],\n        nlayers: int,\n        dtype: torch.dtype,\n        world_size: int,\n        process_group: ProcessGroup,\n        target_to_layer: Dict[str, Tuple[str, torch.Tensor]],\n    ) -> Optional[AdapterWeights]:\n        lora_a_list = [None] * nlayers\n        lora_b_list = [None] * nlayers\n\n        for layer_id in range(nlayers):\n            key = (layer_id, layer_type)\n            if key not in target_to_layer:\n                # There is no layer of this type in the model\n                log_master(\n                    logger.warning,\n                    f\"Key specified in lora weights but not found in base model: {key}\",\n                )\n                return None\n\n            weight_name, layer = target_to_layer[key]\n            base_weight = layer.base_layer.linear.weight\n            base_device = base_weight.device\n\n            if weight_name not in module_map:\n                # There is no LoRA weight for this layer type in the adapter\n                return None\n\n            lora_a, lora_a_name = module_map[weight_name][\"lora_A\"]\n            lora_a = lora_a.to(base_device, dtype)\n\n            lora_b, lora_b_name = module_map[weight_name][\"lora_B\"]\n            lora_b = lora_b.to(base_device, dtype)\n\n            scale = get_scaling_factor(\n                config.lora_alpha,\n                config.r,\n                uses_rslora=config.use_rslora,\n            )\n\n            unused_weight_names.discard(lora_a_name)\n            unused_weight_names.discard(lora_b_name)\n\n            # Merge scaling factor into lora_b due to associativity of matrix multiplication:\n            # (A * B) * C = A * (B * C)\n            lora_a_list[layer_id] = lora_a.transpose(0, 1)\n            lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale\n\n        # pad lora ranks to be compatible with sgmv\n        if SYSTEM != \"ipex\":\n            lora_a_list = [\n                punica_sgmv.pad_rank(w, dim=1, world_size=world_size)\n                for w in lora_a_list\n            ]\n            lora_b_list = [\n                punica_sgmv.pad_rank(w, dim=0, world_size=world_size)\n                for w in lora_b_list\n            ]\n\n            if lora_a_list:\n                # update rank if it was padded\n                padded_rank = lora_a_list[0].size(1)\n                config.r = padded_rank\n\n        return LoraWeights(\n            *shard_lora_weights(\n                weights_a=lora_a_list,\n                weights_b=lora_b_list,\n                split_dim=0 if layer_type in {\"o_proj\", \"down_proj\", \"lm_head\"} else 1,\n                process_group=process_group,\n            ),\n            config,\n        )\n\n\n@dataclass\nclass RankSegments:\n    rank: int\n\n    lora_a_ptr: torch.Tensor\n    lora_b_ptr: torch.Tensor\n\n    # prefill (sgmv)\n    tmp_shrink: torch.Tensor\n    tmp_expand: torch.Tensor\n    segment_starts: torch.Tensor\n    segment_ends: torch.Tensor\n\n    # decode (bgmv)\n    indices: torch.Tensor\n\n\n@dataclass\nclass BatchLoraWeights(BatchAdapterWeights):\n    lora_a: Dict[int, torch.Tensor]\n    lora_b: Dict[int, torch.Tensor]\n    adapter_index_configs: Dict[int, LoraConfig]\n    rank_data: Dict[int, RankSegments]\n    use_sgmv: bool\n\n    def has_adapter(self, adapter_index: int) -> bool:\n        return adapter_index in self.adapter_index_configs\n\n    def can_vectorize(self, pg: ProcessGroup) -> bool:\n        return all(\n            rank_data.rank // pg.size() <= punica_sgmv.MAX_RANK_CUSTOM\n            for rank_data in self.rank_data.values()\n        )\n\n    @classmethod\n    def load(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        meta: AdapterBatchMetadata,\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> Optional[\"BatchLoraWeights\"]:\n        adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}\n        adapter_weights = {\n            k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)\n        }\n        if not adapter_weights:\n            return None\n\n        first_weights = next(iter(adapter_weights.values()))\n        device = first_weights.weights_a.device\n        segment_indices = meta.segment_indices\n\n        lora_a = {\n            idx: adapter_weights[idx].weights_a\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n        lora_b = {\n            idx: adapter_weights[idx].weights_b\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n\n        max_rank = max(\n            (\n                adapter_weights[idx].lora_a_r\n                for idx in segment_indices\n                if idx in adapter_weights\n            ),\n            default=0,\n        )\n\n        use_sgmv = False\n        if prefill or max_rank > punica_sgmv.BGMV_MAX_RANK:\n            if punica_sgmv is not None:\n                use_sgmv = True\n            lora_a_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_a.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n            lora_b_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_b.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n        else:\n            lora_a_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_a_t.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n            lora_b_ptr = torch.tensor(\n                [\n                    (\n                        adapter_weights[idx].weights_b_t.data_ptr()\n                        if idx in adapter_weights\n                        else 0\n                    )\n                    for idx in segment_indices\n                ],\n                dtype=torch.int64,\n                device=device,\n            )\n\n        adapter_index_configs = {\n            idx: adapter_weights[idx].adapter_config\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n\n        adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}\n\n        rank_indices = defaultdict(list)\n        for segment_idx, adapter_idx in enumerate(segment_indices):\n            if adapter_idx not in adapter_weights:\n                continue\n            rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)\n\n        if prefill_head_indices is not None:\n            j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]\n            for head_index in prefill_head_indices:\n                # j cannot go out of bounds as that would mean there are tokens without corresponding adapters\n                if head_index < meta.adapter_segments[j]:\n                    prefill_head_segment_ends[-1] += 1\n                else:\n                    prefill_head_segment_starts.append(prefill_head_segment_ends[-1])\n                    prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)\n                    j += 1\n\n        rank_data = {}\n        for rank, indices in rank_indices.items():\n            tmp_shrink = None\n            tmp_expand = None\n            segment_starts = None\n            segment_ends = None\n            batch_indices = None\n\n            if use_sgmv:\n                lora_a_ptr_indices = lora_a_ptr[indices]\n                tmp_shrink, tmp_expand = punica_sgmv.get_tmp_tensors(\n                    lora_a_ptr_indices.size(0), rank, device\n                )\n                segment_starts = meta.adapter_segments[indices]\n                segment_ends = meta.adapter_segments[[i + 1 for i in indices]]\n                if prefill_head_indices is not None:\n                    for i, segment_index in enumerate(indices):\n                        segment_starts[i] = prefill_head_segment_starts[segment_index]\n                        segment_ends[i] = prefill_head_segment_ends[segment_index]\n            else:\n                rank_indices = set(indices)\n                batch_indices = [\n                    adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()\n                ]\n                batch_indices = [\n                    idx if idx in rank_indices else -1 for idx in batch_indices\n                ]\n                batch_indices = torch.tensor(\n                    batch_indices, dtype=torch.int64, device=device\n                )\n\n            rank_data[rank] = RankSegments(\n                rank=rank,\n                tmp_shrink=tmp_shrink,\n                tmp_expand=tmp_expand,\n                lora_a_ptr=lora_a_ptr[indices],\n                lora_b_ptr=lora_b_ptr[indices],\n                segment_starts=segment_starts,\n                segment_ends=segment_ends,\n                indices=batch_indices,\n            )\n\n        return BatchLoraWeights(\n            lora_a=lora_a,\n            lora_b=lora_b,\n            adapter_index_configs=adapter_index_configs,\n            rank_data=rank_data,\n            use_sgmv=use_sgmv,\n        )\n\n\n@dataclass\nclass IPEXBatchLoraWeights(BatchLoraWeights):\n    @classmethod\n    def load(\n        self,\n        adapter_weights: Dict[int, AdapterWeights],\n        meta: AdapterBatchMetadata,\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> Optional[\"BatchLoraWeights\"]:\n        adapter_weights = {k: _convert_lora(v) for k, v in adapter_weights.items()}\n        adapter_weights = {\n            k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)\n        }\n        if not adapter_weights:\n            return None\n\n        first_weights = next(iter(adapter_weights.values()))\n        device = first_weights.weights_a.device\n        segment_indices = meta.segment_indices\n\n        lora_a = {\n            idx: adapter_weights[idx].weights_a\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n        lora_b = {\n            idx: adapter_weights[idx].weights_b\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n        adapter_index_configs = {\n            idx: adapter_weights[idx].adapter_config\n            for idx in segment_indices\n            if idx in adapter_weights\n        }\n        if len(lora_a) != 0:\n            lora_a_ptr = torch.stack(list(lora_a.values()))\n        if len(lora_b) != 0:\n            lora_b_ptr = torch.stack(list(lora_b.values()))\n\n        use_sgmv = True if prefill else False\n\n        adapter_to_segment = {v: k for k, v in enumerate(segment_indices)}\n\n        rank_indices = defaultdict(list)\n        for segment_idx, adapter_idx in enumerate(segment_indices):\n            if adapter_idx not in adapter_weights:\n                continue\n            rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx)\n\n        if prefill_head_indices is not None:\n            j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0]\n            for head_index in prefill_head_indices:\n                # j cannot go out of bounds as that would mean there are tokens without corresponding adapters\n                if head_index < meta.adapter_segments[j]:\n                    prefill_head_segment_ends[-1] += 1\n                else:\n                    prefill_head_segment_starts.append(prefill_head_segment_ends[-1])\n                    prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1)\n                    j += 1\n\n        rank_data = {}\n        segment_starts = None\n        segment_ends = None\n        if use_sgmv:\n            segment_starts = meta.adapter_segments[:-1]\n            segment_ends = meta.adapter_segments[1:]\n            if prefill_head_indices is not None:\n                segment_starts = prefill_head_segment_starts[:-1]\n                segment_ends = prefill_head_segment_ends[1:]\n        batch_indices = [\n            adapter_to_segment[idx] for idx in meta.adapter_indices.tolist()\n        ]\n        for rank, indices in rank_indices.items():\n            adapters_indices = []\n            lora_a_keys = list(lora_a.keys())\n            for segment_idx in batch_indices:\n                if segment_idx in indices:\n                    adapters_indices.append(\n                        lora_a_keys.index(segment_indices[segment_idx])\n                    )\n                else:\n                    adapters_indices.append(-1)\n            adapters_indices = torch.tensor(\n                adapters_indices, dtype=torch.int64, device=device\n            )\n            if use_sgmv:\n                adapters_indices = adapters_indices[segment_starts]\n            rank_data[rank] = RankSegments(\n                rank=rank,\n                tmp_shrink=None,\n                tmp_expand=None,\n                lora_a_ptr=lora_a_ptr,\n                lora_b_ptr=lora_b_ptr,\n                segment_starts=segment_starts,\n                segment_ends=segment_ends,\n                indices=adapters_indices,\n            )\n\n        return BatchLoraWeights(\n            lora_a=lora_a,\n            lora_b=lora_b,\n            adapter_index_configs=adapter_index_configs,\n            rank_data=rank_data,\n            use_sgmv=use_sgmv,\n        )\n\n\ndef get_scaling_factor(\n    lora_alpha: int,\n    r: int,\n    uses_rslora: bool = False,\n) -> float:\n    \"\"\"Computes the scaling factor for the lora weights.\"\"\"\n    if uses_rslora:\n        return lora_alpha / (r**0.5)\n    return lora_alpha / r\n\n\ndef _convert_lora(v: AdapterWeights) -> AdapterWeights:\n    if hasattr(v, \"lora_weights\"):\n        return v.lora_weights\n    return v\n"
  },
  {
    "path": "server/text_generation_server/adapters/weights.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/adapters/weights.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom abc import ABC, abstractclassmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Set, Type\n\nimport torch\n\n\n@dataclass\nclass AdapterBatchMetadata:\n    # [batch_size]\n    adapter_indices: torch.Tensor\n\n    # [num_adapters]\n    adapter_set: Set[int]\n\n    # [num_segments + 1]\n    adapter_segments: torch.Tensor\n\n    # [num_segments]\n    # maps from segment index to adapter index, i.e.:\n    # segment_indices[s] == adapter_indices[i]\n    segment_indices: List[int]\n\n\nclass AdapterWeights(ABC):\n    @abstractclassmethod\n    def get_batch_types(cls) -> List[Type[\"BatchAdapterWeights\"]]:\n        pass\n\n    @property\n    def speculative_tokens(self) -> int:\n        return 0\n\n\nclass BatchAdapterWeights(ABC):\n    @abstractclassmethod\n    def has_adapter(self, adapter_index: int) -> bool:\n        pass\n\n    @abstractclassmethod\n    def load(\n        cls,\n        adapter_weights: Dict[int, AdapterWeights],\n        meta: \"AdapterBatchMetadata\",\n        prefill: bool,\n        prefill_head_indices: torch.Tensor,\n    ) -> Optional[\"BatchAdapterWeights\"]:\n        pass\n\n\nclass LayerAdapterWeights:\n    \"\"\"Adapter weights that apply to a particular layer.\"\"\"\n\n    def __init__(self):\n        self.adapter_weights: Dict[int, AdapterWeights] = {}\n\n    def add_adapter(self, adapter_idx: int, weights: AdapterWeights):\n        self.adapter_weights[adapter_idx] = weights\n\n    def remove_adapter(self, adapter_idx: int):\n        if adapter_idx not in self.adapter_weights:\n            return\n        del self.adapter_weights[adapter_idx]\n\n    def is_empty(self) -> bool:\n        return len(self.adapter_weights) == 0\n\n    def get_data(\n        self,\n        meta: AdapterBatchMetadata,\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> Dict[str, BatchAdapterWeights]:\n        # bucket adapters by batch class\n        adapter_batch_types: Dict[\n            Type[BatchAdapterWeights], Dict[int, AdapterWeights]\n        ] = defaultdict(dict)\n        for adapter_index, adapter_weights in self.adapter_weights.items():\n            for batch_type in adapter_weights.get_batch_types():\n                adapter_batch_types[batch_type][adapter_index] = adapter_weights\n\n        batch_data = {}\n        for batch_type, adapter_weights in adapter_batch_types.items():\n            batched_weights = batch_type.load(\n                adapter_weights, meta, prefill, prefill_head_indices\n            )\n            if batched_weights is not None:\n                batch_data = batched_weights\n        return batch_data\n\n\n@dataclass\nclass AdapterBatchData:\n    meta: AdapterBatchMetadata\n\n    # layer type -> adapter type -> batch weight data\n    data: Dict[str, Dict[str, BatchAdapterWeights]]\n\n    prefill: bool\n\n    @staticmethod\n    def from_meta(\n        meta: AdapterBatchMetadata,\n        weights: Dict[str, LayerAdapterWeights],\n        prefill: bool,\n        prefill_head_indices: Optional[torch.Tensor],\n    ) -> \"AdapterBatchData\":\n        data = {}\n        for k, v in weights.items():\n            if v.is_empty():\n                continue\n            data[k] = v.get_data(\n                meta, prefill, prefill_head_indices if k == \"lm_head\" else None\n            )\n        return AdapterBatchData(meta=meta, data=data, prefill=prefill)\n\n    def ranks(self) -> Set[int]:\n        # TODO(travis): refactor to be less coupled to lora implementation\n        ranks = set()\n        for lora_data in self.data.values():\n            if lora_data is None:\n                continue\n\n            for rank_data in lora_data.rank_data.values():\n                ranks.add(rank_data.rank)\n\n        return ranks\n\n    def layer_names(self) -> Set[str]:\n        return set(self.data.keys())\n\n    def adapter_keys(self) -> Set[str]:\n        adapter_keys = set()\n        for layer_data in self.data.values():\n            adapter_keys.update(layer_data.keys())\n        return adapter_keys\n\n    @property\n    def max_rank(self) -> int:\n        ranks = self.ranks()\n        return max(ranks) if len(ranks) > 0 else 0\n"
  },
  {
    "path": "server/text_generation_server/cache.py",
    "content": "import torch\n\nfrom typing import Dict, Optional, TypeVar\n\nfrom text_generation_server.models.types import Batch\n\nB = TypeVar(\"B\", bound=Batch)\n\n\nclass Cache:\n    def __init__(self):\n        self.cache: Dict[int, B] = {}\n\n    def pop(self, batch_id: int) -> Optional[B]:\n        return self.cache.pop(batch_id, None)\n\n    def set(self, entry: B):\n        if entry is not None:\n            self.cache[entry.batch_id] = entry\n\n    def delete(self, batch_id: int):\n        batch = self.pop(batch_id)\n        if batch is not None:\n            del batch\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    def clear(self):\n        keys = list(self.cache.keys())\n        for k in keys:\n            self.delete(k)\n\n    def __len__(self):\n        return len(self.cache.keys())\n"
  },
  {
    "path": "server/text_generation_server/cli.py",
    "content": "import os\nimport sys\nimport typer\n\nfrom pathlib import Path\nfrom loguru import logger\nfrom typing import Optional\nfrom enum import Enum\nfrom huggingface_hub import hf_hub_download\nfrom text_generation_server.utils.adapter import parse_lora_adapters\n\n# Dummy change should cache hit.\n\n\napp = typer.Typer()\n\n\nclass Quantization(str, Enum):\n    bitsandbytes = \"bitsandbytes\"\n    bitsandbytes_nf4 = \"bitsandbytes-nf4\"\n    bitsandbytes_fp4 = \"bitsandbytes-fp4\"\n    gptq = \"gptq\"\n    awq = \"awq\"\n    compressed_tensors = \"compressed-tensors\"\n    eetq = \"eetq\"\n    exl2 = \"exl2\"\n    fp8 = \"fp8\"\n    marlin = \"marlin\"\n\n\nclass Dtype(str, Enum):\n    float16 = \"float16\"\n    bloat16 = \"bfloat16\"\n\n\nclass KVCacheDtype(str, Enum):\n    fp8_e4m3fn = \"fp8_e4m3fn\"\n    fp8_e5m2 = \"fp8_e5m2\"\n\n\n@app.command()\ndef serve(\n    model_id: str,\n    revision: Optional[str] = None,\n    sharded: bool = False,\n    quantize: Optional[Quantization] = None,\n    speculate: Optional[int] = None,\n    dtype: Optional[Dtype] = None,\n    kv_cache_dtype: Optional[KVCacheDtype] = None,\n    trust_remote_code: bool = False,\n    uds_path: Path = \"/tmp/text-generation-server\",\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    otlp_endpoint: Optional[str] = None,\n    otlp_service_name: str = \"text-generation-inference.server\",\n    max_input_tokens: Optional[int] = None,\n):\n    if sharded:\n        assert (\n            os.getenv(\"RANK\", None) is not None\n        ), \"RANK must be set when sharded is True\"\n        assert (\n            os.getenv(\"WORLD_SIZE\", None) is not None\n        ), \"WORLD_SIZE must be set when sharded is True\"\n        assert (\n            os.getenv(\"MASTER_ADDR\", None) is not None\n        ), \"MASTER_ADDR must be set when sharded is True\"\n        assert (\n            os.getenv(\"MASTER_PORT\", None) is not None\n        ), \"MASTER_PORT must be set when sharded is True\"\n\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    # Import here after the logger is added to log potential import exceptions\n    from text_generation_server import server\n    from text_generation_server.tracing import setup_tracing\n\n    # Setup OpenTelemetry distributed tracing\n    if otlp_endpoint is not None:\n        setup_tracing(otlp_service_name=otlp_service_name, otlp_endpoint=otlp_endpoint)\n\n    lora_adapters = parse_lora_adapters(os.getenv(\"LORA_ADAPTERS\"))\n\n    # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled\n    # and warn the user\n    if lora_adapters:\n        logger.warning(\"LoRA adapters enabled (experimental feature).\")\n\n        if \"CUDA_GRAPHS\" in os.environ:\n            logger.warning(\n                \"LoRA adapters incompatible with CUDA Graphs. Disabling CUDA Graphs.\"\n            )\n            global CUDA_GRAPHS\n            CUDA_GRAPHS = None\n\n    # Downgrade enum into str for easier management later on\n    quantize = None if quantize is None else quantize.value\n    dtype = None if dtype is None else dtype.value\n    kv_cache_dtype = None if kv_cache_dtype is None else kv_cache_dtype.value\n    if dtype is not None and quantize not in {\n        None,\n        \"bitsandbytes\",\n        \"bitsandbytes-nf4\",\n        \"bitsandbytes-fp4\",\n    }:\n        raise RuntimeError(\n            \"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model.\"\n        )\n    server.serve(\n        model_id,\n        lora_adapters,\n        revision,\n        sharded,\n        quantize,\n        speculate,\n        dtype,\n        kv_cache_dtype,\n        trust_remote_code,\n        uds_path,\n        max_input_tokens,\n    )\n\n\n@app.command()\ndef download_weights(\n    model_id: str,\n    revision: Optional[str] = None,\n    extension: str = \".safetensors\",\n    auto_convert: bool = True,\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    trust_remote_code: bool = False,\n    merge_lora: bool = False,\n):\n    # Remove default handler\n    logger.remove()\n    logger.add(\n        sys.stdout,\n        format=\"{message}\",\n        filter=\"text_generation_server\",\n        level=logger_level,\n        serialize=json_output,\n        backtrace=True,\n        diagnose=False,\n    )\n\n    # Import here after the logger is added to log potential import exceptions\n    from text_generation_server import utils\n\n    # Test if files were already download\n    try:\n        utils.weight_files(model_id, revision, extension)\n        logger.info(\"Files are already present on the host. \" \"Skipping download.\")\n        return\n    # Local files not found\n    except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):\n        pass\n\n    is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(\n        \"WEIGHTS_CACHE_OVERRIDE\", None\n    ) is not None\n\n    if not is_local_model:\n        # TODO: maybe reverse the default value of merge_lora?\n        # currently by default we don't merge the weights with the base model\n        if merge_lora:\n            try:\n                hf_hub_download(\n                    model_id, revision=revision, filename=\"adapter_config.json\"\n                )\n                utils.download_and_unload_peft(\n                    model_id, revision, trust_remote_code=trust_remote_code\n                )\n                is_local_model = True\n                utils.weight_files(model_id, revision, extension)\n                return\n            except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n                pass\n        else:\n            try:\n                utils.peft.download_peft(\n                    model_id, revision, trust_remote_code=trust_remote_code\n                )\n            except Exception:\n                pass\n\n        try:\n            import json\n\n            config = hf_hub_download(\n                model_id, revision=revision, filename=\"config.json\"\n            )\n            with open(config, \"r\") as f:\n                config = json.load(f)\n\n            base_model_id = config.get(\"base_model_name_or_path\", None)\n            if base_model_id and base_model_id != model_id:\n                try:\n                    logger.info(f\"Downloading parent model {base_model_id}\")\n                    download_weights(\n                        model_id=base_model_id,\n                        revision=\"main\",\n                        extension=extension,\n                        auto_convert=auto_convert,\n                        logger_level=logger_level,\n                        json_output=json_output,\n                        trust_remote_code=trust_remote_code,\n                    )\n                except Exception:\n                    pass\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n\n        # Try to download weights from the hub\n        try:\n            filenames = utils.weight_hub_files(model_id, revision, extension)\n            utils.download_weights(filenames, model_id, revision)\n            # Successfully downloaded weights\n            return\n\n        # No weights found on the hub with this extension\n        except utils.EntryNotFoundError as e:\n            # Check if we want to automatically convert to safetensors or if we can use .bin weights instead\n            if not extension == \".safetensors\" or not auto_convert:\n                raise e\n\n    elif (Path(model_id) / \"adapter_config.json\").exists():\n        # Try to load as a local PEFT model\n        try:\n            utils.download_and_unload_peft(\n                model_id, revision, trust_remote_code=trust_remote_code\n            )\n            utils.weight_files(model_id, revision, extension)\n            return\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n    elif (Path(model_id) / \"config.json\").exists():\n        # Try to load as a local Medusa model\n        try:\n            import json\n\n            config = Path(model_id) / \"config.json\"\n            with open(config, \"r\") as f:\n                config = json.load(f)\n\n            base_model_id = config.get(\"base_model_name_or_path\", None)\n            if base_model_id:\n                try:\n                    logger.info(f\"Downloading parent model {base_model_id}\")\n                    download_weights(\n                        model_id=base_model_id,\n                        revision=\"main\",\n                        extension=extension,\n                        auto_convert=auto_convert,\n                        logger_level=logger_level,\n                        json_output=json_output,\n                        trust_remote_code=trust_remote_code,\n                    )\n                except Exception:\n                    pass\n        except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n            pass\n\n    # Try to see if there are local pytorch weights\n    try:\n        # Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE\n        try:\n            local_pt_files = utils.weight_files(model_id, revision, \".bin\")\n        except Exception:\n            local_pt_files = utils.weight_files(model_id, revision, \".pt\")\n\n    # No local pytorch weights\n    except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):\n        if extension == \".safetensors\":\n            logger.warning(\n                f\"No safetensors weights found for model {model_id} at revision {revision}. \"\n                f\"Downloading PyTorch weights.\"\n            )\n\n        # Try to see if there are pytorch weights on the hub\n        pt_filenames = utils.weight_hub_files(model_id, revision, \".bin\")\n        # Download pytorch weights\n        local_pt_files = utils.download_weights(pt_filenames, model_id, revision)\n\n    if auto_convert:\n        if not trust_remote_code:\n            logger.warning(\n                \"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because \"\n                \"Pickle files are unsafe and can essentially contain remote code execution!\"\n                \"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety\",\n            )\n\n        logger.warning(\n            f\"No safetensors weights found for model {model_id} at revision {revision}. \"\n            f\"Converting PyTorch weights to safetensors.\"\n        )\n\n        # Safetensors final filenames\n        local_st_files = [\n            p.parent / f\"{p.stem.lstrip('pytorch_')}.safetensors\"\n            for p in local_pt_files\n        ]\n        try:\n            import transformers\n            import json\n\n            if is_local_model:\n                config_filename = os.path.join(model_id, \"config.json\")\n            else:\n                config_filename = hf_hub_download(\n                    model_id, revision=revision, filename=\"config.json\"\n                )\n            with open(config_filename, \"r\") as f:\n                config = json.load(f)\n            architecture = config[\"architectures\"][0]\n\n            class_ = getattr(transformers, architecture)\n\n            # Name for this varible depends on transformers version.\n            discard_names = getattr(class_, \"_tied_weights_keys\", [])\n\n        except Exception:\n            discard_names = []\n        # Convert pytorch weights to safetensors\n        utils.convert_files(local_pt_files, local_st_files, discard_names)\n\n\n@app.command()\ndef quantize(\n    model_id: str,\n    output_dir: str,\n    revision: Optional[str] = None,\n    logger_level: str = \"INFO\",\n    json_output: bool = False,\n    trust_remote_code: bool = False,\n    upload_to_model_id: Optional[str] = None,\n    percdamp: float = 0.01,\n    act_order: bool = False,\n    groupsize: int = 128,\n):\n    if revision is None:\n        revision = \"main\"\n    download_weights(\n        model_id=model_id,\n        revision=revision,\n        logger_level=logger_level,\n        json_output=json_output,\n    )\n    from text_generation_server.layers.gptq.quantize import quantize\n\n    quantize(\n        model_id=model_id,\n        bits=4,\n        groupsize=groupsize,\n        output_dir=output_dir,\n        revision=revision,\n        trust_remote_code=trust_remote_code,\n        upload_to_model_id=upload_to_model_id,\n        percdamp=percdamp,\n        act_order=act_order,\n        sym=True,\n    )\n\n\nif __name__ == \"__main__\":\n    app()\n"
  },
  {
    "path": "server/text_generation_server/interceptor.py",
    "content": "import torch\nimport grpc\n\nfrom google.rpc import status_pb2, code_pb2\nfrom grpc_status import rpc_status\nfrom grpc_interceptor.server import AsyncServerInterceptor\nfrom loguru import logger\nfrom typing import Callable, Any\n\n\nclass ExceptionInterceptor(AsyncServerInterceptor):\n    def __init__(self, shutdown_callback):\n        self.shutdown_callback = shutdown_callback\n\n    async def intercept(\n        self,\n        method: Callable,\n        request_or_iterator: Any,\n        context: grpc.ServicerContext,\n        method_name: str,\n    ) -> Any:\n        try:\n            response = method(request_or_iterator, context)\n            return await response\n        except Exception as err:\n            method_name = method_name.split(\"/\")[-1]\n            logger.exception(f\"Method {method_name} encountered an error.\")\n\n            # Runtime Error cannot be recovered from\n            if isinstance(err, RuntimeError):\n                self.shutdown_callback()\n\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n            await context.abort_with_status(\n                rpc_status.to_status(\n                    status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))\n                )\n            )\n"
  },
  {
    "path": "server/text_generation_server/layers/__init__.py",
    "content": "from text_generation_server.layers.tensor_parallel import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n)\nfrom text_generation_server.layers.linear import (\n    get_linear,\n    FastLinear,\n)\nfrom text_generation_server.layers.speculative import SpeculativeHead\n\n# Just to add the `load` methods.\nfrom text_generation_server.layers.layernorm import load_layer_norm\nfrom text_generation_server.layers.conv import load_conv2d\n\nfrom text_generation_server.layers.lora import (\n    LoraLinear,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\n\n__all__ = [\n    \"get_linear\",\n    \"FastLinear\",\n    \"TensorParallelColumnLinear\",\n    \"TensorParallelRowLinear\",\n    \"TensorParallelEmbedding\",\n    \"SpeculativeHead\",\n    \"LoraLinear\",\n    \"TensorParallelMultiAdapterLinear\",\n    \"TensorParallelAdapterRowLinear\",\n    \"load_layer_norm\",\n    \"load_conv2d\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/__init__.py",
    "content": "import os\n\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nfrom .common import Seqlen\n\nif os.getenv(\"USE_FLASH_ATTENTION\", \"\").lower() == \"false\":\n    raise ImportError(\"`USE_FLASH_ATTENTION` is false.\")\nif SYSTEM == \"cuda\":\n    from .cuda import (\n        SUPPORTS_WINDOWING,\n        attention,\n        paged_attention,\n    )\nelif SYSTEM == \"rocm\":\n    from .rocm import (\n        SUPPORTS_WINDOWING,\n        attention,\n        paged_attention,\n    )\nelif SYSTEM == \"ipex\":\n    from .ipex import (\n        SUPPORTS_WINDOWING,\n        attention,\n        paged_attention,\n    )\nelse:\n    raise ImportError(f\"System {SYSTEM} doesn't support flash/paged attention\")\n\n# KVCache needs `reshape_and_cache`, so ensure that it is defined already.\nfrom .kv_cache import KVCache, get_kv_scales\n\n__all__ = [\n    \"attention\",\n    \"get_kv_scales\",\n    \"paged_attention\",\n    \"SUPPORTS_WINDOWING\",\n    \"KVCache\",\n    \"Seqlen\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/common.py",
    "content": "from dataclasses import dataclass\nimport torch\nfrom typing import Optional\n\n\n@dataclass\nclass Seqlen:\n    input_lengths: torch.Tensor\n    cache_lengths: torch.Tensor\n    cu_seqlen_q: Optional[torch.Tensor]\n    cu_seqlen_k: Optional[torch.Tensor]\n    max_q: int\n    max_k: int\n\n    def __init__(\n        self,\n        input_lengths,\n        cache_lengths,\n        cu_seqlen_q=None,\n        max_q=None,\n        max_k=None,\n    ):\n        self.input_lengths = input_lengths\n        self.cache_lengths = cache_lengths\n        device = self.input_lengths.device\n        shape = self.input_lengths.shape\n        if cu_seqlen_q is None:\n            cu_seqlen_q = torch.arange(\n                shape[0] + 1,\n                device=device,\n                dtype=torch.int32,\n            )\n            max_q = 1\n        else:\n            assert max_q is not None\n        assert max_k is not None\n        cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)\n\n        # cuda graphs don't like this and this is necessary to clamp within mistral\n        # Although FA2 might not want the clamping\n        # cu_seqlen_k[0] = 0\n        total = self.input_lengths + self.cache_lengths\n        torch.cumsum(total, -1, out=cu_seqlen_k[1:])\n\n        self.cu_seqlen_q = cu_seqlen_q\n        self.cu_seqlen_k = cu_seqlen_k\n        self.max_q = max_q\n        self.max_k = max_k\n\n    def clamp(self, max):\n        # Flash decoding doesn't need to clamp\n        return self\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/cuda.py",
    "content": "import torch\nfrom text_generation_server.layers.attention.kv_cache import KVCache, KVScales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.models.globals import (\n    ATTENTION,\n    BLOCK_SIZE,\n)\nfrom text_generation_server.layers.attention import Seqlen\nfrom typing import Optional\n\n\nmajor, minor = torch.cuda.get_device_capability()\nis_sm75 = major == 7 and minor == 5\n_PARTITION_SIZE = 512\n\nif SYSTEM == \"cuda\":\n    try:\n        paged_attention_kernels = load_kernel(\n            module=\"paged_attention\", repo_id=\"kernels-community/paged-attention\"\n        )\n    except Exception as e:\n        raise ImportError(\n            f\"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}\"\n        )\nelse:\n    paged_attention_kernels = None\n\n\ndef paged_attention(\n    query: torch.Tensor,\n    kv_cache: KVCache,\n    kv_head_mapping: torch.Tensor,\n    softmax_scale: float,\n    block_tables: torch.Tensor,\n    seqlen: Seqlen,\n    max_s: int,\n    *,\n    kv_scales: KVScales,\n    softcap: Optional[float] = None,\n    window_size_left: Optional[int] = -1,\n):\n    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py\n    # Copyright 2023 The vLLM team. All rights\n    # reserved.\n    #\n    # Licensed under the Apache License, Version 2.0 (the \"License\");\n    # you may not use this file except in compliance with the License.\n    # You may obtain a copy of the License at\n    #\n    #     http://www.apache.org/licenses/LICENSE-2.0\n    #\n    # Unless required by applicable law or agreed to in writing, software\n    # distributed under the License is distributed on an \"AS IS\" BASIS,\n    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    # See the License for the specific language governing permissions and\n    # limitations under the License.\n    #\n\n    # value_cache => [num_blocks, num_heads, head_size, block_size]\n    # block_size = value_cache.shape[3]\n    block_size = BLOCK_SIZE\n    num_seqs, num_heads, head_size = query.shape\n    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE\n\n    can_scale = kv_cache.can_scale(kv_scales)\n\n    # NOTE(woosuk): We use a simple heuristic to decide whether to use\n    # PagedAttention V1 or V2. If the number of partitions is 1, we use\n    # V1 to avoid the overhead of reduction. Also, if the number of\n    # sequences or heads is large, we use V1 since there is enough work\n    # to parallelize.\n    if ATTENTION == \"flashinfer\":\n        from text_generation_server.layers.attention.flashinfer import decode_state\n\n        return decode_state.get().forward(\n            query,\n            paged_kv_cache=(kv_cache.key, kv_cache.value),\n            logits_soft_cap=softcap,\n            sm_scale=softmax_scale,\n            k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,\n            v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,\n            window_left=window_size_left,\n        )\n    elif ATTENTION == \"flashdecoding\":\n        max_q = 1\n        max_k = max_s\n        import flash_attn_2_cuda\n\n        window_size_right = -1 if window_size_left == -1 else 0\n\n        # TODO fixme when flash contains the fix.\n        # Number of splits is not correctly handled\n        # by the current path\n        # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577\n        # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.\n        if softcap is None:\n            softcap = 0.0\n        out = flash_attn_2_cuda.varlen_fwd(\n            query,\n            kv_cache.key,\n            kv_cache.value,\n            None,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            None,  # pad_k\n            None,\n            block_tables,\n            None,\n            max_q,\n            max_k,\n            0.0,  # dropout\n            softmax_scale,\n            False,  # zero_tensors\n            True,  # causal\n            window_size_left,  # Window_left\n            window_size_right,  # Window right\n            softcap,\n            False,  # return softmax\n            None,  # generator\n        )\n        return out[0]\n    else:\n        if softcap is not None:\n            raise RuntimeError(\"Paged attention doesn't support softcapping\")\n        input_lengths = seqlen.input_lengths + seqlen.cache_lengths\n\n        out = torch.empty_like(query)\n\n        kv_cache_dtype = \"fp8\" if kv_cache.dtype == torch.float8_e4m3fn else \"auto\"\n\n        use_v1 = max_s <= 8192 and (\n            max_num_partitions == 1 or num_seqs * num_heads > 512\n        )\n        if use_v1:\n            paged_attention_kernels.paged_attention_v1(\n                out,\n                query,\n                kv_cache.key,\n                kv_cache.value,\n                kv_cache.key.shape[1],\n                softmax_scale,\n                block_tables,\n                input_lengths,\n                block_size,\n                max_s,\n                None,\n                kv_cache_dtype,\n                torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),\n                torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),\n            )\n        else:\n            # Run PagedAttention V2.\n            assert _PARTITION_SIZE % block_size == 0\n            tmp_output = torch.empty(\n                size=(num_seqs, num_heads, max_num_partitions, head_size),\n                dtype=out.dtype,\n                device=out.device,\n            )\n            exp_sums = torch.empty(\n                size=(num_seqs, num_heads, max_num_partitions),\n                dtype=torch.float32,\n                device=out.device,\n            )\n            max_logits = torch.empty_like(exp_sums)\n\n            paged_attention_kernels.paged_attention_v2(\n                out,\n                exp_sums,\n                max_logits,\n                tmp_output,\n                query,\n                kv_cache.key,\n                kv_cache.value,\n                kv_cache.key.shape[1],\n                softmax_scale,\n                block_tables,\n                input_lengths,\n                block_size,\n                max_s,\n                None,\n                kv_cache_dtype,\n                torch.tensor(kv_scales.key_scale_cpu if can_scale else 1.0),\n                torch.tensor(kv_scales.value_scale_cpu if can_scale else 1.0),\n            )\n    return out\n\n\ntry:\n    is_ampere_or_newer = major >= 8 and minor >= 0\n    if not is_ampere_or_newer:\n        raise ImportError(\"FlashAttention only supports Ampere GPUs or newer.\")\n\n    import flash_attn_2_cuda\n\n    V2 = True\nexcept ImportError:\n    try:\n        import flash_attn_cuda\n\n        V2 = False\n    except ImportError as e:\n        if major >= 8:\n            architecture_suffix = f\"-{SYSTEM}\"\n            raise ImportError(\n                \"Flash Attention V2 is not installed.\\n\"\n                \"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) \"\n                f\"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`\"\n            )\n        elif is_sm75:\n            raise ImportError(\n                \"Flash Attention is not installed.\\n\"\n                \"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) \"\n                \"or install flash attention with `cd server && make install install-flash-attention`\"\n            ) from e\n        else:\n            raise ImportError(\n                f\"GPU with CUDA capability {major} {minor} is not supported\"\n            ) from e\n\n\nif ATTENTION == \"flashdecoding\" and not V2:\n    raise ValueError(\"Flash decoding requires Flash Attention V2\")\n\nSUPPORTS_WINDOWING = V2\n\n\ndef attention(\n    *,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    kv_cache: KVCache,\n    kv_scales: KVScales,\n    seqlen: Seqlen,\n    block_tables: torch.Tensor,\n    softmax_scale: float,\n    window_size_left: int = -1,\n    causal: bool = True,\n    softcap: Optional[float] = None,\n):\n    can_scale = kv_cache.can_scale(kv_scales)\n\n    if ATTENTION == \"flashinfer\":\n        from text_generation_server.layers.attention.flashinfer import (\n            prefill_with_paged_kv_state,\n        )\n\n        if softcap is None:\n            softcap = 0.0\n\n        return prefill_with_paged_kv_state.get().forward(\n            query,\n            causal=causal,\n            paged_kv_cache=(kv_cache.key, kv_cache.value),\n            logits_soft_cap=softcap,\n            sm_scale=softmax_scale,\n            k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,\n            v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,\n            window_left=window_size_left,\n        )\n\n    # If we are using flashdecoding or paged, we always use flash-attn for\n    # the prefill. We have to branch on whether we use flash-attn v1 or v2.\n    elif V2:\n        out = torch.empty_like(query)\n        if window_size_left <= 0 and window_size_left != -1:\n            raise ValueError(\"`window_size_left` must be > 0 or -1\")\n\n        if softcap is None:\n            softcap = 0.0\n\n        return flash_attn_2_cuda.varlen_fwd(\n            query,\n            # flashdecoding: pass the KV caches, paged: pass the KV.\n            kv_cache.key if ATTENTION == \"flashdecoding\" else key,\n            kv_cache.value if ATTENTION == \"flashdecoding\" else value,\n            out,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            None,\n            None,\n            block_tables if ATTENTION == \"flashdecoding\" else None,\n            None,\n            seqlen.max_q,\n            seqlen.max_k,\n            0.0,\n            softmax_scale,\n            False,\n            causal,\n            window_size_left,\n            0,\n            softcap,\n            False,\n            None,\n        )[0]\n\n    else:\n        if window_size_left != -1:\n            raise NotImplementedError(\n                \"window_size_left is only available with flash attn v2\"\n            )\n        if softcap is not None:\n            raise NotImplementedError(\"softcap is not available in flash attn v1\")\n\n        # Flash attention v1 requires q, k and v to have the same number of heads\n        if key.shape[1] != query.shape[1]:\n            # MQA expand\n            if key.shape[1] == 1:\n                key = key.expand(-1, query.shape[1], -1)\n            # Grouped attention reshape\n            else:\n                original_shape = key.shape\n                key = (\n                    key.unsqueeze(2)\n                    .expand(-1, -1, query.shape[1] // key.shape[1], -1)\n                    .reshape(original_shape[0], -1, original_shape[2])\n                )\n        if value.shape[1] != query.shape[1]:\n            # MQA expand\n            if value.shape[1] == 1:\n                value = value.expand(-1, query.shape[1], -1)\n            # Grouped attention reshape\n            else:\n                original_shape = value.shape\n                value = (\n                    value.unsqueeze(2)\n                    .expand(-1, -1, query.shape[1] // value.shape[1], -1)\n                    .reshape(original_shape[0], -1, original_shape[2])\n                )\n\n        out = torch.empty_like(query)\n        flash_attn_cuda.fwd(\n            query,\n            key,\n            value,\n            out,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_q,\n            seqlen.max_q,\n            seqlen.max_k,\n            0.0,\n            softmax_scale,\n            False,\n            causal,\n            False,\n            0,\n            None,\n        )\n        return out\n\n\n__all__ = [\n    \"SUPPORTS_WINDOWING\",\n    \"attention\",\n    \"paged_attention\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/flash_attn_triton.py",
    "content": "#!/usr/bin/env python\n\"\"\"\nFused Attention\n===============\n\nThis is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao\n(https://tridao.me/publications/flash2/flash2.pdf)\nCredits: OpenAI kernel team, AMD ML Frameworks Triton team\n\nFeatures supported:\n\n1) Fwd with causal masking\n2) Any sequence lengths without padding (currently fwd kernel only)\n3) Support for different sequence lengths for q and k\n4) Nested tensor API currently does not support dropout or bias.\n\nNot currently supported:\n\n1) Non power of two head dims\n\n\"\"\"\n\nimport torch\nimport triton\nimport triton.language as tl\n\ntorch_dtype: tl.constexpr = torch.float16\n\n\n@triton.jit\ndef cdiv_fn(x, y):\n    return (x + y - 1) // y\n\n\n@triton.jit\ndef max_fn(x, y):\n    return tl.math.max(x, y)\n\n\n@triton.jit\ndef dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):\n    ms = tl.arange(0, m)\n    ns = tl.arange(0, n)\n    return philox_offset + ms[:, None] * stride + ns[None, :]\n\n\n@triton.jit\ndef dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):\n    rng_offsets = dropout_offsets(\n        philox_seed, philox_offset, dropout_p, m, n, stride\n    ).to(tl.uint32)\n    # TODO: use tl.randint for better performance\n    return tl.rand(philox_seed, rng_offsets)\n\n\n@triton.jit\ndef dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):\n    rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)\n    rng_keep = rng_output > dropout_p\n    return rng_keep\n\n\n@triton.jit\ndef load_fn(block_ptr, first, second, pad):\n    if first and second:\n        tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)\n    elif first:\n        tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)\n    elif second:\n        tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)\n    else:\n        tensor = tl.load(block_ptr)\n    return tensor\n\n\n@triton.jit\ndef _attn_fwd_inner(\n    acc,\n    l_i,\n    m_i,\n    q,\n    K_block_ptr,\n    V_block_ptr,\n    start_m,\n    actual_seqlen_k,\n    dropout_p,\n    philox_seed,\n    batch_philox_offset,\n    encoded_softmax_block_ptr,\n    block_min,\n    block_max,\n    offs_n_causal,\n    masked_blocks,\n    n_extra_tokens,\n    bias_ptr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    OFFS_M: tl.constexpr,\n    OFFS_N: tl.constexpr,\n    PRE_LOAD_V: tl.constexpr,\n    MASK_STEPS: tl.constexpr,\n    ENABLE_DROPOUT: tl.constexpr,\n    RETURN_ENCODED_SOFTMAX: tl.constexpr,\n    PADDED_HEAD: tl.constexpr,\n):\n    # loop over k, v, and update accumulator\n    for start_n in range(block_min, block_max, BLOCK_N):\n        # For padded blocks, we will overrun the tensor size if\n        # we load all BLOCK_N. For others, the blocks are all within range.\n        k = load_fn(\n            K_block_ptr,\n            PADDED_HEAD,\n            MASK_STEPS and (n_extra_tokens != 0),\n            \"zero\",\n        )\n        if PRE_LOAD_V:\n            v = load_fn(\n                V_block_ptr,\n                MASK_STEPS and (n_extra_tokens != 0),\n                PADDED_HEAD,\n                \"zero\",\n            )\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        # We start from end of seqlen_k so only the first iteration would need\n        # to be checked for padding if it is not a multiple of block_n\n        # TODO: This can be optimized to only be true for the padded block.\n        if MASK_STEPS:  # noqa: SIM102\n            # If this is the last block / iteration, we want to\n            # mask if the sequence length is not a multiple of block size\n            # a solution is to always do BLOCK_M // BLOCK_N + 1 steps\n            # if not is_modulo_mn. last step might get wasted but that is okay.\n            # check if this masking works for that case.\n            if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):\n                boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)\n                size_n = start_n + OFFS_N[None, :]\n                mask = size_n < boundary_m[:, None]\n                qk = tl.where(mask, qk, float(\"-inf\"))\n        if IS_CAUSAL:\n            causal_boundary = start_n + offs_n_causal\n            causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]\n            qk = tl.where(causal_mask, qk, float(\"-inf\"))\n        # -- compute qk ----\n        qk += tl.dot(q, k)\n        if bias_ptr is not None:\n            bias = load_fn(\n                bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), \"zero\"\n            )\n            # While bias is added after multiplying qk with sm_scale, our\n            # optimization to use 2^x instead of e^x results in an additional\n            # scale factor of log2(e) which we must also multiply the bias with.\n            qk += bias * 1.44269504089\n        m_ij = tl.maximum(m_i, tl.max(qk, 1))\n        qk = qk - m_ij[:, None]\n        p = tl.math.exp2(qk)\n\n        # CAVEAT: Must update l_ij before applying dropout\n        l_ij = tl.sum(p, 1)\n        if ENABLE_DROPOUT:\n            philox_offset = (\n                batch_philox_offset\n                + start_m * BLOCK_M * actual_seqlen_k\n                + start_n\n                - BLOCK_N\n            )\n            keep = dropout_mask(\n                philox_seed,\n                philox_offset,\n                dropout_p,\n                BLOCK_M,\n                BLOCK_N,\n                actual_seqlen_k,\n            )\n            if RETURN_ENCODED_SOFTMAX:\n                tl.store(\n                    encoded_softmax_block_ptr,\n                    tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),\n                )\n            p = tl.where(keep, p, 0.0)\n        elif RETURN_ENCODED_SOFTMAX:\n            tl.store(\n                encoded_softmax_block_ptr,\n                p.to(encoded_softmax_block_ptr.type.element_ty),\n            )\n        # -- update output accumulator --\n        alpha = tl.math.exp2(m_i - m_ij)\n        acc = acc * alpha[:, None]\n        if not PRE_LOAD_V:\n            v = load_fn(\n                V_block_ptr,\n                MASK_STEPS and (n_extra_tokens != 0),\n                PADDED_HEAD,\n                \"zero\",\n            )\n        # -- update m_i and l_i\n        l_i = l_i * alpha + l_ij\n        # update m_i and l_i\n        m_i = m_ij\n        acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)\n        V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n        K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n        if bias_ptr is not None:\n            bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))\n        if RETURN_ENCODED_SOFTMAX:\n            encoded_softmax_block_ptr = tl.advance(\n                encoded_softmax_block_ptr, (0, BLOCK_N)\n            )\n    return acc, l_i, m_i\n\n\n@triton.autotune(\n    configs=[\n        triton.Config(\n            {\n                \"BLOCK_M\": 256,\n                \"BLOCK_N\": 64,\n                \"waves_per_eu\": 2,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 128,\n                \"BLOCK_N\": 128,\n                \"waves_per_eu\": 2,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 256,\n                \"BLOCK_N\": 128,\n                \"waves_per_eu\": 2,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 128,\n                \"BLOCK_N\": 64,\n                \"waves_per_eu\": 3,\n                \"PRE_LOAD_V\": True,\n            },\n            num_stages=1,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 128,\n                \"BLOCK_N\": 64,\n                \"waves_per_eu\": 3,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 64,\n                \"BLOCK_N\": 64,\n                \"waves_per_eu\": 4,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=8,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 32,\n                \"BLOCK_N\": 32,\n                \"waves_per_eu\": 4,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=8,\n        ),\n        # TODO: This config fails with head_size not pow2 with data mismatches.\n        #    triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,\n        #                   'PRE_LOAD_V': False}, num_stages=1, num_warps=4),\n        triton.Config(\n            {\n                \"BLOCK_M\": 16,\n                \"BLOCK_N\": 16,\n                \"waves_per_eu\": 1,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=4,\n        ),\n        triton.Config(\n            {\n                \"BLOCK_M\": 128,\n                \"BLOCK_N\": 64,\n                \"waves_per_eu\": 1,\n                \"PRE_LOAD_V\": False,\n            },\n            num_stages=1,\n            num_warps=4,\n        ),\n    ],\n    key=[\"IS_CAUSAL\", \"dropout_p\", \"BLOCK_DMODEL\"],\n)\n@triton.jit\ndef attn_fwd(\n    Q,\n    K,\n    V,\n    bias,\n    sm_scale,\n    L,\n    Out,\n    stride_qz,\n    stride_qh,\n    stride_qm,\n    stride_qk,\n    stride_kz,\n    stride_kh,\n    stride_kn,\n    stride_kk,\n    stride_vz,\n    stride_vh,\n    stride_vk,\n    stride_vn,\n    stride_oz,\n    stride_oh,\n    stride_om,\n    stride_on,\n    stride_bz,\n    stride_bh,\n    stride_bm,\n    stride_bn,\n    cu_seqlens_q,\n    cu_seqlens_k,\n    dropout_p,\n    philox_seed,\n    philox_offset_base,\n    encoded_softmax,\n    HQ: tl.constexpr,\n    HK: tl.constexpr,\n    ACTUAL_BLOCK_DMODEL: tl.constexpr,\n    MAX_SEQLENS_Q: tl.constexpr,\n    MAX_SEQLENS_K: tl.constexpr,\n    VARLEN: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    PRE_LOAD_V: tl.constexpr,\n    BIAS_TYPE: tl.constexpr,\n    ENABLE_DROPOUT: tl.constexpr,\n    RETURN_ENCODED_SOFTMAX: tl.constexpr,\n):\n    start_m = tl.program_id(0)\n    off_h_q = tl.program_id(1)\n    off_z = tl.program_id(2)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N)\n    if VARLEN:\n        cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)\n        cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)\n        seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start\n        # We have a one-size-fits-all grid in id(0). Some seqlens might be too\n        # small for all start_m so for those we return early.\n        if start_m * BLOCK_M > seqlen_q:\n            return\n        cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)\n        cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)\n        seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start\n    else:\n        cu_seqlens_q_start = 0\n        cu_seqlens_k_start = 0\n        seqlen_q = MAX_SEQLENS_Q\n        seqlen_k = MAX_SEQLENS_K\n\n    # Now we compute whether we need to exit early due to causal masking.\n    # This is because for seqlen_q > seqlen_k, M rows of the attn scores\n    # are completely masked, resulting in 0s written to the output, and\n    # inf written to LSE. We don't need to do any GEMMs in this case.\n    # This block of code determines what N is, and if this WG is operating\n    # on those M rows.\n    n_blocks = cdiv_fn(seqlen_k, BLOCK_N)\n    if IS_CAUSAL:\n        # If seqlen_q == seqlen_k, the attn scores are a square matrix.\n        # If seqlen_q != seqlen_k, attn scores are rectangular which means\n        # the causal mask boundary is bottom right aligned, and ends at either\n        # the top edge (seqlen_q < seqlen_k) or left edge.\n        # This captures the decrease in n_blocks if we have a rectangular attn\n        # matrix\n        n_blocks_seqlen = cdiv_fn(\n            (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N\n        )\n        # This is what adjusts the block_max for the current WG, only\n        # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks\n        n_blocks = min(n_blocks, n_blocks_seqlen)\n        # If we have no blocks after adjusting for seqlen deltas, this WG is\n        # part of the blocks that are all 0. We exit early.\n        if n_blocks <= 0:\n            o_offset = (\n                off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh\n            )\n            O_block_ptr = tl.make_block_ptr(\n                base=Out + o_offset,\n                shape=(seqlen_q, BLOCK_DMODEL),\n                strides=(stride_om, stride_on),\n                offsets=(start_m * BLOCK_M, 0),\n                block_shape=(BLOCK_M, BLOCK_DMODEL),\n                order=(1, 0),\n            )\n            acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)\n            # We still need to write 0s to the result\n            # tl.store(O_block_ptr,\n            # acc.to(Out.type.element_ty), boundary_check=(0,1))\n            # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q\n            #          + offs_m\n            # We store inf to LSE, not -inf because in the bwd pass,\n            # we subtract this\n            # from qk which makes it -inf, such that exp(qk - inf) = 0\n            # for these masked blocks.\n            # l = tl.full([BLOCK_M], value=float(\"inf\"), dtype=tl.float32)\n            # tl.store(l_ptrs, l)\n            # TODO: Should dropout and return encoded softmax be handled here?\n            return\n\n    # If MQA / GQA, set the K and V head offsets appropriately.\n    GROUP_SIZE: tl.constexpr = HQ // HK\n    if GROUP_SIZE != 1:\n        off_h_k = off_h_q // GROUP_SIZE\n    else:\n        off_h_k = off_h_q\n\n    n_extra_tokens = 0\n    if seqlen_k < BLOCK_N:\n        n_extra_tokens = BLOCK_N - seqlen_k\n    elif seqlen_k % BLOCK_N:\n        n_extra_tokens = seqlen_k % BLOCK_N\n    PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL\n\n    # Compute pointers for all the tensors used in this kernel.\n    q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm\n    Q_block_ptr = tl.make_block_ptr(\n        base=Q + q_offset,\n        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_qm, stride_qk),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0),\n    )\n    k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn\n    K_block_ptr = tl.make_block_ptr(\n        base=K + k_offset,\n        shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),\n        strides=(stride_kk, stride_kn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_DMODEL, BLOCK_N),\n        order=(0, 1),\n    )\n    v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk\n    V_block_ptr = tl.make_block_ptr(\n        base=V + v_offset,\n        shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_vk, stride_vn),\n        offsets=(0, 0),\n        block_shape=(BLOCK_N, BLOCK_DMODEL),\n        order=(1, 0),\n    )\n    if BIAS_TYPE != 0:\n        bias_ptr = tl.make_block_ptr(\n            base=bias + off_h_q * stride_bh,\n            shape=(seqlen_q, seqlen_k),\n            strides=(stride_bm, stride_bn),\n            offsets=(start_m * BLOCK_M, 0),\n            block_shape=(BLOCK_M, BLOCK_N),\n            order=(1, 0),\n        )\n    else:\n        bias_ptr = None\n    if ENABLE_DROPOUT:\n        batch_philox_offset = (\n            philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k\n        )\n    else:\n        batch_philox_offset = 0\n    # We can ask to return the dropout mask without actually doing any dropout.\n    # In this case, we return an invalid pointer so indicate the mask is not i\n    # valid.\n    # TODO: Fix encoded softmax. It currently uses just h_q in the base offset.\n    if RETURN_ENCODED_SOFTMAX:\n        encoded_softmax_block_ptr = tl.make_block_ptr(\n            base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,\n            shape=(seqlen_q, seqlen_k),\n            strides=(seqlen_k, 1),\n            offsets=(start_m * BLOCK_M, 0),\n            block_shape=(BLOCK_M, BLOCK_N),\n            order=(1, 0),\n        )\n    else:\n        encoded_softmax_block_ptr = 0\n    # initialize pointer to m and l\n    m_i = tl.full([BLOCK_M], float(\"-inf\"), dtype=tl.float32)\n    l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n    # scale sm_scale by log_2(e) and use 2^x in the loop as we do not\n    # have native e^x support in HW.\n    qk_scale = sm_scale * 1.44269504089\n    # Q is loaded once at the beginning and shared by all N blocks.\n    q = load_fn(Q_block_ptr, True, PADDED_HEAD, \"zero\")\n    q = (q * qk_scale).to(Q_block_ptr.type.element_ty)\n\n    # Here we compute how many full and masked blocks we have.\n    padded_block_k = n_extra_tokens != 0\n    is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)\n    if IS_CAUSAL:\n        # There are always at least BLOCK_M // BLOCK_N masked blocks.\n        # Additionally there might be one more due to dissimilar seqlens.\n        masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)\n    else:\n        # Padding on Q does not need to be masked in the FA loop.\n        masked_blocks = padded_block_k\n    # if IS_CAUSAL, not is_modulo_mn does not always result in an additional\n    # block. In this case we might exceed n_blocks so pick the min.\n    masked_blocks = min(masked_blocks, n_blocks)\n    n_full_blocks = n_blocks - masked_blocks\n    block_min = 0\n    block_max = n_blocks * BLOCK_N\n    # Compute for full blocks. Here we set causal to false regardless of its\n    # value because there is no masking. Similarly we do not need padding.\n    if n_full_blocks > 0:\n        block_max = (n_blocks - masked_blocks) * BLOCK_N\n        acc, l_i, m_i = _attn_fwd_inner(\n            acc,\n            l_i,\n            m_i,\n            q,\n            K_block_ptr,\n            V_block_ptr,\n            start_m,\n            seqlen_k,\n            dropout_p,\n            philox_seed,\n            batch_philox_offset,\n            encoded_softmax_block_ptr,\n            # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _\n            block_min,\n            block_max,\n            0,\n            0,\n            0,\n            bias_ptr,\n            # IS_CAUSAL, ....\n            False,\n            BLOCK_M,\n            BLOCK_DMODEL,\n            BLOCK_N,\n            offs_m,\n            offs_n,\n            # _, MASK_STEPS, ...\n            PRE_LOAD_V,\n            False,\n            ENABLE_DROPOUT,\n            RETURN_ENCODED_SOFTMAX,\n            PADDED_HEAD,\n        )\n        block_min = block_max\n        block_max = n_blocks * BLOCK_N\n\n    tl.debug_barrier()\n    # Remaining blocks, if any, are full / not masked.\n    if masked_blocks > 0:\n        offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0\n        K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))\n        V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))\n        if bias_ptr is not None:\n            bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))\n        if RETURN_ENCODED_SOFTMAX:\n            encoded_softmax_block_ptr = tl.advance(\n                encoded_softmax_block_ptr, (0, n_full_blocks)\n            )\n        acc, l_i, m_i = _attn_fwd_inner(\n            acc,\n            l_i,\n            m_i,\n            q,\n            K_block_ptr,\n            V_block_ptr,\n            start_m,\n            seqlen_k,\n            dropout_p,\n            philox_seed,\n            batch_philox_offset,\n            encoded_softmax_block_ptr,\n            block_min,\n            block_max,\n            offs_n_causal,\n            masked_blocks,\n            n_extra_tokens,\n            bias_ptr,\n            IS_CAUSAL,\n            BLOCK_M,\n            BLOCK_DMODEL,\n            BLOCK_N,\n            offs_m,\n            offs_n,\n            # _, MASK_STEPS, ...\n            PRE_LOAD_V,\n            True,\n            ENABLE_DROPOUT,\n            RETURN_ENCODED_SOFTMAX,\n            PADDED_HEAD,\n        )\n    # epilogue\n    acc = acc / l_i[:, None]\n    if ENABLE_DROPOUT:\n        acc = acc / (1 - dropout_p)\n    # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,\n    # then we have one block with a row of all NaNs which come from computing\n    # softmax over a row of all -infs (-inf - inf = NaN). We check for that here\n    # and store 0s where there are NaNs as these rows should've been zeroed out.\n    end_m_idx = (start_m + 1) * BLOCK_M\n    start_m_idx = start_m * BLOCK_M\n    causal_start_idx = seqlen_q - seqlen_k\n    acc = acc.to(Out.type.element_ty)\n    if IS_CAUSAL:  # noqa: SIM102\n        if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:\n            out_mask_boundary = tl.full(\n                (BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32\n            )\n            mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)\n            out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]\n            z = 0.0\n            acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))\n    # write back LSE\n    # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m\n    # If seqlen_q not multiple of BLOCK_M, we need to mask out the last\n    # few rows. This is only true for the last M block. For others,\n    # overflow_size will be -ve\n    # overflow_size = end_m_idx - seqlen_q\n    # if overflow_size > 0:\n    #    boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)\n    #    # This is a > check because mask being 0 blocks the store.\n    #    l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)\n    #    tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)\n    # else:\n    #    tl.store(l_ptrs, m_i + tl.math.log2(l_i))\n\n    # write back O\n    o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh\n    O_block_ptr = tl.make_block_ptr(\n        base=Out + o_offset,\n        shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),\n        strides=(stride_om, stride_on),\n        offsets=(start_m * BLOCK_M, 0),\n        block_shape=(BLOCK_M, BLOCK_DMODEL),\n        order=(1, 0),\n    )\n    # Need boundary check on this to make sure the padding from the\n    # Q and KV tensors in both dims are not part of what we store back.\n    # TODO: Do the boundary check optionally.\n    tl.store(O_block_ptr, acc, boundary_check=(0, 1))\n\n\ndef check_args(\n    q,\n    k,\n    v,\n    o,\n    varlen=True,\n    max_seqlens=None,\n    cu_seqlens_q=None,\n    cu_seqlens_k=None,\n):\n    assert q.dim() == k.dim() and q.dim() == v.dim()\n    if varlen:\n        assert q.dim() == 3\n        total_q, nheads_q, head_size = q.shape\n        total_k, nheads_k, _ = k.shape\n        assert cu_seqlens_q is not None\n        assert cu_seqlens_k is not None\n        assert len(cu_seqlens_q) == len(cu_seqlens_k)\n    else:\n        assert q.dim() == 4\n        batch, nheads_q, seqlen_q, head_size = q.shape\n        _, nheads_k, seqlen_k, _ = k.shape\n        assert max_seqlens > 0\n    assert k.shape == v.shape\n    assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]\n    # TODO: Change assert if we support qkl f8 and v f16\n    assert q.dtype == k.dtype and q.dtype == v.dtype\n    # TODO: Fix assert to check head size <=256 once supported\n    assert head_size <= 128\n    assert o.shape == q.shape\n    assert (nheads_q % nheads_k) == 0\n\n\nclass _attention(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        q,\n        k,\n        v,\n        o,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlens_q,\n        max_seqlens_k,\n        causal=False,\n        sm_scale=1.0,\n        bias=None,\n    ):\n        if o is None:\n            o = torch.empty_like(q, dtype=v.dtype)\n\n        check_args(\n            q,\n            k,\n            v,\n            o,\n            varlen=True,\n            cu_seqlens_q=cu_seqlens_q,\n            cu_seqlens_k=cu_seqlens_k,\n        )\n        if True:  # varlen\n            total_q, nheads_q, head_size = q.shape\n            total_k, nheads_k, _ = k.shape\n            batch = len(cu_seqlens_q) - 1\n            q_strides = (0, q.stride(1), q.stride(0), q.stride(2))\n            k_strides = (0, k.stride(1), k.stride(0), k.stride(2))\n            v_strides = (0, v.stride(1), v.stride(0), v.stride(2))\n            o_strides = (0, o.stride(1), o.stride(0), o.stride(2))\n        else:\n            batch, seqlen_q, nheads_q, head_size = q.shape\n            _, seqlen_k, nheads_k, _ = k.shape\n            q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))\n            k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))\n            v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))\n            o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))\n\n        # Get closest power of 2 over or equal to 32.\n        padded_d_model = 1 << (head_size - 1).bit_length()\n        padded_d_model = max(padded_d_model, 16)\n\n        def grid(META):\n            return triton.cdiv(max_seqlens_q, META[\"BLOCK_M\"]), nheads_q, batch\n\n        encoded_softmax = None\n\n        # Seed the RNG so we get reproducible results for testing.\n        philox_seed = 0x1BF52\n        philox_offset = 0x1D4B42\n\n        if bias is not None:\n            bias_strides = (\n                bias.stride(0),\n                bias.stride(1),\n                bias.stride(2),\n                bias.stride(3),\n            )\n        else:\n            bias_strides = (0, 0, 0, 0)\n\n        attn_fwd[grid](\n            q,\n            k,\n            v,\n            bias,\n            sm_scale,\n            None,\n            o,\n            *q_strides,\n            *k_strides,\n            *v_strides,\n            *o_strides,\n            *bias_strides,\n            cu_seqlens_q,\n            cu_seqlens_k,\n            dropout_p=0.0,\n            philox_seed=philox_seed,\n            philox_offset_base=philox_offset,\n            encoded_softmax=encoded_softmax,\n            HQ=nheads_q,\n            HK=nheads_k,\n            ACTUAL_BLOCK_DMODEL=head_size,\n            MAX_SEQLENS_Q=max_seqlens_q,\n            MAX_SEQLENS_K=max_seqlens_k,\n            IS_CAUSAL=causal,\n            VARLEN=True,\n            BLOCK_DMODEL=padded_d_model,\n            BIAS_TYPE=0 if bias is None else 1,\n            ENABLE_DROPOUT=False,\n            RETURN_ENCODED_SOFTMAX=False,\n        )\n\n        ctx.grid = grid\n        ctx.sm_scale = sm_scale\n        ctx.BLOCK_DMODEL = head_size\n        ctx.causal = causal\n        ctx.dropout_p = 0.0\n        ctx.philox_seed = philox_seed\n        ctx.philox_offset = philox_offset\n        ctx.encoded_softmax = encoded_softmax\n        ctx.return_encoded_softmax = False\n        return o, encoded_softmax\n\n\ntriton_attention = _attention.apply\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/flashinfer.py",
    "content": "from typing import Optional\nfrom contextvars import ContextVar\nfrom contextlib import contextmanager\nimport math\n\nimport flashinfer\nimport torch\n\nprefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(\n    \"prefill_state\"\n)\n\nprefill_with_paged_kv_state: ContextVar[\n    flashinfer.BatchPrefillWithPagedKVCacheWrapper\n] = ContextVar(\"prefill_with_paged_kv_state\")\n\ndecode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(\n    \"decode_state\"\n)\n\nworkspace: Optional[torch.Tensor] = None\n\n\ndef unpad_2d_mask(\n    attention_mask: torch.Tensor, seq_lengths: torch.Tensor\n) -> torch.Tensor:\n    # Like torch unpad_sequence, but for 2D masks.\n    unpadded_tensors = []\n    for i, length in enumerate(seq_lengths):\n        unpadded_matrix = attention_mask[i, :length, :length]\n        unpadded_tensors.append(unpadded_matrix.flatten())\n\n    packed_tensor = torch.cat(unpadded_tensors)\n\n    return packed_tensor\n\n\ndef get_workspace(device):\n    \"\"\"Get shared flashinfer workspace.\"\"\"\n    global workspace\n    if workspace is None:\n        workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)\n    return workspace\n\n\ndef create_prefill_with_paged_kv_state(\n    *,\n    device: torch.device,\n):\n    \"\"\"Create a prefill state that uses the KV cache.\"\"\"\n    workspace_buffer = get_workspace(device)\n    return flashinfer.BatchPrefillWithPagedKVCacheWrapper(\n        workspace_buffer, kv_layout=\"NHD\", use_cuda_graph=False\n    )\n\n\n@contextmanager\ndef use_prefill_with_paged_kv_state(\n    *,\n    state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,\n    block_tables: torch.Tensor,\n    cu_seqlens: torch.Tensor,\n    custom_mask: Optional[torch.Tensor],\n    input_lengths: torch.Tensor,\n    num_heads: int,\n    num_kv_heads: int,\n    head_size: int,\n    page_size: int,\n    kv_dtype: torch.dtype,\n    q_dtype: torch.dtype,\n):\n    \"\"\"\n    Context manager to set the active flashinfer prefill state to the given\n    `state` and parameters. This state will be used by all calls to the\n    `attention` function while the context manager is active.\n    \"\"\"\n\n    indptr = torch.zeros(\n        input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32\n    )\n    # Round up to page size and then calculate the cumulative sum to get\n    # the indices into the block table.\n    torch.add(input_lengths, page_size - 1, out=indptr[1:])\n    indptr[1:].div_(page_size, rounding_mode=\"floor\")\n    indptr[1:].cumsum_(-1)\n\n    # Get the lengths of the last page in a block.\n    if page_size == 1:\n        last_page_len = torch.ones(\n            input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device\n        )\n    else:\n        last_page_len = torch.empty(\n            input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device\n        )\n        torch.sub(input_lengths, 1, out=last_page_len)\n        last_page_len.remainder_(page_size)\n        last_page_len += 1\n\n    token = prefill_with_paged_kv_state.set(state)\n\n    # Attention masks are padded, unpad.\n    if custom_mask is not None:\n        bs = input_lengths.shape[0]\n        seq_len = math.isqrt(custom_mask.numel() // bs)\n        custom_mask = unpad_2d_mask(\n            custom_mask.reshape(bs, seq_len, seq_len), input_lengths\n        )\n\n    try:\n        state.plan(\n            qo_indptr=cu_seqlens,\n            paged_kv_indptr=indptr,\n            paged_kv_indices=block_tables,\n            paged_kv_last_page_len=last_page_len,\n            custom_mask=custom_mask,\n            num_qo_heads=num_heads,\n            num_kv_heads=num_kv_heads,\n            head_dim=head_size,\n            kv_data_type=kv_dtype,\n            q_data_type=q_dtype,\n            page_size=page_size,\n        )\n        yield\n    finally:\n        if token is not None:\n            prefill_with_paged_kv_state.reset(token)\n\n\ndef create_prefill_state(\n    *,\n    device: torch.device,\n):\n    \"\"\"Create a prefill state.\"\"\"\n    workspace_buffer = get_workspace(device)\n    return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(\n        workspace_buffer, kv_layout=\"NHD\", use_cuda_graph=False\n    )\n\n\ndef create_decode_state(\n    *,\n    device: torch.device,\n    num_heads: int,\n    num_kv_heads: int,\n):\n    \"\"\"Create a decode state.\"\"\"\n    workspace_buffer = get_workspace(device)\n    num_groups = num_heads // num_kv_heads\n    return flashinfer.BatchDecodeWithPagedKVCacheWrapper(\n        workspace_buffer,\n        kv_layout=\"NHD\",\n        use_cuda_graph=False,\n        # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60\n        use_tensor_cores=num_groups not in [1, 2, 4, 8],\n    )\n\n\ndef create_decode_state_cuda_graphs(\n    *,\n    device: torch.device,\n    block_tables: torch.Tensor,\n    block_tables_ptr: torch.Tensor,\n    last_page_len: torch.Tensor,\n    num_heads: int,\n    num_kv_heads: int,\n):\n    \"\"\"\n    Create a decode state for use with CUDA Graphs. `block_tables`,\n    `block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are\n    therefore stored as part of the state.\n    \"\"\"\n    workspace_buffer = get_workspace(device)\n    num_groups = num_heads // num_kv_heads\n    return flashinfer.BatchDecodeWithPagedKVCacheWrapper(\n        workspace_buffer,\n        kv_layout=\"NHD\",\n        use_cuda_graph=True,\n        paged_kv_indices_buffer=block_tables,\n        paged_kv_indptr_buffer=block_tables_ptr,\n        paged_kv_last_page_len_buffer=last_page_len,\n        # Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60\n        use_tensor_cores=num_groups not in [1, 2, 4, 8],\n    )\n\n\n@contextmanager\ndef use_decode_state(\n    *,\n    state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,\n    input_lengths: torch.Tensor,\n    block_tables: torch.Tensor,\n    num_heads: int,\n    num_kv_heads: int,\n    head_size: int,\n    page_size: int,\n    kv_cache_dtype: torch.dtype,\n    q_dtype: torch.dtype,\n):\n    \"\"\"\n    Context manager to set the active flashinfer decoding state to the given\n    `state` and parameters. This state will be used by all calls to the\n    `paged_attention` function while the context manager is active.\n    \"\"\"\n    indptr = torch.zeros(\n        input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32\n    )\n    # Round up to page size and then calculate the cumulative sum to get\n    # the indices into the block table.\n    torch.add(input_lengths, page_size - 1, out=indptr[1:])\n    indptr[1:].div_(page_size, rounding_mode=\"floor\")\n    indptr[1:].cumsum_(-1)\n\n    # Get the lengths of the last page in a block.\n    last_page_len = torch.empty(\n        input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device\n    )\n    torch.sub(input_lengths, 1, out=last_page_len)\n    last_page_len.remainder_(page_size)\n    last_page_len += 1\n\n    token = decode_state.set(state)\n\n    try:\n        state.plan(\n            indptr=indptr,\n            indices=block_tables,\n            last_page_len=last_page_len,\n            num_qo_heads=num_heads,\n            num_kv_heads=num_kv_heads,\n            head_dim=head_size,\n            page_size=page_size,\n            data_type=kv_cache_dtype,\n            q_data_type=q_dtype,\n        )\n        yield\n    finally:\n        if token is not None:\n            decode_state.reset(token)\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/ipex.py",
    "content": "import intel_extension_for_pytorch as ipex\nimport torch\nfrom text_generation_server.layers.attention.kv_cache import KVCache, KVScales\nfrom text_generation_server.layers.attention import Seqlen\nfrom typing import Optional\nfrom text_generation_server.models.globals import (\n    ATTENTION,\n    BLOCK_SIZE,\n)\n\nif ATTENTION == \"flashdecoding-ipex\":\n    SUPPORTS_WINDOWING = True\nelse:\n    SUPPORTS_WINDOWING = False\n\n\ndef attention(\n    *,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    kv_cache: KVCache,\n    kv_scales: KVScales,\n    seqlen: Seqlen,\n    block_tables: torch.Tensor,\n    softmax_scale: float,\n    window_size_left: int = -1,\n    causal: bool = True,\n    softcap: Optional[float] = None,\n):\n\n    out = torch.empty_like(query)\n    kv_cache_dtype = \"auto\"\n    if kv_cache.key.dtype == torch.float8_e5m2:\n        kv_cache_dtype = \"fp8_e5m2\"\n    if kv_cache.key.dtype == torch.float8_e4m3fn:\n        kv_cache_dtype = \"fp8_e4m3\"\n\n    # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.\n    if ATTENTION == \"flashdecoding-ipex\":\n        window_size_right = -1 if window_size_left == -1 else 0\n        if softcap is None:\n            softcap = -1.0\n        ipex.llm.modules.PagedAttention.flash_attn_varlen_func(\n            out,\n            query.contiguous() if query.device.type == \"xpu\" else query,\n            kv_cache.key,\n            kv_cache.value,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            seqlen.max_q,\n            seqlen.max_k,\n            softmax_scale,\n            causal,\n            block_tables,\n            None,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n            kv_cache_dtype=kv_cache_dtype,\n            k_scale=kv_scales.key_scale_cpu,\n            v_scale=kv_scales.value_scale_cpu,\n            softcap=softcap,\n        )\n    else:\n        if softcap is not None:\n            raise NotImplementedError(\n                \"softcap is not available in IPEX paged attention\"\n            )\n        ipex.llm.functional.varlen_attention(\n            query.contiguous() if query.device.type == \"xpu\" else query,\n            key.contiguous() if key.device.type == \"xpu\" else key,\n            value.contiguous() if value.device.type == \"xpu\" else value,\n            out,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_q,\n            seqlen.max_q,\n            seqlen.max_q,\n            0.0,\n            softmax_scale,\n            False,\n            causal,\n            False,\n            None,\n        )\n\n    return out\n\n\ndef paged_attention(\n    query: torch.Tensor,\n    kv_cache: KVCache,\n    kv_head_mapping: torch.Tensor,\n    softmax_scale: float,\n    block_tables: torch.Tensor,\n    seqlen: Seqlen,\n    max_s: int,\n    *,\n    kv_scales: KVScales,\n    softcap: Optional[float] = None,\n    window_size_left: Optional[int] = -1,\n):\n    out = torch.empty_like(query)\n    kv_cache_dtype = \"auto\"\n    if kv_cache.key.dtype == torch.float8_e5m2:\n        kv_cache_dtype = \"fp8_e5m2\"\n    if kv_cache.key.dtype == torch.float8_e4m3fn:\n        kv_cache_dtype = \"fp8_e4m3\"\n    if ATTENTION == \"flashdecoding-ipex\":\n        window_size_right = -1 if window_size_left == -1 else 0\n        if softcap is None:\n            softcap = -1.0\n        ipex.llm.modules.PagedAttention.flash_attn_varlen_func(\n            out,\n            query.contiguous() if query.device.type == \"xpu\" else query,\n            kv_cache.key,\n            kv_cache.value,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            seqlen.max_q,\n            seqlen.max_k,\n            softmax_scale,\n            True,\n            block_tables,\n            None,\n            window_size_left=window_size_left,\n            window_size_right=window_size_right,\n            kv_cache_dtype=kv_cache_dtype,\n            k_scale=kv_scales.key_scale_cpu,\n            v_scale=kv_scales.value_scale_cpu,\n            softcap=softcap,\n        )\n    else:\n        input_lengths = seqlen.input_lengths + seqlen.cache_lengths\n        if softcap is not None:\n            raise NotImplementedError(\n                \"softcap is not available in IPEX paged attention\"\n            )\n        ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(\n            out,\n            query,\n            kv_cache.key,\n            kv_cache.value,\n            kv_head_mapping,\n            softmax_scale,\n            block_tables,\n            input_lengths,\n            BLOCK_SIZE,\n            max_s,\n            None,\n            k_scale=kv_scales.key_scale_cpu,\n            v_scale=kv_scales.value_scale_cpu,\n        )\n    return out\n\n\n__all__ = [\n    \"SUPPORTS_WINDOWING\",\n    \"attention\",\n    \"paged_attention\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/kv_cache.py",
    "content": "from typing import Tuple\nfrom dataclasses import dataclass, field\n\nfrom loguru import logger\nimport torch\n\nfrom text_generation_server.layers.fp8 import fp8_quantize\nfrom text_generation_server.models.globals import ATTENTION, BLOCK_SIZE\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import Weights\n\nif SYSTEM == \"cuda\":\n    try:\n        paged_attention = load_kernel(\n            module=\"paged_attention\", repo_id=\"kernels-community/paged-attention\"\n        )\n    except Exception as e:\n        raise ImportError(\n            f\"Could not import attention kernels. Make sure your installation is correct. Complete error: {e}\"\n        )\nelse:\n    paged_attention = None\n\n\n@dataclass\nclass KVScales:\n    \"\"\"\n    Key-value scales for FP8 KV cache.\n\n    This data class stores key and value scales both as a GPU tensor and\n    as a GPU float. This inconvenience is necessary because some functions\n    (e.g. scaling kernels) take scales as a GPU tensor, whereas others\n    (e.g. flashinfer) take scales as a CPU scalar.\n    \"\"\"\n\n    key_scale: torch.Tensor\n    value_scale: torch.Tensor\n    key_scale_cpu: float = field(init=False)\n    value_scale_cpu: float = field(init=False)\n\n    def __post_init__(self):\n        if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:\n            raise ValueError(\"Key and value scales must be scalar tensors.\")\n\n        self.key_scale_cpu = self.key_scale.item()\n        self.value_scale_cpu = self.value_scale.item()\n\n\nclass KVCache:\n    \"\"\"\n    Key-value cache for attention layers.\n    \"\"\"\n\n    kv_cache: Tuple[torch.Tensor, torch.Tensor]\n\n    def __init__(\n        self,\n        *,\n        num_blocks: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n    ):\n        \"\"\"Construct the key-value cache for a layer.\"\"\"\n        if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:\n            if not (\n                (ATTENTION == \"flashinfer\" and SYSTEM == \"cuda\")\n                or (ATTENTION == \"paged\" and SYSTEM in (\"cuda\", \"rocm\", \"ipex\"))\n                or (ATTENTION == \"flashdecoding-ipex\")\n            ):\n                raise ValueError(\n                    \"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on CUDA, ROCm and INTEL IPEX and flashdecoding in Intel IPEX \"\n                )\n            if SYSTEM == \"rocm\" and dtype == torch.float8_e5m2:\n                raise ValueError(\n                    \"float8_e5m2 FP8 KV cache is not supported on AMD ROCm\"\n                )\n            if device.type == \"cpu\" and dtype == torch.float8_e4m3fn:\n                raise ValueError(\n                    \"float8_e4m3fn FP8 KV cache is not supported on Intel IPEX CPU\"\n                )\n\n        element_size = torch.tensor([], dtype=dtype).element_size()\n        if SYSTEM == \"ipex\" and device.type == \"xpu\":\n            x = 1\n        else:\n            x = BLOCK_SIZE // element_size\n\n        if ATTENTION in {\"flashdecoding\", \"flashinfer\"} or (\n            ATTENTION == \"flashdecoding-ipex\" and device.type == \"xpu\"\n        ):\n            self.kv_cache = (\n                torch.empty(\n                    (num_blocks, BLOCK_SIZE, num_heads, head_size),\n                    dtype=dtype,\n                    device=device,\n                ),\n                torch.empty(\n                    (num_blocks, BLOCK_SIZE, num_heads, head_size),\n                    dtype=dtype,\n                    device=device,\n                ),\n            )\n        elif SYSTEM == \"ipex\" and device == torch.device(\"cpu\"):\n            # ipex cpu flashdecoding kernel and paged attention kernel share same layout\n            self.kv_cache = (\n                torch.empty(\n                    (num_blocks, num_heads, BLOCK_SIZE, head_size),\n                    dtype=dtype,\n                    device=device,\n                ),\n                torch.empty(\n                    (num_blocks, num_heads, BLOCK_SIZE, head_size),\n                    dtype=dtype,\n                    device=device,\n                ),\n            )\n        else:\n            self.kv_cache = (\n                torch.zeros(\n                    (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),\n                    dtype=dtype,\n                    device=device,\n                ),\n                torch.zeros(\n                    (num_blocks, num_heads, head_size, BLOCK_SIZE),\n                    dtype=dtype,\n                    device=device,\n                ),\n            )\n\n    def can_scale(self, kv_scales: KVScales) -> bool:\n        \"\"\"Check if the cache can be scaled by the given scales.\"\"\"\n        if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:\n            return False\n        elif self.dtype == torch.float8_e4m3fn and (\n            (ATTENTION in (\"paged\", \"flashinfer\") and SYSTEM == \"cuda\")\n            or (ATTENTION == \"paged\" and SYSTEM in [\"rocm\", \"ipex\"])\n            or (ATTENTION == \"flashdecoding-ipex\")\n        ):\n            log_once(logger.info, \"Using FP8 KV cache scales\")\n            return True\n        else:\n            # We have scales, but not the correct FP8 cache type, so warn once.\n            log_once(\n                logger.info,\n                \"Ignoring FP8 KV cache scales, supported only for float8_e4m3fn KV cache with flashinfer on CUDA and paged attention on ROCm/IPEX and flashdecoding on IPEX\",\n            )\n            return False\n\n    @property\n    def dtype(self):\n        \"\"\"Get the data type of the cache.\"\"\"\n        return self.kv_cache[0].dtype\n\n    @property\n    def key(self):\n        \"\"\"Get the key cache.\"\"\"\n\n        return self.kv_cache[0]\n\n    @property\n    def value(self):\n        \"\"\"Get the value cache.\"\"\"\n\n        return self.kv_cache[1]\n\n    def store(\n        self,\n        *,\n        key: torch.Tensor,\n        value: torch.Tensor,\n        slots: torch.Tensor,\n        kv_scales: KVScales,\n    ):\n        \"\"\"Store the key and value at the given slots.\"\"\"\n\n        key_cache = self.kv_cache[0]\n        value_cache = self.kv_cache[1]\n\n        if self.can_scale(kv_scales) and SYSTEM == \"cuda\":\n            if kv_scales.key_scale_cpu != 1.0:\n                key = fp8_quantize(\n                    key.float(),\n                    scale=kv_scales.key_scale,\n                    qdtype=self.dtype,\n                    scalar=True,\n                )[0]\n            if kv_scales.value_scale_cpu != 1.0:\n                value = fp8_quantize(\n                    value.float(),\n                    scale=kv_scales.value_scale,\n                    qdtype=self.dtype,\n                    scalar=True,\n                )[0]\n\n        if ATTENTION in {\"flashdecoding\", \"flashinfer\"}:\n            key = key.to(key_cache.dtype)\n            value = value.to(value_cache.dtype)\n            if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:\n                # Torch index_put does not support float8_{e5m2,e4m3fn} yet, so\n                # put as raw data instead.\n                key_cache = key_cache.view(torch.uint8)\n                value_cache = value_cache.view(torch.uint8)\n                key = key.view(torch.uint8)\n                value = value.view(torch.uint8)\n            shape = key_cache.shape\n            key_cache.view(-1, shape[-2], shape[-1])[slots] = key\n            value_cache.view(-1, shape[-2], shape[-1])[slots] = value\n        elif ATTENTION == \"flashdecoding-ipex\" and key.device.type == \"xpu\":\n            import intel_extension_for_pytorch as ipex\n\n            kv_cache_dtype = \"auto\"\n            if key_cache.dtype == torch.float8_e5m2:\n                kv_cache_dtype = \"fp8_e5m2\"\n            if key_cache.dtype == torch.float8_e4m3fn:\n                kv_cache_dtype = \"fp8_e4m3\"\n            ipex.llm.modules.PagedAttention.reshape_and_cache_flash(\n                key,\n                value,\n                key_cache,\n                value_cache,\n                slots,\n                kv_cache_dtype=kv_cache_dtype,\n                k_scale=kv_scales.key_scale_cpu,\n                v_scale=kv_scales.value_scale_cpu,\n            )\n        else:\n            paged_reshape_and_cache(\n                key,\n                value,\n                key_cache,\n                value_cache,\n                slots,\n                kv_scales.key_scale_cpu,\n                kv_scales.value_scale_cpu,\n            )\n\n\ndef paged_reshape_and_cache(\n    key: torch.Tensor,\n    value: torch.Tensor,\n    key_cache: torch.Tensor,\n    value_cache: torch.Tensor,\n    slots: torch.Tensor,\n    k_scale: float = 1.0,\n    v_scale: float = 1.0,\n):\n\n    if SYSTEM == \"cuda\":\n        kv_cache_dtype = \"auto\"\n        if key_cache.dtype == torch.float8_e4m3fn:\n            kv_cache_dtype = \"fp8\"\n\n        paged_attention.reshape_and_cache(\n            key,\n            value,\n            key_cache,\n            value_cache,\n            slots,\n            kv_cache_dtype,\n            torch.tensor(k_scale),\n            torch.tensor(v_scale),\n        )\n    elif SYSTEM == \"rocm\":\n        try:\n            import vllm._custom_ops as ops\n        except Exception as e:\n            raise ImportError(\n                f\"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}\"\n            )\n\n        kv_cache_dtype = \"auto\"\n        if key_cache.dtype == torch.float8_e4m3fn:\n            key_cache = key_cache.view(torch.uint8)\n            value_cache = value_cache.view(torch.uint8)\n            kv_cache_dtype = \"fp8\"\n\n        ops.reshape_and_cache(\n            key, value, key_cache, value_cache, slots, kv_cache_dtype, k_scale, v_scale\n        )\n    elif SYSTEM == \"ipex\":\n        import intel_extension_for_pytorch as ipex\n\n        kv_cache_dtype = \"auto\"\n        if key_cache.dtype == torch.float8_e5m2:\n            kv_cache_dtype = \"fp8_e5m2\"\n        if key_cache.dtype == torch.float8_e4m3fn:\n            kv_cache_dtype = \"fp8_e4m3\"\n\n        ipex.llm.modules.PagedAttention.reshape_and_cache(\n            key,\n            value,\n            key_cache,\n            value_cache,\n            slots,\n            kv_cache_dtype=kv_cache_dtype,\n            k_scale=k_scale,\n            v_scale=v_scale,\n        )\n    else:\n        raise NotImplementedError(\n            f\"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported\"\n        )\n\n\ndef get_kv_scales(weights: Weights, prefix: str) -> KVScales:\n    \"\"\"Load KV cache scales.\"\"\"\n\n    key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)\n    value_scale = key_scale\n    if weights.has_tensor(f\"{prefix}.k_scale\") and weights.has_tensor(\n        f\"{prefix}.v_scale\"\n    ):\n        key_scale = weights.get_tensor(f\"{prefix}.k_scale\", to_dtype=False).float()\n        value_scale = weights.get_tensor(f\"{prefix}.v_scale\", to_dtype=False).float()\n    elif weights.has_tensor(f\"{prefix}.kv_scale\"):\n        # Fall back to older more coarse-grained scale when available.\n        key_scale = weights.get_tensor(f\"{prefix}.kv_scale\").float()\n        value_scale = key_scale\n\n    return KVScales(key_scale=key_scale, value_scale=value_scale)\n"
  },
  {
    "path": "server/text_generation_server/layers/attention/rocm.py",
    "content": "import os\nfrom typing import Optional\nimport torch\nfrom text_generation_server.layers.attention.kv_cache import KVCache, KVScales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.layers.attention import Seqlen\nfrom text_generation_server.utils.log import log_master\nfrom text_generation_server.models.globals import (\n    ATTENTION,\n    BLOCK_SIZE,\n)\nfrom loguru import logger\nimport vllm._custom_ops as ops\n\nmajor, minor = torch.cuda.get_device_capability()\nis_sm75 = major == 7 and minor == 5\n\n_PARTITION_SIZE_V1V2 = 1024\n_PARTITION_SIZE_CUSTOM = 256\n\n_GPU_ARCH = torch.cuda.get_device_properties(\"cuda\").gcnArchName\n_ON_MI250_MI300 = any(\n    arch in _GPU_ARCH for arch in [\"gfx90a\", \"gfx940\", \"gfx941\", \"gfx942\"]\n)\n\nuse_triton = os.getenv(\"ROCM_USE_FLASH_ATTN_V2_TRITON\", \"\").lower() in {\"true\", \"1\"}\nENGINE = \"triton\" if use_triton else \"ck\"\n\nuse_rocm_custom_paged_attn = os.getenv(\"ROCM_USE_CUSTOM_PAGED_ATTN\", \"1\") != \"0\"\n\n\ndef _use_rocm_custom_paged_attention(\n    qtype: torch.dtype,\n    head_size: int,\n    block_size: int,\n    gqa_ratio: int,\n    max_seq_len: int,\n) -> bool:\n    # rocm custom page attention not support on navi (gfx1*)\n    return (\n        use_rocm_custom_paged_attn\n        and _ON_MI250_MI300\n        and (qtype == torch.half or qtype == torch.bfloat16)\n        and (head_size == 64 or head_size == 128)\n        and (block_size == 16 or block_size == 32)\n        and (gqa_ratio >= 1 and gqa_ratio <= 16)\n        and max_seq_len <= 131072\n    )\n\n\ndef paged_attention(\n    query: torch.Tensor,\n    kv_cache: KVCache,\n    kv_head_mapping: torch.Tensor,\n    softmax_scale: float,\n    block_tables: torch.Tensor,\n    seqlen: Seqlen,\n    max_s: int,\n    *,\n    kv_scales: KVScales,\n    softcap: Optional[float] = None,\n    window_size_left: Optional[int] = -1,\n):\n    # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py\n    # Copyright 2023 The vLLM team. All rights\n    # reserved.\n    #\n    # Licensed under the Apache License, Version 2.0 (the \"License\");\n    # you may not use this file except in compliance with the License.\n    # You may obtain a copy of the License at\n    #\n    #     http://www.apache.org/licenses/LICENSE-2.0\n    #\n    # Unless required by applicable law or agreed to in writing, software\n    # distributed under the License is distributed on an \"AS IS\" BASIS,\n    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    # See the License for the specific language governing permissions and\n    # limitations under the License.\n    #\n\n    if ATTENTION == \"flashdecoding\":\n        max_q = 1\n        max_k = max_s\n        import flash_attn_2_cuda\n\n        window_size_right = -1 if window_size_left == -1 else 0\n\n        if softcap is None:\n            softcap = 0.0\n        out = flash_attn_2_cuda.varlen_fwd(\n            query,\n            kv_cache.key,\n            kv_cache.value,\n            None,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            None,  # pad_k\n            None,\n            block_tables,\n            None,\n            max_q,\n            max_k,\n            0.0,  # dropout\n            softmax_scale,\n            False,  # zero_tensors\n            True,  # causal\n            window_size_left,  # Window_left\n            window_size_right,  # Window right\n            softcap,\n            False,  # return softmax\n            None,  # generator\n        )\n        return out[0]\n\n    if softcap is not None:\n        raise RuntimeError(\"Paged attention doesn't support softcapping\")\n\n    # value_cache => [num_blocks, num_heads, head_size, block_size]\n    # block_size = kv_cache.value.shape[3]\n    block_size = BLOCK_SIZE\n    num_seqs, num_heads, head_size = query.shape\n\n    num_kv_heads = kv_cache.key.shape[1]\n    gqa_ratio = num_heads // num_kv_heads\n    use_custom = _use_rocm_custom_paged_attention(\n        query.dtype, head_size, block_size, gqa_ratio, max_s\n    )\n\n    if not use_custom:\n        _PARTITION_SIZE = _PARTITION_SIZE_V1V2\n    else:\n        _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM\n\n    max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE\n    input_lengths = seqlen.input_lengths + seqlen.cache_lengths\n\n    out = torch.empty_like(query)\n\n    if kv_cache.dtype == torch.float8_e4m3fn:\n        key = kv_cache.key.view(torch.uint8)\n        value = kv_cache.value.view(torch.uint8)\n        kv_cache_dtype = \"fp8\"\n    else:\n        key = kv_cache.key\n        value = kv_cache.value\n        kv_cache_dtype = \"auto\"\n\n    # NOTE(woosuk): We use a simple heuristic to decide whether to use\n    # PagedAttention V1 or V2. If the number of partitions is 1, we use\n    # V1 to avoid the overhead of reduction. Also, if the number of\n    # sequences or heads is large, we use V1 since there is enough work\n    # to parallelize.\n    use_v1 = (\n        max_s <= 8192\n        and (max_num_partitions == 1 or num_seqs * num_heads > 512)\n        and not use_custom\n    )\n    if use_v1:\n        ops.paged_attention_v1(\n            out,\n            query,\n            key,\n            value,\n            num_kv_heads,\n            softmax_scale,\n            block_tables,\n            input_lengths,\n            block_size,\n            max_s,\n            None,\n            kv_cache_dtype,\n            kv_scales.key_scale_cpu,\n            kv_scales.value_scale_cpu,\n        )\n    else:\n        # Run PagedAttention V2.\n        assert _PARTITION_SIZE % block_size == 0\n        tmp_output = torch.zeros(\n            size=(num_seqs, num_heads, max_num_partitions, head_size),\n            dtype=out.dtype,\n            device=out.device,\n        )\n        exp_sums = torch.zeros(\n            size=(num_seqs, num_heads, max_num_partitions),\n            dtype=torch.float32,\n            device=out.device,\n        )\n        max_logits = torch.zeros_like(exp_sums)\n\n        if not use_custom:\n            ops.paged_attention_v2(\n                out,\n                exp_sums,\n                max_logits,\n                tmp_output,\n                query,\n                key,\n                value,\n                num_kv_heads,\n                softmax_scale,\n                block_tables,\n                input_lengths,\n                block_size,\n                max_s,\n                None,\n                kv_cache_dtype,\n                kv_scales.key_scale_cpu,\n                kv_scales.value_scale_cpu,\n            )\n        else:\n            ops.paged_attention_rocm(\n                out,\n                exp_sums,\n                max_logits,\n                tmp_output,\n                query,\n                key,\n                value,\n                num_kv_heads,\n                softmax_scale,\n                block_tables,\n                input_lengths,\n                block_size,\n                max_s,\n                None,\n                kv_cache_dtype,\n                kv_scales.key_scale_cpu,\n                kv_scales.value_scale_cpu,\n                None,\n                _PARTITION_SIZE,\n            )\n\n    return out\n\n\nif ENGINE != \"triton\":\n    try:\n        import flash_attn_2_cuda\n\n        log_master(\n            logger.info,\n            \"ROCm: using Flash Attention 2 Composable Kernel implementation.\",\n        )\n    except ImportError as e:\n        if major >= 8:\n            architecture_suffix = f\"-{SYSTEM}\"\n            raise ImportError(\n                \"Flash Attention V2 is not installed.\\n\"\n                \"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) \"\n                f\"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`\"\n            )\n        elif is_sm75:\n            raise ImportError(\n                \"Flash Attention is not installed.\\n\"\n                \"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) \"\n                \"or install flash attention with `cd server && make install install-flash-attention`\"\n            ) from e\n        else:\n            for idx in range(torch.cuda.device_count()):\n                name = torch.cuda.get_device_name(idx)\n                if \"MI210\" not in name and \"MI250\" not in name:\n                    raise ImportError(\n                        f\"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention\"\n                    )\n            raise ImportError(\n                f\"AMD GPU with ROCm capability {major} {minor} is not supported\"\n            ) from e\n\n\nSUPPORTS_WINDOWING = False\n\n\ndef attention(\n    *,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    kv_cache: KVCache,\n    kv_scales: KVScales,\n    seqlen: Seqlen,\n    block_tables: torch.Tensor,\n    softmax_scale: float,\n    window_size_left: int = -1,\n    causal: bool = True,\n    softcap: Optional[float] = None,\n):\n    if ENGINE == \"ck\":\n        if window_size_left <= 0 and window_size_left != -1:\n            raise ValueError(\"`window_size_left` must be > 0 or -1\")\n\n        out = torch.empty_like(query)\n\n        if softcap is None:\n            softcap = 0.0\n\n        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.\n        return flash_attn_2_cuda.varlen_fwd(\n            query,\n            # flashdecoding: pass the KV caches, paged: pass the KV.\n            kv_cache.key if ATTENTION == \"flashdecoding\" else key,\n            kv_cache.value if ATTENTION == \"flashdecoding\" else value,\n            out,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_k,\n            None,\n            None,\n            block_tables if ATTENTION == \"flashdecoding\" else None,\n            None,\n            seqlen.max_q,\n            seqlen.max_k,\n            0.0,\n            softmax_scale,\n            False,\n            causal,\n            window_size_left,\n            0,\n            softcap,\n            False,\n            None,\n        )[0]\n\n    elif ENGINE == \"triton\":\n        from .flash_attn_triton import triton_attention\n\n        if softcap is not None:\n            raise NotImplementedError(\"softcap is only available with CK flash attn\")\n\n        out = torch.empty_like(query)\n\n        # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.\n        output, _ = triton_attention(\n            query,\n            key,\n            value,\n            out,\n            seqlen.cu_seqlen_q,\n            seqlen.cu_seqlen_q,\n            seqlen.max_q,\n            seqlen.max_k,\n            causal,\n            softmax_scale,\n        )\n        return output\n\n    else:\n        raise RuntimeError(f\"Unknown attention engine {ENGINE}\")\n\n\n__all__ = [\n    \"SUPPORTS_WINDOWING\",\n    \"attention\",\n    \"paged_attention\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/awq/conversion_utils.py",
    "content": "import torch\nfrom typing import List\n\n\nAWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]\nREVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n\n\ndef pack(imatrix: torch.Tensor, direction: str = \"column\"):\n    \"\"\"\n    Packs a 4-bit integer matrix into a packed 32-bit integer matrix.\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n        direction (str): direction of packing, either \"column\" or \"row\"\n    Returns:\n        qmatrix (torch.Tensor): packed matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=imatrix.device)\n\n    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow\n\n    if direction == \"column\":\n        imatrix = imatrix.view(-1, imatrix.shape[1] // (32 // 4), (32 // 4))\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)\n\n    elif direction == \"row\":\n        imatrix = imatrix.view(imatrix.shape[0] // (32 // 4), (32 // 4), -1)\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)\n\n    qmatrix = qmatrix.to(torch.int32)\n\n    return qmatrix\n\n\ndef unpack(qmatrix: torch.Tensor, direction: str = \"column\"):\n    \"\"\"\n    Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.\n    Args:\n        qmatrix (torch.Tensor): matrix of packed integers\n        direction (str): direction of unpacking, either \"column\" or \"row\"\n    Returns:\n        imatrix (torch.Tensor): matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, 32, 4, device=qmatrix.device)\n\n    if direction == \"column\":\n        imatrix = torch.bitwise_right_shift(\n            qmatrix[:, :, None], shifts[None, None, :]\n        ).view(qmatrix.shape[0], -1)\n\n    elif direction == \"row\":\n        imatrix = torch.bitwise_right_shift(\n            qmatrix[:, None, :], shifts[None, :, None]\n        ).view(-1, qmatrix.shape[-1])\n\n    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow\n\n    return imatrix\n\n\ndef apply_order(\n    imatrix: torch.Tensor,\n    direction: str = \"column\",\n    order: List[int] = AWQ_PACK_ORDER,\n):\n    \"\"\"\n    Applies the order to a 4-bit integer matrix.\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n        direction (str): direction of applying order, either \"column\" or \"row\"\n        order (List[int]): order to apply, default is AWQ_PACK_ORDER\n    Returns:\n        imatrix (torch.Tensor): matrix of integers\n    \"\"\"\n    if direction == \"column\":\n        imatrix = imatrix.view(-1, (32 // 4))[:, order].view(imatrix.shape)\n    elif direction == \"row\":\n        imatrix = imatrix.view((32 // 4), -1)[order, :].view(imatrix.shape)\n\n    return imatrix\n\n\ndef fast_awq_to_gptq(qweight, qzeros):\n    # awq uses column packing for both weights and zeros\n    izeros = unpack(qzeros, direction=\"column\")\n    iweights = unpack(qweight, direction=\"column\")\n\n    # Reverse the order of the iweight and izeros tensors\n    izeros = apply_order(izeros, direction=\"column\", order=REVERSE_AWQ_PACK_ORDER)\n    iweights = apply_order(iweights, direction=\"column\", order=REVERSE_AWQ_PACK_ORDER)\n    # Subtract 1 from the izeros tensor (gptq adds 1 to the zeros)\n    izeros = izeros - 1\n    # exllama uses row packing for weights and column packing for zeros\n    qzeros = pack(izeros, direction=\"column\")\n    qweight = pack(iweights, direction=\"row\")\n\n    return qweight, qzeros\n"
  },
  {
    "path": "server/text_generation_server/layers/awq/quantize/__init__.py",
    "content": "from text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"ipex\":\n    from .ipex import WQLinear\nelif SYSTEM == \"cuda\":\n    from .cuda import WQLinear\n\n__all__ = [\"WQLinear\"]\n"
  },
  {
    "path": "server/text_generation_server/layers/awq/quantize/cuda.py",
    "content": "# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py\n\nfrom typing import Optional\nimport torch\nimport torch.nn as nn\nimport awq_inference_engine  # with CUDA kernels\n\n\n# class ScaledActivation(nn.Module):\n#     def __init__(self, module, scales):\n#         super().__init__()\n#         self.act = module\n#         self.scales = nn.Parameter(scales.data)\n#\n#     def forward(self, x):\n#         return self.act(x) / self.scales.view(1, 1, -1).to(x.device)\n\n\nclass WQLinear(nn.Module):\n    def __init__(\n        self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]\n    ):\n        super().__init__()\n\n        if w_bit not in [4]:\n            raise NotImplementedError(\"Only 4-bit are supported for now.\")\n\n        self.in_features = qweight.shape[0]\n        self.out_features = qweight.shape[1] * 32 // w_bit\n\n        self.w_bit = w_bit\n        self.group_size = group_size if group_size != -1 else self.in_features\n        # quick sanity check (make sure aligment)\n        assert self.in_features % self.group_size == 0\n        assert self.out_features % (32 // self.w_bit) == 0\n\n        self.qweight = qweight\n        self.qzeros = qzeros\n        self.scales = scales\n        self.bias = bias\n\n    @torch.no_grad()\n    def forward(self, x):\n        out_shape = x.shape[:-1] + (self.out_features,)\n        out = awq_inference_engine.gemm_forward_cuda(\n            x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8\n        )\n        out = out + self.bias if self.bias is not None else out\n        return out.reshape(out_shape)\n"
  },
  {
    "path": "server/text_generation_server/layers/awq/quantize/ipex.py",
    "content": "from typing import Optional\nimport torch\nimport torch.nn as nn\nimport intel_extension_for_pytorch as ipex\n\n\nclass WQLinear(nn.Module):\n    def __init__(\n        self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]\n    ):\n        super().__init__()\n\n        if w_bit not in [4]:\n            raise NotImplementedError(\"Only 4-bit are supported for now.\")\n\n        self.in_features = qweight.shape[0]\n        self.out_features = qweight.shape[1] * 32 // w_bit\n\n        self.w_bit = w_bit\n        self.group_size = group_size if group_size != -1 else self.in_features\n        # quick sanity check (make sure aligment)\n        assert self.in_features % self.group_size == 0\n        assert self.out_features % (32 // self.w_bit) == 0\n\n        self.qweight = qweight\n        self.qzeros = qzeros\n        self.scales = scales\n        self.bias = bias\n        self.woq_linear = (\n            ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(\n                self.qweight,\n                self.scales,\n                self.qzeros,\n                self.in_features,\n                self.out_features,\n                bias=self.bias,\n                group_size=self.group_size,\n                quant_method=ipex.llm.quantization.QuantMethod.AWQ_GEMM,\n                dtype=ipex.llm.quantization.QuantDtype.INT4,\n            )\n        )\n\n    @torch.no_grad()\n    def forward(self, x):\n        out_shape = x.shape[:-1] + (self.out_features,)\n        out = self.woq_linear(x.reshape(-1, x.shape[-1]))\n        return out.reshape(out_shape)\n"
  },
  {
    "path": "server/text_generation_server/layers/bnb.py",
    "content": "from dataclasses import dataclass\n\nimport bitsandbytes as bnb\nimport torch\nfrom bitsandbytes.nn import Int8Params, Params4bit\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\n@dataclass\nclass BNBWeight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0)\n\n\nclass Linear8bitLt(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n        has_fp16_weights=True,\n        memory_efficient_backward=False,\n        threshold=0.0,\n        index=None,\n    ):\n        super().__init__()\n        assert (\n            not memory_efficient_backward\n        ), \"memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0\"\n        self.state = bnb.MatmulLtState()\n        self.index = index\n\n        # Necessary for stacked layers\n        self.state.threshold = threshold\n        self.state.has_fp16_weights = has_fp16_weights\n        self.state.memory_efficient_backward = memory_efficient_backward\n        if threshold > 0.0 and not has_fp16_weights:\n            self.state.use_pool = True\n\n        self.weight = Int8Params(\n            weight.data,\n            has_fp16_weights=has_fp16_weights,\n            requires_grad=has_fp16_weights,\n        )\n        self.weight.cuda(weight.device)\n        self.bias = bias\n\n    def init_8bit_state(self):\n        self.state.CB = self.weight.CB\n        self.state.SCB = self.weight.SCB\n        self.weight.CB = None\n        self.weight.SCB = None\n\n    def forward(self, x: torch.Tensor):\n        self.state.is_training = self.training\n        if self.weight.CB is not None:\n            self.init_8bit_state()\n\n        # weights are cast automatically as Int8Params, but the bias has to be cast manually\n        if self.bias is not None and self.bias.dtype != x.dtype:\n            self.bias.data = self.bias.data.to(x.dtype)\n\n        out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)\n\n        if not self.state.has_fp16_weights:\n            if self.state.CB is not None and self.state.CxB is not None:\n                # we converted 8-bit row major to turing/ampere format in the first inference pass\n                # we no longer need the row-major weight\n                del self.state.CB\n                self.weight.data = self.state.CxB\n        return out\n\n\n@dataclass\nclass BNBFP4Weight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear4bit(self.weight, bias, quant_type=\"fp4\")\n\n\n@dataclass\nclass BNBNF4Weight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        return Linear4bit(self.weight, bias, quant_type=\"nf4\")\n\n\nclass Linear4bit(torch.nn.Module):\n    def __init__(self, weight, bias, quant_type):\n        super().__init__()\n        self.weight = Params4bit(\n            weight.data,\n            requires_grad=False,\n            compress_statistics=True,\n            quant_type=quant_type,\n        )\n        self.compute_dtype = None\n        self.weight.cuda(weight.device)\n        self.bias = bias\n\n    def forward(self, x: torch.Tensor):\n        # weights are cast automatically as Int8Params, but the bias has to be cast manually\n        if self.bias is not None and self.bias.dtype != x.dtype:\n            self.bias.data = self.bias.data.to(x.dtype)\n\n        if getattr(self.weight, \"quant_state\", None) is None:\n            print(\n                \"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.\"\n            )\n        inp_dtype = x.dtype\n        if self.compute_dtype is not None:\n            x = x.to(self.compute_dtype)\n\n        bias = None if self.bias is None else self.bias.to(self.compute_dtype)\n        out = bnb.matmul_4bit(\n            x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state\n        )\n\n        out = out.to(inp_dtype)\n\n        return out\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/__init__.py",
    "content": "from .loader import CompressedTensorsLoader\n\n__all__ = [\"CompressedTensorsLoader\"]\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/loader.py",
    "content": "from typing import Any, Dict, List, Union\n\nfrom compressed_tensors import QuantizationConfig, QuantizationStatus\nfrom compressed_tensors.config import CompressionFormat\nfrom compressed_tensors.quantization import (\n    QuantizationScheme,\n    QuantizationType,\n    find_name_or_class_matches,\n)\nfrom loguru import logger\nfrom pydantic import ValidationError\nfrom torch import nn\n\nfrom text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader\nfrom text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader\nfrom text_generation_server.layers.compressed_tensors.wna16_int_24 import (\n    WNA16Int24Loader,\n)\nfrom text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    UnquantizedWeight,\n    Weights,\n    WeightsLoader,\n)\n\n# compressed-tensors can match modules as quantization targets. However,\n# they need to be objects rather than classes or class names. Since we\n# need to match `Linear` targets, make an instance that can be re-used.\n_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)\n\n\nclass CompressedTensorsLoader(WeightsLoader):\n    \"\"\"Loader for checkpoints stored in the compressed-tensors format.\"\"\"\n\n    def __init__(self, config: Dict[str, Any]):\n        quantization_config_raw = config.get(\"quantization_config\")\n        if quantization_config_raw is None:\n            # `compression_config` was renamed to `quantization_config`; support\n            # retained for backward compatibility.\n            quantization_config_raw = config.get(\"compression_config\")\n        if quantization_config_raw is None:\n            raise ValueError(\n                \"Checkpoint does not have compressed-tensors configuration\"\n            )\n\n        try:\n            quantization_config = QuantizationConfig.model_validate(\n                quantization_config_raw\n            )\n        except ValidationError as e:\n            raise ValueError(\"Cannot parse compressed-tensors configuration\") from e\n\n        if quantization_config.quantization_status not in (\n            QuantizationStatus.COMPRESSED,\n            QuantizationStatus.FROZEN,\n        ):\n            raise ValueError(\n                f\"Model quantization was not finished, status was: {quantization_config.quantization_status}\"\n            )\n\n        self.ignore = (\n            quantization_config.ignore if quantization_config.ignore is not None else []\n        )\n        self.loaders = self._get_target_loaders(quantization_config)\n\n        for target, loader in self.loaders.items():\n            log_once(\n                logger.info,\n                f\"Using {loader} for compressed-tensors target '{target}'\",\n            )\n\n    def get_weights(self, weights: Weights, prefix: str):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights(weights, prefix)\n\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights_col_packed(weights, prefix, block_sizes)\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        loader = self._lookup_loader(prefixes[0])\n        return loader.get_multi_weights_col(weights, prefixes, dim)\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        loader = self._lookup_loader(prefix)\n        return loader.get_weights_row(weights, prefix)\n\n    def _get_target_loaders(\n        self, quantization_config: QuantizationConfig\n    ) -> Dict[str, WeightsLoader]:\n        \"\"\"\n        A compressed-tensors checkpoint can use different quantizations\n        for different targets. This method returns a dictionary with a\n        loader per target.\n        \"\"\"\n\n        loaders: Dict[str, WeightsLoader] = {}\n\n        format = quantization_config.format\n\n        for group_name, group in quantization_config.config_groups.items():\n            # The group configuration can be a string, but does that ever\n            # happen in a serialized quantization config?\n            assert isinstance(group, QuantizationScheme)\n\n            loader = self._create_loader_for_group(format, group_name, group)\n\n            # A quantized parameter group can have multiple targets, add the\n            # loader for all the targets.\n            for target in group.targets:\n                if target in loaders:\n                    raise ValueError(\n                        f\"Target '{target} has multiple configured loaders'\"\n                    )\n                loaders[target] = loader\n\n        return loaders\n\n    def _create_loader_for_group(\n        self, format: str, group_name: str, group: QuantizationScheme\n    ) -> WeightsLoader:\n        \"\"\"\n        Find and create a loader for the group with the given quantization\n        scheme.\n        \"\"\"\n        # NOTE: we ignore group.output_activations because we don't support\n        #       output quantization yet.\n\n        input_activations = group.input_activations\n        weights = group.weights\n        if (\n            format\n            in {\n                CompressionFormat.float_quantized.value,\n                CompressionFormat.naive_quantized.value,\n            }\n            and weights is not None\n            and weights.type == QuantizationType.FLOAT\n            and weights.num_bits == 8\n        ):\n            # FP W8A8 or W8A16.\n            return W8ANFpLoader(input_activations=input_activations, weights=weights)\n        elif (\n            format == CompressionFormat.pack_quantized.value\n            and weights is not None\n            and weights.type == QuantizationType.INT\n            and weights.num_bits in (4, 8)\n        ):\n            # INT W4A16 or W8A16 (GPTQ/AWQ-like).\n            return WNA16IntLoader(weights)\n        elif (\n            format == CompressionFormat.marlin_24.value\n            and weights is not None\n            and weights.type == QuantizationType.INT\n            and weights.num_bits in (4, 8)\n        ):\n            return WNA16Int24Loader(weights)\n        elif (\n            format\n            in {\n                CompressionFormat.int_quantized.value,\n                CompressionFormat.naive_quantized.value,\n            }\n            and weights is not None\n            and weights.type == QuantizationType.INT\n            and weights.num_bits == 8\n        ):\n            return W8A8IntLoader(input_args=input_activations, weight_args=weights)\n        else:\n            raise ValueError(\n                f\"Group '{group_name}' has unsupported compressed-tensors configurtion\"\n            )\n\n    def _lookup_loader(self, prefix: str) -> WeightsLoader:\n        \"\"\"\n        Look up the loader to use for a given parameter name (prefix).\n        \"\"\"\n\n        if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:\n            return DefaultWeightsLoader(UnquantizedWeight)\n\n        # We currently only handle linear layers, so unconditionally pass\n        # a `Linear` instance.\n        targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())\n        if len(targets) == 0:\n            raise ValueError(\n                f\"Cannot find compressed-tensors target for prefix: {prefix}\"\n            )\n        return self.loaders[targets[0]]\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/w8a8_int.py",
    "content": "from typing import List, Optional, Union, TypeVar\nfrom dataclasses import dataclass\n\nfrom loguru import logger\nimport torch\nfrom compressed_tensors.quantization import QuantizationArgs, QuantizationType\n\nfrom text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import Weight, Weights, WeightsLoader\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\n\nclass W8A8IntLoader(WeightsLoader):\n    \"\"\"\n    Loader for w8a8 integer compressed-tensors parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        input_args: Optional[QuantizationArgs],\n        weight_args: QuantizationArgs,\n    ):\n        if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8:\n            raise ValueError(\n                f\"{type(self).__name__} only supports w8a8 int checkpoints\"\n            )\n\n        if not weight_args.symmetric:\n            raise ValueError(\"Checkpoints with asymmetric weights are not supported\")\n\n        self.load_weight_scale = not weight_args.dynamic\n\n        if input_args is not None:\n            self.input_symmetric = input_args.symmetric\n\n            if not input_args.dynamic:\n                log_once(\n                    logger.warning,\n                    \"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).\",\n                )\n        else:\n            self.input_symmetric = True\n\n    def __str__(self) -> str:\n        def scale_to_str(scale):\n            return \"static\" if scale else \"dynamic\"\n\n        def symmetric_to_str(symmetric):\n            return \"symmetric\" if symmetric else \"asymmetric\"\n\n        return f\"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        w = weights.get_tensor(f\"{prefix}.weight\", to_dtype=False)\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(\n                f\"{prefix}.weight_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Int8Weight(\n            input_symmetric=self.input_symmetric,\n            weight=w,\n            weight_scale=weight_scale,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        w = weights.get_packed_sharded(\n            f\"{prefix}.weight\", dim=0, block_sizes=block_sizes, to_dtype=False\n        )\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n            if weight_scale.numel() > 1:\n                weight_scale = weights.get_packed_sharded(\n                    f\"{prefix}.weight_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            weight_scale = weight_scale.reshape(-1)\n\n        return Int8Weight(\n            input_symmetric=self.input_symmetric,\n            weight=w,\n            weight_scale=weight_scale,\n        )\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        w = [\n            weights.get_sharded(f\"{p}.weight\", dim=0, to_dtype=False) for p in prefixes\n        ]\n        shapes = [x.shape for x in w]\n\n        w = torch.cat(w, dim=dim)\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.weight_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n            ]\n            weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)\n\n        return Int8Weight(\n            input_symmetric=self.input_symmetric,\n            weight=w,\n            weight_scale=weight_scale,\n        )\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        w = weights.get_sharded(f\"{prefix}.weight\", dim=1, to_dtype=False)\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(\n                f\"{prefix}.weight_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Int8Weight(\n            input_symmetric=self.input_symmetric,\n            weight=w,\n            weight_scale=weight_scale,\n        )\n\n\nOtherT = TypeVar(\"OtherT\")\n\n\ndef _get_tensor_or_else(\n    weights: Weights, prefix: str, other: OtherT\n) -> Union[torch.Tensor, OtherT]:\n    # Even if a checkpoint uses e.g. zero-points, they can be elided:\n    # https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105\n    if weights.has_tensor(prefix):\n        return weights.get_tensor(prefix, to_dtype=False)\n    else:\n        return other\n\n\n@dataclass\nclass Int8Weight(Weight):\n    input_symmetric: bool\n    weight: torch.Tensor\n    weight_scale: Optional[torch.Tensor]\n\n    def get_linear(self, bias: torch.Tensor):\n        if self.weight_scale is None:\n            assert quantization is not None\n            qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)\n            return W8A8IntLinear(\n                bias=bias,\n                input_symmetric=self.input_symmetric,\n                weight=qweight,\n                weight_scale=weight_scale,\n            )\n        else:\n            return W8A8IntLinear(\n                bias=bias,\n                input_symmetric=self.input_symmetric,\n                weight=self.weight,\n                weight_scale=self.weight_scale,\n            )\n\n\nclass W8A8IntLinear(torch.nn.Module):\n    def __init__(\n        self,\n        *,\n        bias: Optional[torch.Tensor],\n        input_symmetric: bool,\n        weight: torch.Tensor,\n        weight_scale: torch.Tensor,\n    ):\n        super().__init__()\n\n        weight_scale = weight_scale.to(torch.float32)\n\n        self.bias = bias\n        self.input_symmetric = input_symmetric\n        # cutlass kernels require transposed weights.\n        self.weight = weight.t()\n        self.weight_scale = weight_scale\n\n        if input_symmetric:\n            self.zero_point_adj = None\n        else:\n            # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp\n            self.zero_point_adj = self.weight.sum(\n                dim=0, keepdim=True, dtype=torch.int32\n            )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        assert quantization is not None\n\n        qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(\n            input=input,\n            scale=None,\n            azp=None,\n            symmetric=self.input_symmetric,\n        )\n\n        if self.input_symmetric:\n            return quantization.cutlass_scaled_mm(\n                a=qinput,\n                b=self.weight,\n                scale_a=input_scale,\n                scale_b=self.weight_scale,\n                out_dtype=input.dtype,\n                bias=self.bias,\n            )\n        else:\n            assert (\n                self.zero_point_adj is not None\n                and input_scale is not None\n                and (self.input_symmetric or input_zero_point is not None)\n            )\n\n            return quantization.cutlass_scaled_mm_azp(\n                a=qinput,\n                b=self.weight,\n                scale_a=input_scale,\n                scale_b=self.weight_scale,\n                out_dtype=input.dtype,\n                azp_adj=self.zero_point_adj,\n                azp=input_zero_point,\n                bias=self.bias,\n            )\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/w8an_fp.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom compressed_tensors.quantization import QuantizationArgs, QuantizationType\n\nfrom text_generation_server.layers.fp8 import (\n    Fp8Weight,\n    _load_scalar_or_matrix_scale,\n    requantize_with_max_scale,\n    normalize_e4m3fn_to_native_float8,\n)\nfrom text_generation_server.utils.weights import Weights, WeightsLoader\nfrom text_generation_server.utils.import_utils import SYSTEM\n\n\nclass W8ANFpLoader(WeightsLoader):\n    \"\"\"\n    Loader for W8A8/W8A16 FP compressed-tensors parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        input_activations: Optional[QuantizationArgs],\n        weights: QuantizationArgs,\n    ):\n        assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8\n\n        # We ignore the `strategy` option which sets the scales to be\n        # per-tensor, per-channel or per-token. What scales are supported\n        # is dependent on the kernels used (e.g. cutlass can do tokenwise,\n        # Torch cannot, and FP8-Marlin does not quantize inputs at all).\n        # So, instead we try to use the best-possible configuration.\n\n        self.load_weight_scale = not weights.dynamic\n        self.load_input_scale = (\n            input_activations is not None and not input_activations.dynamic\n        )\n        self.force_w8a16 = (\n            input_activations is not None and input_activations.num_bits == 16\n        )\n\n    def __str__(self) -> str:\n        def scale_to_str(scale):\n            return \"static\" if scale else \"dynamic\"\n\n        quantization_type = f\"W8A{16 if self.force_w8a16 else 8}\"\n\n        return f\"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        w = weights.get_tensor(f\"{prefix}.weight\")\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if SYSTEM == \"cuda\":\n                weight_scale = weight_scale.reshape(-1).expand(w.shape[0])\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(\n                f\"{prefix}.input_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        w = weights.get_packed_sharded(\n            f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n        )\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n            if weight_scale.numel() > 1:\n                weight_scale = weights.get_packed_sharded(\n                    f\"{prefix}.weight_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            if SYSTEM == \"cuda\":\n                weight_scale = weight_scale.reshape(-1).expand(w.shape[0])\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n            if input_scale.numel() > 1:\n                input_scale = weights.get_packed_sharded(\n                    f\"{prefix}.input_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            input_scale = input_scale.reshape(-1).max()\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [\n            weights.get_sharded(f\"{p}.weight\", dim=0, to_device=False) for p in prefixes\n        ]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.weight_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n            ]\n            weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.input_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n        if self.load_weight_scale and SYSTEM == \"rocm\":\n            w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(\n                w, weight_scale, input_scale\n            )\n\n            if weight_scale.numel() == len(prefixes):\n                logical_widths = [x[0] for x in shapes]\n                w, weight_scale = requantize_with_max_scale(\n                    w, weight_scale.to(weights.device), logical_widths, weights.dtype\n                )\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        w = weights.get_sharded(f\"{prefix}.weight\", dim=1)\n        weight_scale = None\n        if self.load_weight_scale:\n            weight_scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if SYSTEM == \"cuda\":\n                weight_scale = weight_scale.reshape(-1).expand(w.shape[0])\n\n        input_scale = None\n        if self.load_input_scale:\n            input_scale = weights.get_tensor(\n                f\"{prefix}.input_scale\", to_dtype=False\n            ).reshape(-1)\n\n        return Fp8Weight(\n            weight=w,\n            weight_scale=weight_scale,\n            input_scale=input_scale,\n            dtype=weights.dtype,\n            force_w8a16=self.force_w8a16,\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/wna16_int.py",
    "content": "from typing import List, Union\n\nimport torch\nfrom compressed_tensors.quantization import ActivationOrdering, QuantizationArgs\nfrom loguru import logger\n\nfrom text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import Weights, WeightsLoader\n\n\nclass WNA16IntLoader(WeightsLoader):\n    \"\"\"\n    Loader for W4A16/W8A16 INT compressed-tensors parameters.\n    \"\"\"\n\n    def __init__(self, weights: QuantizationArgs):\n        self.weights = weights\n        self.desc_act = self.weights.actorder == ActivationOrdering.GROUP\n        self.groupsize = (\n            -1 if self.weights.group_size is None else self.weights.group_size\n        )\n\n    def __str__(self) -> str:\n        quantization_type = f\"W{self.weights.num_bits}A16\"\n\n        return f\"{self.__class__.__name__} ({quantization_type})\"\n\n    def get_weights(self, weights: Weights, prefix: str):\n        log_once(logger.info, \"Using GPTQ-Marlin kernels\")\n        try:\n            weight_packed = weights.get_tensor(f\"{prefix}.weight_packed\").t()\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized\"\n            )\n\n        zero_point = None\n        if not self.weights.symmetric:\n            zero_point = weights.get_tensor(f\"{prefix}.weight_zero_point\").t()\n\n        g_idx = None\n        if self.desc_act:\n            g_idx = weights.get_tensor(f\"{prefix}.weight_g_idx\")\n\n        scales = weights.get_tensor(f\"{prefix}.weight.scales\").t()\n\n        return repack_gptq_for_marlin(\n            qweight=weight_packed.contiguous(),\n            scales=scales,\n            qzeros=zero_point,\n            g_idx=g_idx,\n            bits=self.weights.num_bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=\"compressed-tensors\",\n            sym=self.weights.symmetric,\n            sharded_infeatures=False,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        try:\n            weight_packed = weights.get_packed_sharded(\n                f\"{prefix}.weight_packed\", dim=0, block_sizes=block_sizes\n            ).t()\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized\"\n            )\n        scales = weights.get_packed_sharded(\n            f\"{prefix}.weight_scale\", dim=0, block_sizes=block_sizes\n        ).t()\n        scales = scales.to(dtype=weights.dtype)\n\n        zero_point = None\n        if not self.weights.symmetric:\n            zero_point = weights.get_packed_sharded(\n                f\"{prefix}.qzeros\", dim=0, block_sizes=block_sizes\n            ).t()\n\n        g_idx = None\n        if self.desc_act:\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n\n        return repack_gptq_for_marlin(\n            qweight=weight_packed.contiguous(),\n            scales=scales,\n            qzeros=zero_point,\n            g_idx=g_idx,\n            bits=self.weights.num_bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=\"compressed-tensors\",\n            sym=self.weights.symmetric,\n            sharded_infeatures=False,\n        )\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        try:\n            weight_packed = torch.cat(\n                [\n                    weights.get_sharded(f\"{p}.weight_packed\", dim=0).t()\n                    for p in prefixes\n                ],\n                dim=1,\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized\"\n            )\n\n        scales = torch.cat(\n            [weights.get_sharded(f\"{p}.weight_scale\", dim=0).t() for p in prefixes],\n            dim=1,\n        )\n\n        zero_point = None\n        if not self.weights.symmetric:\n            zero_point = torch.cat(\n                [weights.get_sharded(f\"{p}.qzeros\", dim=0).t() for p in prefixes], dim=1\n            ).t()\n\n        g_idx = None\n        if self.desc_act:\n            w = [weights.get_tensor(f\"{p}.g_idx\") for p in prefixes]\n            for w2 in w[1:]:\n                torch.testing.assert_close(w2, w[0])\n            g_idx = w[0]\n\n        return repack_gptq_for_marlin(\n            qweight=weight_packed.contiguous(),\n            scales=scales,\n            qzeros=zero_point,\n            g_idx=g_idx,\n            bits=self.weights.num_bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=\"compressed-tensors\",\n            sym=self.weights.symmetric,\n            sharded_infeatures=False,\n        )\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        log_once(logger.info, \"Using GPTQ-Marlin kernels\")\n        try:\n            weight_packed = weights.get_sharded(f\"{prefix}.weight_packed\", dim=1).t()\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized.\"\n            )\n\n        zero_point = None\n        if not self.weights.symmetric:\n            if self.desc_act or self.groupsize == -1:\n                zero_point = weights.get_tensor(f\"{prefix}.weight_zero_point\").t()\n            else:\n                zero_point = weights.get_sharded(\n                    f\"{prefix}.weight_zero_point\", dim=1\n                ).t()\n\n        g_idx = None\n        if self.desc_act:\n            g_idx = weights.get_sharded(f\"{prefix}.g_idx\", dim=0)\n\n        if self.desc_act or self.groupsize == -1:\n            scales = weights.get_tensor(f\"{prefix}.weight_scale\").t()\n        else:\n            scales = weights.get_sharded(f\"{prefix}.weight_scale\", dim=1).t()\n\n        sharded_in_features = weights.process_group.size() > 1\n\n        return repack_gptq_for_marlin(\n            qweight=weight_packed.contiguous(),\n            scales=scales,\n            qzeros=zero_point,\n            g_idx=g_idx,\n            bits=self.weights.num_bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=\"compressed-tensors\",\n            sym=self.weights.symmetric,\n            sharded_infeatures=sharded_in_features,\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/compressed_tensors/wna16_int_24.py",
    "content": "from typing import List, Union\n\nimport torch\n\n\nfrom compressed_tensors.quantization import QuantizationArgs, QuantizationType\nfrom text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight\nfrom text_generation_server.utils.weights import Weights, WeightsLoader\n\n\nclass WNA16Int24Loader(WeightsLoader):\n    \"\"\"\n    Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.\n    \"\"\"\n\n    def __init__(self, weight_args: QuantizationArgs):\n        super().__init__()\n\n        if weight_args.type != QuantizationType.INT:\n            raise ValueError(\n                f\"{type(self).__name__} only supports wNa8 int checkpoints\"\n            )\n\n        if weight_args.strategy == \"group\" and weight_args.group_size is None:\n            raise ValueError(\"`group_size` must be set when `actorder` is `group`\")\n\n        self.bits = weight_args.num_bits\n        self.group_size = weight_args.group_size\n\n    def __str__(self) -> str:\n        quantization_type = f\"W{self.bits}A16 2:4 sparsity\"\n\n        return f\"{self.__class__.__name__} ({quantization_type})\"\n\n    def get_weights(self, weights: Weights, prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        weight_packed = weights.get_tensor(f\"{prefix}.weight_packed\")\n        meta = weights.get_tensor(f\"{prefix}.meta\")\n        scale_packed = weights.get_tensor(f\"{prefix}.scale_packed\")\n        return GPTQMarlin24Weight(\n            weight_packed=weight_packed,\n            meta=meta,\n            scale_packed=scale_packed,\n            bits=self.bits,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        weight_packed = weights.get_packed_sharded(\n            f\"{prefix}.weight_packed\", dim=1, block_sizes=block_sizes\n        )\n        meta = weights.get_packed_sharded(\n            f\"{prefix}.meta\", dim=1, block_sizes=block_sizes\n        )\n        scale_packed = weights.get_packed_sharded(\n            f\"{prefix}.scale_packed\", dim=1, block_sizes=block_sizes\n        )\n        return GPTQMarlin24Weight(\n            weight_packed=weight_packed,\n            meta=meta,\n            scale_packed=scale_packed,\n            bits=self.bits,\n        )\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        weight_packed = torch.cat(\n            [weights.get_sharded(f\"{p}.weight_packed\", dim=1) for p in prefixes], dim=1\n        )\n        meta = torch.cat(\n            [weights.get_sharded(f\"{p}.meta\", dim=1) for p in prefixes], dim=1\n        )\n        scale_packed = torch.cat(\n            [weights.get_sharded(f\"{p}.scale_packed\", dim=1) for p in prefixes], dim=1\n        )\n        return GPTQMarlin24Weight(\n            weight_packed=weight_packed,\n            meta=meta,\n            scale_packed=scale_packed,\n            bits=self.bits,\n        )\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        weight_packed = weights.get_sharded(f\"{prefix}.weight_packed\", dim=0)\n        meta = weights.get_sharded(f\"{prefix}.meta\", dim=0)\n        if self.group_size is None:\n            scale_packed = weights.get_tensor(f\"{prefix}.scale_packed\")\n        else:\n            scale_packed = weights.get_sharded(f\"{prefix}.scale_packed\", dim=0)\n\n        return GPTQMarlin24Weight(\n            weight_packed=weight_packed,\n            meta=meta,\n            scale_packed=scale_packed,\n            bits=self.bits,\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/conv.py",
    "content": "from accelerate import init_empty_weights\nimport torch\n\n\n@classmethod\ndef load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    bias = weights.get_tensor(f\"{prefix}.bias\")\n    with init_empty_weights():\n        conv2d = cls(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n        )\n\n    conv2d.weight = torch.nn.Parameter(weight)\n    conv2d.bias = torch.nn.Parameter(bias)\n    return conv2d\n\n\n@classmethod\ndef load_conv2d_no_bias(\n    cls, prefix, weights, in_channels, out_channels, kernel_size, stride\n):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    with init_empty_weights():\n        conv2d = cls(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            stride=stride,\n        )\n\n    conv2d.weight = torch.nn.Parameter(weight)\n    conv2d.bias = None\n    return conv2d\n\n\ntorch.nn.Conv2d.load = load_conv2d\ntorch.nn.Conv2d.load_no_bias = load_conv2d_no_bias\n"
  },
  {
    "path": "server/text_generation_server/layers/eetq.py",
    "content": "from dataclasses import dataclass\n\nimport torch\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\nquantization_eetq = load_kernel(\n    module=\"quantization_eetq\", repo_id=\"kernels-community/quantization-eetq\"\n)\n\n\n@dataclass\nclass EETQWeight(UnquantizedWeight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        try:\n            from text_generation_server.layers.eetq import EETQLinear\n\n            return EETQLinear(self.weight, bias)\n        except ImportError:\n            raise ImportError(\n                \"Please install EETQ from https://github.com/NetEase-FuXi/EETQ\"\n            )\n\n\nclass EETQLinear(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n    ) -> None:\n        super().__init__()\n        device = weight.device\n        if weight.dtype != torch.float16:\n            weight = weight.to(dtype=torch.float16)\n        weight = torch.t(weight).contiguous().cpu()\n        weight, scale = quantization_eetq.quant_weights(weight, torch.int8, False)\n\n        self.weight = weight.cuda(device)\n        self.scale = scale.cuda(device)\n        self.bias = bias.cuda(device) if bias is not None else None\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        output = quantization_eetq.w8_a16_gemm(input, self.weight, self.scale)\n        output = output + self.bias if self.bias is not None else output\n        return output\n"
  },
  {
    "path": "server/text_generation_server/layers/exl2.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Union\n\nimport torch\nfrom text_generation_server.utils.weights import Weight, Weights, WeightsLoader\n\n\n@dataclass\nclass Exl2Weight(Weight):\n    \"\"\"\n    Exllama2 exl2 quantized weights.\n    \"\"\"\n\n    q_weight: torch.Tensor\n    q_scale: torch.Tensor\n    q_invperm: torch.Tensor\n    q_scale_max: torch.Tensor\n    q_groups: torch.Tensor\n\n    def __post_init__(self):\n        self.q_scale_max /= 256\n        self.q_invperm = self.q_invperm.short()\n\n    @property\n    def device(self) -> torch.device:\n        return self.q_weight.device\n\n    def get_linear(self, bias: torch.Tensor):\n        from text_generation_server.layers.gptq import ExllamaQuantLinear\n\n        return ExllamaQuantLinear(self, bias)\n\n\nclass Exl2WeightsLoader(WeightsLoader):\n    \"\"\"Loader for exl2-quantized weights.\"\"\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        try:\n            q_weight = weights.get_tensor(f\"{prefix}.q_weight\")\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `exl2`-quantized weight, make sure the model is already quantized.\"\n            )\n\n        q_scale = weights.get_tensor(f\"{prefix}.q_scale\")\n        q_invperm = weights.get_tensor(f\"{prefix}.q_invperm\")\n        q_scale_max = weights.get_tensor(f\"{prefix}.q_scale_max\")\n        q_groups = weights.get_tensor(f\"{prefix}.q_groups\")\n\n        return Exl2Weight(\n            q_weight=q_weight,\n            q_scale=q_scale,\n            q_invperm=q_invperm,\n            q_scale_max=q_scale_max,\n            q_groups=q_groups,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        raise RuntimeError(\"Column-packed weights are not supported for exl\")\n\n    def get_weights_col(self, weights: Weights, prefix: str):\n        # Sharding is not yet supported, so we return the weights as-is.\n        return self.get_weights(weights, prefix)\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        raise ValueError(\"get_multi_weights_col is not supported for exl2\")\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        # Sharding is not yet supported, so we return the weights as-is.\n        return self.get_weights(weights, prefix)\n"
  },
  {
    "path": "server/text_generation_server/layers/fp8.py",
    "content": "from dataclasses import dataclass\nimport os\nfrom typing import Optional, Tuple, Type, Union, List\n\nimport torch\nfrom loguru import logger\n\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.weights import (\n    Weight,\n    WeightsLoader,\n    UnquantizedWeight,\n    Weights,\n)\nfrom text_generation_server.utils.log import log_once\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\ntry:\n    from moe_kernels.fp8_utils import w8a8_block_fp8_matmul, per_token_group_quant_fp8\nexcept ImportError:\n    w8a8_block_fp8_matmul = None\n    per_token_group_quant_fp8 = None\n\nquant_dtype: torch.dtype = (\n    torch.float8_e4m3fnuz if SYSTEM == \"rocm\" else torch.float8_e4m3fn\n)\n\nif SYSTEM == \"cuda\" and quantization is not None:\n    major, minor = torch.cuda.get_device_capability()\n    CUTLASS_FP8_AVAILABLE = quantization.cutlass_scaled_mm_supports_fp8(\n        major * 10 + minor\n    )\nelse:\n    CUTLASS_FP8_AVAILABLE = False\n\n\ndef get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:\n    \"\"\"\n    Return an FP8 linear `Module` that is compatible with the current system.\n    \"\"\"\n\n    if SYSTEM == \"cuda\":\n        major, _ = torch.cuda.get_device_capability()\n        # Marlin is W8A16, use it when:\n        #\n        # - On capability 8.x where x < 8: W8A8 FP8 GEMM is not supported.\n        # - On capability 8.9: W8A8 FP8 GEMM is supported, but Marlin-FP8 is faster.\n        # - On capability 9.x when force_w8a16: cutlass kernels do not support W8A16.\n        if (major == 8 or (major == 9 and force_w8a16)) and os.getenv(\n            \"USE_CUTLASS_W8A8\", \"0\"\n        ) != \"1\":\n            # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin\n            #       gives better decoding throughput on L4 and L40.\n            from text_generation_server.layers.marlin import GPTQMarlinFP8Linear\n\n            if major == 8 and minor == 9:\n                log_once(\n                    logger.info,\n                    \"GPU supports FP8, but using Marlin FP8 kernel for better performance\",\n                )\n            else:\n                log_once(\n                    logger.info, \"GPU does not support FP8, using Marlin FP8 kernel\"\n                )\n\n            return GPTQMarlinFP8Linear\n\n    # On other systems let Torch decide if the hardware supports FP8.\n    return Fp8Linear\n\n\ndef normalize_e4m3fn_to_native_float8(\n    weight: torch.Tensor,\n    weight_scale: torch.Tensor,\n    input_scale: Optional[torch.Tensor] = None,\n) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:\n    if weight.dtype == torch.float8_e4m3fn and SYSTEM == \"rocm\":\n        # The bits pattern 10000000(-128) represents zero in e4m3fn\n        # but NaN in e4m3fnuz. So here we set it to 0.\n        # https://onnx.ai/onnx/technical/float8.html\n        weight_as_int8 = weight.view(torch.int8)\n        ROCM_FP8_NAN_AS_INT = -128\n        weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0\n        weight = weight_as_int8.view(torch.float8_e4m3fnuz)\n\n        # For the same bits representation, e4m3fnuz value is half of\n        # the e4m3fn value, so we should double the scaling factor to\n        # get the same dequantized value.\n        # https://onnx.ai/onnx/technical/float8.html\n        weight_scale = weight_scale * 2.0\n        if input_scale is not None:\n            input_scale = input_scale * 2.0\n    return weight, weight_scale, input_scale\n\n\ndef per_tensor_dequantize(\n    tensor: torch.Tensor,\n    inv_scale: Union[float, torch.Tensor],\n    dtype: torch.dtype = torch.float16,\n) -> torch.Tensor:\n    fake_qweight = tensor.to(dtype)\n    dq_weight = fake_qweight * inv_scale\n    return dq_weight\n\n\ndef requantize_with_max_scale(\n    weight: torch.Tensor,\n    weight_scale: torch.Tensor,\n    logical_widths: int,\n    dtype: torch.dtype,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # Max scale to be used for requanitzation.\n    max_w_scale = weight_scale.max().float()\n\n    start = 0\n    for idx, logical_width in enumerate(logical_widths):\n        end = start + logical_width\n        weight_dq = per_tensor_dequantize(\n            weight[start:end, :], weight_scale[idx], dtype\n        )\n        weight[start:end, :], max_w_scale_normalized = fp8_quantize(\n            weight_dq, max_w_scale\n        )\n        start = end\n\n    return weight, max_w_scale_normalized\n\n\ndef fp8_quantize(\n    weight: torch.Tensor,\n    scale: Optional[torch.Tensor] = None,\n    scale_upper_bound: Optional[torch.Tensor] = None,\n    qdtype: torch.dtype = torch.float8_e4m3fn,\n    scalar: bool = False,\n):\n    \"\"\"\n    This function returns a reciprocal of the scale, so that a tensor can be unscaled\n    by multiplying it with the returned scale. If a scale is given through the `scale`\n    argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can\n    be used without modification).\n    \"\"\"\n    if quantization is not None:\n        shape = weight.shape\n        qweight, scale = quantization.scaled_fp8_quant(\n            weight.reshape(-1, shape[-1]),\n            scale=scale,\n            scale_ub=scale_upper_bound,\n            # TODO: don't do this when we have to use the Torch kernel.\n            use_per_token_if_dynamic=not scalar,\n        )\n\n        return qweight.reshape(shape), scale\n\n    finfo = torch.finfo(qdtype)\n\n    if scale is None:\n        # Calculate the scale as dtype max divided by absmax\n        scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)\n        # scale and clamp the tensor to bring it to\n        # the representative range of float8 data type\n        # (as default cast is unsaturated)\n        qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)\n        scale = scale.float().reciprocal()\n    else:\n        if SYSTEM == \"rocm\":\n            scale = scale / 2.0\n        # Use reciprocal to avoid more expensive division.\n        qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)\n\n    # Return both float8 data and the inverse scale (as float),\n    # as both required as inputs to torch._scaled_mm\n    qweight = qweight.to(qdtype)\n\n    if SYSTEM == \"rocm\":\n        qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)\n\n    return qweight, scale\n\n\nclass HybridFP8UnquantLoader(WeightsLoader):\n    \"\"\"Weight loader that loads FP8 and unquantized Torch tensors.\"\"\"\n\n    def __init__(\n        self,\n        activation_scale_ub: Optional[float],\n        to_fp8: bool,\n        weight_block_size: Optional[List[int]] = None,\n    ):\n        self.activation_scale_ub = activation_scale_ub\n        self.to_fp8 = to_fp8\n        self.weight_block_size = weight_block_size\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        w = weights.get_tensor(f\"{prefix}.weight\")\n\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                scale = weights.get_tensor(f\"{prefix}.weight_scale_inv\")\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n            # FP8 branch\n            scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if SYSTEM == \"cuda\":\n                scale.reshape(-1).expand(w.shape[0])\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = (\n                    weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n                    .reshape(-1)\n                    .max()\n                )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        w = weights.get_packed_sharded(\n            f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n        )\n\n        if w.dtype == torch.float8_e4m3fn:\n            # FP8 branch\n            scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if scale.numel() > 1:\n                scale = weights.get_packed_sharded(\n                    f\"{prefix}.weight_scale\",\n                    dim=0,\n                    block_sizes=block_sizes,\n                    to_dtype=False,\n                )\n            if SYSTEM == \"cuda\":\n                scale = scale.reshape(-1).expand(w.shape[0])\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = weights.get_tensor(\n                    f\"{prefix}.input_scale\", to_dtype=False\n                )\n                if input_scale.numel() > 1:\n                    input_scale = weights.get_packed_sharded(\n                        f\"{prefix}.input_scale\",\n                        dim=0,\n                        block_sizes=block_sizes,\n                        to_dtype=False,\n                    )\n                input_scale = input_scale.reshape(-1).max()\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet\n        w = [\n            weights.get_sharded(f\"{p}.weight\", dim=0, to_device=False) for p in prefixes\n        ]\n        shapes = [x.shape for x in w]\n\n        # Concat then send to the device\n        w = torch.cat(w, dim=dim).to(weights.device)\n\n        # FP8 branch\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                scale = [\n                    weights.get_sharded(f\"{p}.weight_scale_inv\", dim=0, to_device=False)\n                    for p in prefixes\n                ]\n                scale = torch.cat(scale, dim=dim)\n                scale = scale.to(weights.device)\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n\n            scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.weight_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n            ]\n            scale = torch.cat(scale, dim=0).reshape(-1)\n\n            input_scale = [\n                _load_scalar_or_matrix_scale(weights, f\"{p}.input_scale\", shape)\n                for p, shape in zip(prefixes, shapes)\n                if weights.has_tensor(f\"{p}.input_scale\")\n            ]\n            assert len(input_scale) == 0 or len(input_scale) == len(prefixes)\n            input_scale = (\n                torch.cat(input_scale, dim=0).reshape(-1).max()\n                if len(input_scale) != 0\n                else None\n            )\n\n            if SYSTEM == \"rocm\":\n                w, scale, input_scale = normalize_e4m3fn_to_native_float8(\n                    w, scale, input_scale\n                )\n\n                if scale.numel() == len(prefixes):\n                    logical_widths = [x[0] for x in shapes]\n                    w, scale = requantize_with_max_scale(\n                        w, scale.to(weights.device), logical_widths, weights.dtype\n                    )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        w = weights.get_sharded(f\"{prefix}.weight\", dim=1)\n        # FP8 branch\n        if w.dtype == torch.float8_e4m3fn:\n            if self.weight_block_size is not None:\n                # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.\n                scale = weights.get_sharded(f\"{prefix}.weight_scale_inv\", dim=1)\n\n                return Fp8Weight(\n                    weight=w,\n                    weight_scale=scale,\n                    activation_scale_ub=self.activation_scale_ub,\n                    dtype=weights.dtype,\n                    weight_block_size=self.weight_block_size,\n                )\n\n            scale = weights.get_tensor(f\"{prefix}.weight_scale\", to_dtype=False)\n\n            if SYSTEM == \"cuda\":\n                scale = scale.reshape(-1).expand(w.shape[0])\n\n            input_scale = None\n            if weights.has_tensor(f\"{prefix}.input_scale\"):\n                input_scale = (\n                    weights.get_tensor(f\"{prefix}.input_scale\", to_dtype=False)\n                    .reshape(-1)\n                    .max()\n                )\n\n            return Fp8Weight(\n                weight=w,\n                weight_scale=scale,\n                input_scale=input_scale,\n                activation_scale_ub=self.activation_scale_ub,\n                dtype=weights.dtype,\n            )\n        if self.to_fp8:\n            return Fp8Weight(weight=w, dtype=weights.dtype)\n\n        return UnquantizedWeight(w)\n\n\n@dataclass\nclass Fp8Weight(Weight):\n    weight: torch.Tensor\n    dtype: torch.dtype\n    weight_scale: Optional[torch.Tensor] = None\n    input_scale: Optional[torch.Tensor] = None\n    activation_scale_ub: Optional[float] = None\n    force_w8a16: bool = False\n    weight_block_size: Optional[List[int]] = None\n\n    def get_linear(self, bias: torch.Tensor):\n        if self.weight_scale is None:\n            return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(\n                self.weight, bias, self.dtype\n            )\n        # This is not checked by the fbgemm kernels, but they require contiguous\n        # memory. Can be non-contiguous when we e.g. expand from scalars.\n        self.weight_scale = self.weight_scale.contiguous()\n        return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(\n            weight=self.weight,\n            scale=self.weight_scale,\n            dtype=self.dtype,\n            bias=bias,\n            input_scale=self.input_scale,\n            scale_upper_bound=self.activation_scale_ub,\n            weight_block_size=self.weight_block_size,\n        )\n\n\nclass Fp8Linear(torch.nn.Module):\n    _device_identity_cache = {}\n\n    def __init__(\n        self,\n        qweight: torch.Tensor,\n        scale: torch.Tensor,\n        dtype: torch.dtype,\n        bias: Optional[torch.Tensor] = None,\n        input_scale: Optional[torch.Tensor] = None,\n        scale_upper_bound: Optional[float] = None,\n        weight_block_size: Optional[List[int]] = None,\n    ) -> None:\n        super().__init__()\n        if CUTLASS_FP8_AVAILABLE:\n            log_once(logger.info, \"Using cutlass w8a8 kernels\")\n        if SYSTEM == \"rocm\" and qweight.dtype == torch.float8_e4m3fn:\n            qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(\n                weight=qweight, weight_scale=scale, input_scale=input_scale\n            )\n\n        self.dtype = dtype\n        self.qweight = qweight\n        self.scale = scale.float()\n        self.input_scale = input_scale.float() if input_scale is not None else None\n        self.weight_block_size = weight_block_size\n\n        if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:\n            self.scale_upper_bound = torch.tensor(\n                scale_upper_bound, dtype=torch.float32, device=qweight.device\n            )\n        else:\n            self.scale_upper_bound = scale_upper_bound\n\n        self.bias = bias if bias is not None else None\n\n    @classmethod\n    def from_unquant(cls, weight, bias, dtype):\n        qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)\n        return cls(\n            qweight=qweight,\n            scale=scale,\n            dtype=dtype,\n            bias=bias,\n            input_scale=None,\n            scale_upper_bound=None,\n        )\n\n    @classmethod\n    def from_fp8(\n        cls,\n        weight: torch.Tensor,\n        scale: torch.Tensor,\n        dtype: torch.dtype,\n        bias: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> \"Fp8Linear\":\n        input_scale = kwargs.get(\"input_scale\", None)\n        scale_upper_bound = kwargs.get(\"scale_upper_bound\", None)\n        weight_block_size = kwargs.get(\"weight_block_size\", None)\n\n        return cls(\n            qweight=weight,\n            scale=scale,\n            input_scale=input_scale,\n            scale_upper_bound=scale_upper_bound,\n            bias=bias,\n            dtype=dtype,\n            weight_block_size=weight_block_size,\n        )\n\n    @classmethod\n    def get_shared_device_identity(cls, device):\n        # Input scaling factors are no longer optional in _scaled_mm starting\n        # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale\n        if device not in cls._device_identity_cache:\n            cls._device_identity_cache[device] = torch.ones(1, device=device)\n        return cls._device_identity_cache[device]\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if self.weight_block_size is not None:\n            # https://arxiv.org/pdf/2412.19437\n            # At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and\n            # scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we\n            # group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output\n            # channels).\n            qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])\n            output = w8a8_block_fp8_matmul(\n                qinput,\n                self.qweight,\n                scale,\n                self.scale,\n                self.weight_block_size,\n                output_dtype=input.dtype,\n            )\n\n            if self.bias is not None:\n                output = output + self.bias\n            return output.to(dtype=input.dtype)\n        if CUTLASS_FP8_AVAILABLE:\n            # cutlass FP8 supports per-token scales, so get non-scalar scales.\n            qinput, scale = fp8_quantize(\n                input, scale_upper_bound=self.scale_upper_bound, scalar=False\n            )\n            return quantization.cutlass_scaled_mm(\n                qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias\n            )\n\n        qinput, scale = fp8_quantize(\n            input,\n            self.input_scale,\n            scale_upper_bound=self.scale_upper_bound,\n            scalar=True,\n        )\n\n        per_tensor_weights = self.scale.numel() == 1\n        per_tensor_activations = scale.numel() == 1\n\n        if SYSTEM != \"rocm\" or (per_tensor_weights and per_tensor_activations):\n            output = torch._scaled_mm(\n                qinput,\n                self.qweight.t(),\n                out_dtype=self.dtype,\n                scale_a=scale,\n                scale_b=self.scale,\n                bias=self.bias,\n            )\n\n            if isinstance(output, tuple) and len(output) == 2:\n                output = output[0]\n        else:\n            device_identity = None\n            if SYSTEM == \"rocm\":\n                device_identity = self.get_shared_device_identity(self.qweight.device)\n\n            output = torch._scaled_mm(\n                qinput,\n                self.qweight.t(),\n                scale_a=device_identity,\n                scale_b=device_identity,\n                out_dtype=torch.float32,\n            )\n            if isinstance(output, tuple) and len(output) == 2:\n                output = output[0]\n\n            output = output * scale * self.scale.t()\n            if self.bias is not None:\n                output = output + self.bias\n\n            output = output.to(dtype=self.dtype)\n\n        return output\n\n\ndef _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):\n    scale = weights.get_tensor(prefix, to_dtype=False)\n\n    if scale.numel() > 1:\n        scale = weights.get_sharded(prefix, dim=0, to_dtype=False)\n    elif SYSTEM == \"rocm\":\n        return scale.reshape(-1)\n    return scale.reshape(-1).expand(shape[0])\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/__init__.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport torch\nfrom loguru import logger\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    Weight,\n    Weights,\n    WeightsLoader,\n    DefaultWeightsLoader,\n)\nimport math\n\n\n@dataclass\nclass GPTQWeight(Weight):\n    qweight: torch.Tensor\n    qzeros: torch.Tensor\n    scales: torch.Tensor\n    g_idx: Optional[torch.Tensor]\n    bits: int\n    groupsize: int\n    use_awq_kernel: bool\n    use_exllama: bool\n\n    def __post_init__(self):\n        if self.scales.dtype == torch.float:\n            self.scales = self.scales.half()\n\n    @property\n    def device(self) -> torch.device:\n        return self.qweight.device\n\n    def get_linear(self, bias: torch.Tensor):\n        if self.use_awq_kernel:\n            if SYSTEM == \"rocm\":\n                raise NotImplementedError(\n                    \"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead \"\n                    \"to use Exllama/GPTQ kernels for AWQ inference.\"\n                )\n            try:\n                from text_generation_server.layers.awq.quantize import WQLinear\n\n                return WQLinear(\n                    w_bit=self.bits,\n                    group_size=self.groupsize,\n                    qweight=self.qweight,\n                    qzeros=self.qzeros,\n                    scales=self.scales,\n                    bias=bias,\n                )\n            except ImportError:\n                raise NotImplementedError(\n                    \"You do not seem to have awq installed, either install it (cd server &&  make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly\"\n                )\n        elif self.use_exllama:\n            try:\n                from text_generation_server.layers.gptq import ExllamaQuantLinear\n            except ImportError:\n                raise NotImplementedError(\n                    \"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`\"\n                )\n\n            return ExllamaQuantLinear(self, bias)\n        else:\n            if SYSTEM == \"ipex\" and not (\n                self.device.type == \"xpu\"\n                and (\n                    self.bits != 4\n                    or math.ceil(\n                        (self.qweight.shape[0] * 32 // self.bits) / self.groupsize\n                    )\n                    != self.scales.shape[0]\n                )\n            ):\n                from .ipex import QuantLinear\n            else:\n                from .triton import QuantLinear\n            return QuantLinear(\n                self.qweight,\n                self.qzeros,\n                self.scales,\n                self.g_idx,\n                bias,\n                self.bits,\n                self.groupsize,\n            )\n\n\nclass GPTQWeightsLoader(WeightsLoader):\n    \"\"\"\n    Loader for GPTQ- and AWQ-quantized weights.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        bits: int,\n        desc_act: bool,\n        groupsize: int,\n        quant_method: str,\n        quantize: str,\n        sym: bool,\n        modules_to_not_convert: List[str],\n    ):\n        self.bits = bits\n        self.desc_act = desc_act\n        self.groupsize = groupsize\n        self.quant_method = quant_method\n        self.quantize = quantize\n        self.sym = sym\n        self.modules_to_not_convert = modules_to_not_convert\n\n    def get_weights(self, weights: Weights, prefix: str):\n        self._get_gptq_params(weights)\n\n        use_exllama = True\n        if self.bits != 4:\n            use_exllama = False\n\n        if self.desc_act:\n            log_once(logger.warning, \"Disabling exllama because desc_act=True\")\n            use_exllama = False\n\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights(weights, prefix)\n\n        try:\n            qweight = weights.get_tensor(f\"{prefix}.qweight\")\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`\"\n            )\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        else:\n            g_idx = None\n\n        from text_generation_server.layers.gptq import (\n            HAS_EXLLAMA,\n            CAN_EXLLAMA,\n            GPTQWeight,\n        )\n\n        if use_exllama:\n            if not HAS_EXLLAMA:\n                if CAN_EXLLAMA:\n                    log_once(\n                        logger.warning,\n                        \"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True\",\n                    )\n                use_exllama = False\n            else:\n                log_once(logger.info, f\"Using exllama kernels v{HAS_EXLLAMA}\")\n\n        qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n        scales = weights.get_tensor(f\"{prefix}.scales\")\n\n        if use_exllama and g_idx is not None:\n            g_idx = g_idx - g_idx[0]\n\n        if self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def is_layer_skipped_quantization(\n        self, prefix: str, modules_to_not_convert: List[str]\n    ):\n        return any(module_name in prefix for module_name in modules_to_not_convert)\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights_col_packed(\n                weights, prefix, block_sizes\n            )\n        try:\n            qweight = weights.get_packed_sharded(\n                f\"{prefix}.qweight\", dim=1, block_sizes=block_sizes\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized.\"\n            )\n        scales = weights.get_packed_sharded(\n            f\"{prefix}.scales\", dim=1, block_sizes=block_sizes\n        )\n        scales = scales.to(dtype=weights.dtype)\n\n        self._get_gptq_params(weights)\n\n        qzeros = weights.get_packed_sharded(\n            f\"{prefix}.qzeros\", dim=1, block_sizes=block_sizes\n        )\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        elif self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            g_idx = (\n                torch.arange(\n                    qweight.shape[0] * (32 // self.bits),\n                    device=qweight.device,\n                )\n                // self.groupsize\n            ).to(dtype=torch.int32)\n        else:\n            g_idx = None\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=False,\n        )\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        if self.is_layer_skipped_quantization(prefixes[0], self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_multi_weights_col(weights, prefixes, dim)\n        try:\n            qweight = torch.cat(\n                [weights.get_sharded(f\"{p}.qweight\", dim=1) for p in prefixes], dim=1\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized\"\n            )\n\n        scales = torch.cat(\n            [weights.get_sharded(f\"{p}.scales\", dim=1) for p in prefixes], dim=1\n        )\n\n        self._get_gptq_params(weights)\n\n        qzeros = torch.cat(\n            [weights.get_sharded(f\"{p}.qzeros\", dim=1) for p in prefixes], dim=1\n        )\n\n        from text_generation_server.layers.gptq import HAS_EXLLAMA\n\n        use_exllama = (\n            self.bits == 4\n            and HAS_EXLLAMA\n            and self.quantize == \"gptq\"\n            and not self.desc_act\n        )\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            w = [weights.get_tensor(f\"{p}.g_idx\") for p in prefixes]\n            for w2 in w[1:]:\n                torch.testing.assert_close(w2, w[0])\n            g_idx = w[0]\n        elif self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n        else:\n            g_idx = None\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        self._get_gptq_params(weights)\n\n        use_exllama = True\n        desc_act = self.desc_act\n        if self.bits != 4:\n            use_exllama = False\n\n        if self.desc_act:\n            log_once(logger.warning, \"Disabling exllama because desc_act=True\")\n            use_exllama = False\n\n        if self.is_layer_skipped_quantization(prefix, self.modules_to_not_convert):\n            return DefaultWeightsLoader.get_weights_row(weights, prefix)\n        try:\n            qweight = weights.get_sharded(f\"{prefix}.qweight\", dim=0)\n        except RuntimeError:\n            raise RuntimeError(\n                \"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`\"\n            )\n\n        if self.quantize == \"gptq\" and self.quant_method == \"gptq\":\n            g_idx = weights.get_sharded(f\"{prefix}.g_idx\", dim=0)\n        else:\n            g_idx = None\n\n        if weights.process_group.size() > 1:\n            if g_idx is not None:\n                if (\n                    not torch.equal(\n                        # Remove g_idx[0] to adapt the check with TP>1.\n                        (g_idx - g_idx[0]).cpu(),\n                        torch.tensor(\n                            [i // self.groupsize for i in range(g_idx.shape[0])],\n                            dtype=torch.int32,\n                        ),\n                    )\n                    and not (g_idx == 0).all()\n                ):\n                    # Exllama implementation does not support row tensor parallelism with act-order, as\n                    # it would require to reorder input activations that are split unto several GPUs\n                    use_exllama = False\n                    desc_act = True\n\n        from text_generation_server.layers.gptq import (\n            CAN_EXLLAMA,\n            HAS_EXLLAMA,\n            GPTQWeight,\n        )\n\n        if use_exllama:\n            if not HAS_EXLLAMA:\n                if CAN_EXLLAMA:\n                    log_once(\n                        logger.warning,\n                        \"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True\",\n                    )\n                use_exllama = False\n            else:\n                log_once(logger.info, f\"Using exllama kernels v{HAS_EXLLAMA}\")\n\n        if not desc_act and self.groupsize != -1:\n            qzeros = weights.get_sharded(f\"{prefix}.qzeros\", dim=0)\n            scales = weights.get_sharded(f\"{prefix}.scales\", dim=0)\n            if g_idx is not None:\n                # qzeros, scales sharded, and g_idx must be adjusted accordingly\n                g_idx = g_idx - g_idx[0]\n        else:\n            qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n            scales = weights.get_tensor(f\"{prefix}.scales\")\n\n        if self.quantize == \"gptq\" and self.quant_method == \"awq\":\n            log_once(\n                logger.info, \"Converting AWQ model to Exllama/GPTQ packing format.\"\n            )\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n            if use_exllama:\n                g_idx = None\n            else:\n                g_idx = (\n                    torch.arange(\n                        qweight.shape[0] * (32 // self.bits),\n                        device=qweight.device,\n                    )\n                    // self.groupsize\n                ).to(dtype=torch.int32)\n\n        return GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=self.bits,\n            groupsize=self.groupsize,\n            use_awq_kernel=self.quantize == \"awq\",\n            use_exllama=use_exllama,\n        )\n\n    def _get_gptq_params(self, weights: Weights):\n        if weights.has_tensor(\"gptq_bits\") and weights.has_tensor(\"gptq_groupsize\"):\n            self.bits = weights.get_tensor(\"gptq_bits\").item()\n            self.groupsize = weights.get_tensor(\"gptq_groupsize\").item()\n            self.desc_act = False\n            # `server quantize` used asymmetric quantization unconditionally\n            # before the `gptq_sym` setting tensor was added.\n            self.sym = (\n                weights.get_tensor(\"gptq_sym\").item()\n                if weights.has_tensor(\"gptq_sym\")\n                else False\n            )\n            self.quant_method = \"gptq\"\n\n\n# Needs to be at the end because circular import.\ntry:\n    major, _minor = torch.cuda.get_device_capability()\nexcept Exception:\n    major = 1\n\nHAS_EXLLAMA = False\nCAN_EXLLAMA = major >= 8 or SYSTEM == \"rocm\"\nV2 = os.getenv(\"EXLLAMA_VERSION\", \"2\") == \"2\"\nif os.getenv(\"DISABLE_EXLLAMA\") == \"True\":\n    HAS_EXLLAMA = False\nelif CAN_EXLLAMA:\n    try:\n        if V2:\n            from text_generation_server.layers.gptq.exllamav2 import (\n                QuantLinear as ExllamaQuantLinear,  # noqa: F401\n                create_exllama_buffers,  # noqa: F401\n                set_device,  # noqa: F401\n            )\n\n            HAS_EXLLAMA = \"2\"\n        else:\n            from text_generation_server.layers.gptq.exllama import (\n                Ex4bitLinear as ExllamaQuantLinear,  # noqa: F401\n                create_exllama_buffers,  # noqa: F401\n                set_device,  # noqa: F401\n            )\n\n            HAS_EXLLAMA = \"1\"\n\n    except ImportError:\n        pass\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/custom_autotune.py",
    "content": "# https://github.com/fpgaminer/GPTQ-triton\n\"\"\"\nMostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.\n\"\"\"\n\nimport builtins\nimport math\nimport time\nfrom typing import Dict\n\nimport triton\n\n\nclass Autotuner(triton.KernelInterface):\n    def __init__(\n        self,\n        fn,\n        arg_names,\n        configs,\n        key,\n        reset_to_zero,\n        prune_configs_by: Dict = None,\n        nearest_power_of_two: bool = False,\n    ):\n        \"\"\"\n        :param prune_configs_by: a dict of functions that are used to prune configs, fields:\n                'perf_model': performance model used to predicate running time with different configs, returns running time\n                'top_k': number of configs to bench\n                'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.\n                'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results\n        \"\"\"\n        if not configs:\n            self.configs = [triton.Config({}, num_warps=4, num_stages=2)]\n        else:\n            self.configs = configs\n        self.key_idx = [arg_names.index(k) for k in key]\n        self.nearest_power_of_two = nearest_power_of_two\n        self.cache = {}\n        # hook to reset all required tensor to zeros before relaunching a kernel\n        self.hook = lambda args: 0\n        if reset_to_zero is not None:\n            self.reset_idx = [arg_names.index(k) for k in reset_to_zero]\n\n            def _hook(args):\n                for i in self.reset_idx:\n                    args[i].zero_()\n\n            self.hook = _hook\n        self.arg_names = arg_names\n        # prune configs\n        if prune_configs_by:\n            perf_model, top_k = (\n                prune_configs_by[\"perf_model\"],\n                prune_configs_by[\"top_k\"],\n            )\n            if \"early_config_prune\" in prune_configs_by:\n                early_config_prune = prune_configs_by[\"early_config_prune\"]\n        else:\n            perf_model, top_k, early_config_prune = None, None, None\n        self.perf_model, self.configs_top_k = perf_model, top_k\n        self.early_config_prune = early_config_prune\n        self.fn = fn\n\n    def _bench(self, *args, config, **meta):\n        # check for conflicts, i.e. meta-parameters both provided\n        # as kwargs and by the autotuner\n        conflicts = meta.keys() & config.kwargs.keys()\n        if conflicts:\n            raise ValueError(\n                f\"Conflicting meta-parameters: {', '.join(conflicts)}.\"\n                \" Make sure that you don't re-define auto-tuned symbols.\"\n            )\n        # augment meta-parameters with tunable ones\n        current = dict(meta, **config.kwargs)\n\n        def kernel_call():\n            if config.pre_hook:\n                config.pre_hook(self.nargs)\n            self.hook(args)\n            self.fn.run(\n                *args,\n                num_warps=config.num_warps,\n                num_stages=config.num_stages,\n                **current,\n            )\n\n        try:\n            # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses\n            # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default\n            return triton.testing.do_bench(\n                kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40\n            )\n        except triton.OutOfResources:\n            return [float(\"inf\"), float(\"inf\"), float(\"inf\")]\n\n    def run(self, *args, **kwargs):\n        self.nargs = dict(zip(self.arg_names, args))\n        if len(self.configs) > 1:\n            key = tuple(args[i] for i in self.key_idx)\n\n            # This reduces the amount of autotuning by rounding the keys to the nearest power of two\n            # In my testing this gives decent results, and greatly reduces the amount of tuning required\n            if self.nearest_power_of_two:\n                key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])\n\n            if key not in self.cache:\n                # prune configs\n                pruned_configs = self.prune_configs(kwargs)\n                bench_start = time.time()\n                timings = {\n                    config: self._bench(*args, config=config, **kwargs)\n                    for config in pruned_configs\n                }\n                bench_end = time.time()\n                self.bench_time = bench_end - bench_start\n                self.cache[key] = builtins.min(timings, key=timings.get)\n                self.hook(args)\n                self.configs_timings = timings\n            config = self.cache[key]\n        else:\n            config = self.configs[0]\n        self.best_config = config\n        if config.pre_hook is not None:\n            config.pre_hook(self.nargs)\n        return self.fn.run(\n            *args,\n            num_warps=config.num_warps,\n            num_stages=config.num_stages,\n            **kwargs,\n            **config.kwargs,\n        )\n\n    def prune_configs(self, kwargs):\n        pruned_configs = self.configs\n        if self.early_config_prune:\n            pruned_configs = self.early_config_prune(self.configs, self.nargs)\n        if self.perf_model:\n            top_k = self.configs_top_k\n            if isinstance(top_k, float) and top_k <= 1.0:\n                top_k = int(len(self.configs) * top_k)\n            if len(pruned_configs) > top_k:\n                est_timing = {\n                    config: self.perf_model(\n                        **self.nargs,\n                        **kwargs,\n                        **config.kwargs,\n                        num_stages=config.num_stages,\n                        num_warps=config.num_warps,\n                    )\n                    for config in pruned_configs\n                }\n                pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[\n                    :top_k\n                ]\n        return pruned_configs\n\n    def warmup(self, *args, **kwargs):\n        self.nargs = dict(zip(self.arg_names, args))\n        for config in self.prune_configs(kwargs):\n            self.fn.warmup(\n                *args,\n                num_warps=config.num_warps,\n                num_stages=config.num_stages,\n                **kwargs,\n                **config.kwargs,\n            )\n        self.nargs = None\n\n\ndef autotune(\n    configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False\n):\n    \"\"\"\n    Decorator for auto-tuning a :code:`triton.jit`'d function.\n    .. highlight:: python\n    .. code-block:: python\n            @triton.autotune(configs=[\n                    triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),\n                    triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),\n                    ],\n                    key=['x_size'] # the two above configs will be evaluated anytime\n                                                    # the value of x_size changes\n            )\n            @triton.jit\n            def kernel(x_ptr, x_size, **META):\n                    BLOCK_SIZE = META['BLOCK_SIZE']\n    :note: When all the configurations are evaluated, the kernel will run multiple time.\n                    This means that whatever value the kernel updates will be updated multiple times.\n                    To avoid this undesired behavior, you can use the `reset_to_zero` argument, which\n                    reset the value of the provided tensor to `zero` before running any configuration.\n    :param configs: a list of :code:`triton.Config` objects\n    :type configs: list[triton.Config]\n    :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.\n    :type key: list[str]\n    :param prune_configs_by: a dict of functions that are used to prune configs, fields:\n            'perf_model': performance model used to predicate running time with different configs, returns running time\n            'top_k': number of configs to bench\n            'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.\n    :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.\n    :type reset_to_zero: list[str]\n    \"\"\"\n\n    def decorator(fn):\n        return Autotuner(\n            fn,\n            fn.arg_names,\n            configs,\n            key,\n            reset_to_zero,\n            prune_configs_by,\n            nearest_power_of_two,\n        )\n\n    return decorator\n\n\ndef matmul248_kernel_config_pruner(configs, nargs):\n    \"\"\"\n    The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.\n    \"\"\"\n    m = max(2 ** int(math.ceil(math.log2(nargs[\"M\"]))), 16)\n    n = max(2 ** int(math.ceil(math.log2(nargs[\"N\"]))), 16)\n    k = max(2 ** int(math.ceil(math.log2(nargs[\"K\"]))), 16)\n\n    used = set()\n    for config in configs:\n        block_size_m = min(m, config.kwargs[\"BLOCK_SIZE_M\"])\n        block_size_n = min(n, config.kwargs[\"BLOCK_SIZE_N\"])\n        block_size_k = min(k, config.kwargs[\"BLOCK_SIZE_K\"])\n        group_size_m = config.kwargs[\"GROUP_SIZE_M\"]\n\n        if (\n            block_size_m,\n            block_size_n,\n            block_size_k,\n            group_size_m,\n            config.num_stages,\n            config.num_warps,\n        ) in used:\n            continue\n\n        used.add(\n            (\n                block_size_m,\n                block_size_n,\n                block_size_k,\n                group_size_m,\n                config.num_stages,\n                config.num_warps,\n            )\n        )\n        yield triton.Config(\n            {\n                \"BLOCK_SIZE_M\": block_size_m,\n                \"BLOCK_SIZE_N\": block_size_n,\n                \"BLOCK_SIZE_K\": block_size_k,\n                \"GROUP_SIZE_M\": group_size_m,\n            },\n            num_stages=config.num_stages,\n            num_warps=config.num_warps,\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/exllama.py",
    "content": "from text_generation_server.layers.gptq import GPTQWeight\nimport torch\nfrom exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params\n\n# Dummy tensor to pass instead of g_idx since there is no way to pass \"None\" to a C++ extension\nnone_tensor = torch.empty((1, 1), device=\"meta\")\n\n\ndef ext_make_q4(qweight, qzeros, scales, g_idx, device):\n    \"\"\"Construct Q4Matrix, return handle\"\"\"\n    return make_q4(\n        qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device\n    )\n\n\ndef ext_q4_matmul(x, q4, q4_width):\n    \"\"\"Matrix multiplication, returns x @ q4\"\"\"\n    outshape = x.shape[:-1] + (q4_width,)\n    x = x.view(-1, x.shape[-1])\n    output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)\n\n    q4_matmul(x, q4, output)\n\n    return output.view(outshape)\n\n\nMAX_DQ = 1\nMAX_INNER = 1\nACT_ORDER = False\nDEVICE = None\n\nTEMP_STATE = None\nTEMP_DQ = None\n\n\ndef set_device(device):\n    global DEVICE\n    DEVICE = device\n\n\ndef create_exllama_buffers(max_total_tokens: int):\n    global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ\n\n    assert DEVICE is not None, \"call set_device first\"\n\n    if not ACT_ORDER:\n        max_total_tokens = 1\n\n    # This temp_state buffer is required to reorder X in the act-order case.\n    temp_state = torch.zeros(\n        (max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE\n    )\n    temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)\n\n    # This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.\n    prepare_buffers(DEVICE, temp_state, temp_dq)\n\n    matmul_recons_thd = 8\n    matmul_fused_remap = False\n    matmul_no_half2 = False\n    set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)\n\n    TEMP_STATE, TEMP_DQ = temp_state, temp_dq\n\n\nclass Ex4bitLinear(torch.nn.Module):\n    \"\"\"Linear layer implementation with per-group 4-bit quantization of the weights\"\"\"\n\n    def __init__(self, weight: GPTQWeight, bias):\n        super().__init__()\n        global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE\n        assert weight.bits == 4\n\n        self.device = weight.qweight.device\n        self.qweight = weight.qweight\n        self.qzeros = weight.qzeros\n        self.scales = weight.scales\n        self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None\n        self.bias = bias if bias is not None else None\n\n        if self.g_idx is not None and (\n            (self.g_idx == 0).all()\n            or torch.equal(\n                weight.g_idx.cpu(),\n                torch.tensor(\n                    [i // weight.groupsize for i in range(weight.g_idx.shape[0])],\n                    dtype=torch.int32,\n                ),\n            )\n        ):\n            self.empty_g_idx = True\n            self.g_idx = None\n\n        assert self.device.type == \"cuda\"\n        assert self.device.index is not None\n\n        self.q4 = ext_make_q4(\n            self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index\n        )\n\n        self.height = weight.qweight.shape[0] * 8\n        self.width = weight.qweight.shape[1]\n\n        # Infer groupsize from height of qzeros\n        self.groupsize = None\n        if self.qzeros.shape[0] > 1:\n            self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])\n\n        if self.groupsize is not None:\n            assert weight.groupsize == self.groupsize\n\n        # Handle act-order matrix\n        if self.g_idx is not None:\n            if self.groupsize is None:\n                raise ValueError(\"Found group index but no groupsize. What do?\")\n            self.act_order = True\n        else:\n            self.act_order = False\n\n        DEVICE = self.qweight.device\n\n        MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)\n\n        if self.act_order:\n            MAX_INNER = max(MAX_INNER, self.height, self.width)\n\n            ACT_ORDER = True\n\n    def forward(self, x):\n        out = ext_q4_matmul(x, self.q4, self.width)\n\n        if self.bias is not None:\n            out.add_(self.bias)\n        return out\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/exllamav2.py",
    "content": "# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2\r\n\r\nfrom dataclasses import dataclass\r\nfrom typing import Optional\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nfrom loguru import logger\r\n\r\nfrom text_generation_server.layers.exl2 import Exl2Weight\r\nfrom text_generation_server.layers.gptq import GPTQWeight\r\nfrom text_generation_server.utils.log import log_master\r\n\r\ntry:\r\n    from exllamav2.ext import exllamav2_ext\r\n\r\n    make_q_matrix = exllamav2_ext.make_q_matrix\r\n    gemm_half_q_half = exllamav2_ext.gemm_half_q_half\r\nexcept ImportError:\r\n    log_master(logger.warning, \"exllamav2_kernels not installed.\")\r\n    raise\r\n\r\n# Dummy tensor to pass instead of g_idx since there is no way to pass \"None\" to a C++ extension\r\nnone_tensor = torch.empty((1, 1), device=\"meta\")\r\n\r\n\r\n@dataclass\r\nclass _ExtraTensors:\r\n    \"\"\"Additional generated quantizer tensors.\"\"\"\r\n\r\n    q_group_map: Optional[torch.Tensor] = None\r\n    q_invperm: Optional[torch.Tensor] = None\r\n    q_perm: Optional[torch.Tensor] = None\r\n\r\n\r\ndef ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):\r\n    \"\"\"Matrix multiplication, returns x @ q4\"\"\"\r\n    output_shape = x.shape[:-1] + (q4_width,)\r\n    x = x.view(-1, x.shape[-1])\r\n    output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)\r\n    gemm_half_q_half(x, q_handle, output, force_cuda)\r\n    return output.view(output_shape)\r\n\r\n\r\ndef make_group_map(q_groups: torch.Tensor, num_qrows: int):\r\n    gr = q_groups.tolist()\r\n    group_map = []\r\n    num_groups = len(gr) // 2\r\n\r\n    for i in range(num_groups):\r\n        bits = gr[i * 2]\r\n        if i < num_groups - 1:\r\n            qrows = gr[i * 2 + 3] - gr[i * 2 + 1]\r\n        else:\r\n            qrows = num_qrows - gr[i * 2 + 1]\r\n        rows = qrows * 32 // bits\r\n        for j in range(rows):\r\n            group_map += [i]\r\n            group_map += [rows - j]\r\n\r\n    return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)\r\n\r\n\r\n# Create Q matrix\r\n\r\n\r\ndef ext_make_q_matrix(\r\n    w: Exl2Weight | GPTQWeight,\r\n    extra: _ExtraTensors,\r\n    temp_dq,\r\n    key: Optional[str] = None,\r\n):\r\n    \"\"\"\r\n    Create Q matrix\r\n    \"\"\"\r\n    # max_dq_size = 512*(1024**2)\r\n    # max_dq_rows = max_dq_size // out_features[0]\r\n    max_dq_rows = 0\r\n\r\n    # EXL2\r\n    if isinstance(w, Exl2Weight):\r\n        extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])\r\n        extra.q_perm = torch.argsort(w.q_invperm).short()\r\n\r\n        return make_q_matrix(\r\n            w.q_weight,\r\n            extra.q_perm,\r\n            w.q_invperm,\r\n            w.q_scale,\r\n            w.q_scale_max,\r\n            w.q_groups,\r\n            extra.q_group_map,\r\n            none_tensor,  # zeros\r\n            none_tensor,  # scales\r\n            none_tensor,  # g_idx\r\n            none_tensor,  # bias\r\n            temp_dq,\r\n            max_dq_rows,\r\n        )\r\n    # GPTQ\r\n    elif isinstance(w, GPTQWeight):\r\n        if w.scales.dtype == torch.float:\r\n            w.scales = w.scales.half()\r\n\r\n        # GPTQ with g_idx (act_order)\r\n        if w.g_idx is not None and not (w.g_idx == 0).all().item():\r\n            extra.q_perm = torch.empty(\r\n                (w.qweight.shape[0] * 8,),\r\n                dtype=torch.short,\r\n                device=w.qweight.device,\r\n            )\r\n            extra.q_invperm = torch.empty_like(extra.q_perm)\r\n            # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.\r\n            return make_q_matrix(\r\n                w.qweight,\r\n                extra.q_perm,\r\n                extra.q_invperm,\r\n                none_tensor,  # q_scale\r\n                none_tensor,  # q_scale_max\r\n                none_tensor,  # q_groups\r\n                none_tensor,  # q_group_map\r\n                w.qzeros,\r\n                w.scales,\r\n                w.g_idx.cpu(),\r\n                none_tensor,  # bias\r\n                temp_dq,\r\n                max_dq_rows,\r\n            )\r\n        # GPTQ without g_idx\r\n        else:\r\n            return make_q_matrix(\r\n                w.qweight,\r\n                none_tensor,  # q_perm\r\n                none_tensor,  # q_invperm\r\n                none_tensor,  # q_scale\r\n                none_tensor,  # q_scale_max\r\n                none_tensor,  # q_groups\r\n                none_tensor,  # q_group_map\r\n                w.qzeros,\r\n                w.scales,\r\n                none_tensor,  # g_idx\r\n                none_tensor,  # bias\r\n                temp_dq,\r\n                max_dq_rows,\r\n            )\r\n    else:\r\n        RuntimeError(\"Cannot create handle\")\r\n\r\n\r\nDEVICE = None\r\nLAYERS = []\r\n\r\n\r\ndef set_device(device):\r\n    global DEVICE\r\n    DEVICE = device\r\n\r\n\r\ndef create_exllama_buffers(max_total_tokens: int):\r\n    global LAYERS, DEVICE\r\n\r\n    # No need to initialize scratch space if there are no layers\r\n    # that use ExLLamav2.\r\n    if len(LAYERS) == 0:\r\n        return\r\n\r\n    # Find the size of the scratch space.\r\n    scratch_bytes = max(\r\n        layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)\r\n        for layer in LAYERS\r\n    )\r\n    temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)\r\n\r\n    for layer in LAYERS:\r\n        layer.post_init(temp_dq)\r\n\r\n\r\nclass QuantLinear(nn.Module):\r\n    QUANT_TYPE = \"exllamav2\"\r\n\r\n    \"\"\"Linear layer implementation with per-group 4-bit quantization of the weights\"\"\"\r\n\r\n    def __init__(\r\n        self,\r\n        weight: Exl2Weight | GPTQWeight,\r\n        bias: torch.Tensor,\r\n    ):\r\n        super().__init__()\r\n\r\n        self.q_handle = None\r\n        self.q_tensors = weight\r\n        self.extra_tensors = _ExtraTensors()\r\n\r\n        if isinstance(weight, Exl2Weight):\r\n            self.infeatures = weight.q_invperm.shape[0]\r\n            self.outfeatures = weight.q_weight.shape[1]\r\n        elif isinstance(weight, GPTQWeight):\r\n            if weight.bits != 4:\r\n                raise ValueError(\r\n                    f\"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization.\"\r\n                )\r\n\r\n            self.infeatures = weight.qweight.shape[0] // weight.bits * 32\r\n            self.outfeatures = weight.qweight.shape[1]\r\n\r\n        self.padding = -self.outfeatures % 32\r\n        self.outfeatures = self.outfeatures + self.padding\r\n\r\n        self.device = weight.device\r\n        self.bias = bias if bias is not None else None\r\n\r\n        global LAYERS\r\n        LAYERS.append(self)\r\n\r\n    def post_init(self, temp_dq):\r\n        device = self.q_tensors.device\r\n        assert device.type == \"cuda\"\r\n        assert device.index is not None\r\n        temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())\r\n\r\n        # We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,\r\n        # and `Memory access fault by GPU node-2` will EAT you.\r\n        self.temp_dq = temp_dq\r\n        self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)\r\n\r\n    def forward(self, x, force_cuda=False):\r\n        output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)\r\n\r\n        if self.bias is not None:\r\n            output.add_(self.bias)\r\n        return output\r\n\r\n    def temp_dq_size(self):\r\n        return self.infeatures * self.outfeatures * 2 + 128\r\n\r\n    def temp_fwd_size(self, max_input_len, max_batch_size):\r\n        return self.outfeatures * max_input_len * max_batch_size * 4 + 128\r\n\r\n    def scratch_space_fixed(self, max_input_len, max_batch_size):\r\n        return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)\r\n\r\n\r\nclass ExLlamaV2DeviceTensors:\r\n\r\n    device_idx: int\r\n    scratch_bytes: int\r\n    scratch_idx: int\r\n    scratch: torch.tensor = None\r\n\r\n    def __init__(self, device, scratch_bytes):\r\n        self.device = device\r\n        self.scratch_bytes = scratch_bytes\r\n\r\n    def prepare(self):\r\n        self.scratch = torch.empty(\r\n            (self.scratch_bytes // 2,), dtype=torch.half, device=self.device\r\n        )\r\n\r\n    def get_scratch_slice(self, size_bytes):\r\n\r\n        if self.scratch is None:\r\n            self.prepare()\r\n\r\n        size_bytes = ((size_bytes + 127) // 128) * 128\r\n        size_half = size_bytes // 2\r\n        scratch_slice = self.scratch.narrow(0, 0, size_half)\r\n        return scratch_slice\r\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/ipex.py",
    "content": "import math\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\n\r\nimport intel_extension_for_pytorch as ipex\r\n\r\n\r\nclass QuantLinear(nn.Module):\r\n    def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):\r\n        super().__init__()\r\n        self.register_buffer(\"qweight\", qweight)\r\n        self.register_buffer(\"qzeros\", qzeros)\r\n        self.register_buffer(\"scales\", scales)\r\n        self.register_buffer(\"g_idx\", g_idx)\r\n        if bias is not None:\r\n            self.register_buffer(\"bias\", bias)\r\n        else:\r\n            self.bias = None\r\n        if bits not in [4]:\r\n            raise NotImplementedError(\"Only 4 bits are supported.\")\r\n        self.bits = bits\r\n        self.maxq = 2**self.bits - 1\r\n        self.groupsize = groupsize\r\n\r\n        self.outfeatures = qweight.shape[1]\r\n        self.infeatures = qweight.shape[0] * 32 // bits\r\n        self.woq_linear = (\r\n            ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(\r\n                self.qweight,\r\n                self.scales,\r\n                self.qzeros,\r\n                self.infeatures,\r\n                self.outfeatures,\r\n                bias=self.bias,\r\n                group_size=self.groupsize,\r\n                g_idx=g_idx,\r\n                quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,\r\n                dtype=ipex.llm.quantization.QuantDtype.INT4,\r\n            )\r\n        )\r\n\r\n    @classmethod\r\n    def new(cls, bits, groupsize, infeatures, outfeatures, bias):\r\n        if bits not in [4]:\r\n            raise NotImplementedError(\"Only 4 bits are supported.\")\r\n\r\n        qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)\r\n        qzeros = torch.zeros(\r\n            (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),\r\n            dtype=torch.int32,\r\n        )\r\n        scales = torch.zeros(\r\n            (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16\r\n        )\r\n        g_idx = torch.tensor(\r\n            [i // groupsize for i in range(infeatures)], dtype=torch.int32\r\n        )\r\n        if bias:\r\n            bias = torch.zeros((outfeatures), dtype=torch.float16)\r\n        else:\r\n            bias = None\r\n        return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)\r\n\r\n    def pack(self, linear, scales, zeros, g_idx=None):\r\n        self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx\r\n\r\n        scales = scales.t().contiguous()\r\n        zeros = zeros.t().contiguous()\r\n        scale_zeros = zeros * scales\r\n        self.scales = scales.clone().half()\r\n        if linear.bias is not None:\r\n            self.bias = linear.bias.clone().half()\r\n\r\n        intweight = []\r\n        for idx in range(self.infeatures):\r\n            intweight.append(\r\n                torch.round(\r\n                    (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])\r\n                    / self.scales[self.g_idx[idx]]\r\n                ).to(torch.int)[:, None]\r\n            )\r\n        intweight = torch.cat(intweight, dim=1)\r\n        intweight = intweight.t().contiguous()\r\n        intweight = intweight.numpy().astype(np.uint32)\r\n        qweight = np.zeros(\r\n            (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32\r\n        )\r\n        i = 0\r\n        row = 0\r\n        while row < qweight.shape[0]:\r\n            if self.bits in [4]:\r\n                for j in range(i, i + (32 // self.bits)):\r\n                    qweight[row] |= intweight[j] << (self.bits * (j - i))\r\n                i += 32 // self.bits\r\n                row += 1\r\n            else:\r\n                raise NotImplementedError(\"Only 4 bits are supported.\")\r\n\r\n        qweight = qweight.astype(np.int32)\r\n        self.qweight = torch.from_numpy(qweight)\r\n\r\n        zeros -= 1\r\n        zeros = zeros.numpy().astype(np.uint32)\r\n        qzeros = np.zeros(\r\n            (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32\r\n        )\r\n        i = 0\r\n        col = 0\r\n        while col < qzeros.shape[1]:\r\n            if self.bits in [4]:\r\n                for j in range(i, i + (32 // self.bits)):\r\n                    qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))\r\n                i += 32 // self.bits\r\n                col += 1\r\n            else:\r\n                raise NotImplementedError(\"Only 4 bits are supported.\")\r\n\r\n        qzeros = qzeros.astype(np.int32)\r\n        self.qzeros = torch.from_numpy(qzeros)\r\n\r\n    def forward(self, x):\r\n        out_shape = x.shape[:-1] + (self.outfeatures,)\r\n        out = self.woq_linear(x.reshape(-1, x.shape[-1]))\r\n        return out.reshape(out_shape)\r\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/quantize.py",
    "content": "import time\nimport torch.nn as nn\nimport math\nimport json\nimport os\nimport torch\nimport transformers\n\nfrom texttable import Texttable\nfrom transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\nfrom huggingface_hub import HfApi\nfrom accelerate import init_empty_weights\nfrom text_generation_server.utils import initialize_torch_distributed, Weights\nfrom text_generation_server.utils.hub import weight_files\nfrom text_generation_server.layers.gptq import QuantLinear\nfrom loguru import logger\nfrom typing import Optional\nfrom text_generation_server.layers.gptq.utils import torch_snr_error\n\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\nDEV = torch.device(\"cuda:0\")\n\n\nclass Quantizer(nn.Module):\n    def __init__(self, shape=1):\n        super(Quantizer, self).__init__()\n        self.register_buffer(\"maxq\", torch.tensor(0))\n        self.register_buffer(\"scale\", torch.zeros(shape))\n        self.register_buffer(\"zero\", torch.zeros(shape))\n\n    def configure(\n        self,\n        bits,\n        perchannel=False,\n        sym=True,\n        mse=False,\n        norm=2.4,\n        grid=100,\n        maxshrink=0.8,\n        trits=False,\n    ):\n        self.maxq = torch.tensor(2**bits - 1)\n        self.perchannel = perchannel\n        self.sym = sym\n        self.mse = mse\n        self.norm = norm\n        self.grid = grid\n        self.maxshrink = maxshrink\n        if trits:\n            self.maxq = torch.tensor(-1)\n        self.scale = torch.zeros_like(self.scale)\n\n    def _quantize(self, x, scale, zero, maxq):\n        if maxq < 0:\n            return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero\n        q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)\n        return scale * (q - zero)\n\n    def find_params(self, x, weight=False):\n        dev = x.device\n        self.maxq = self.maxq.to(dev)\n\n        shape = x.shape\n        if self.perchannel:\n            if weight:\n                x = x.flatten(1)\n            else:\n                if len(shape) == 4:\n                    x = x.permute([1, 0, 2, 3])\n                    x = x.flatten(1)\n                if len(shape) == 3:\n                    x = x.reshape((-1, shape[-1])).t()\n                if len(shape) == 2:\n                    x = x.t()\n        else:\n            x = x.flatten().unsqueeze(0)\n\n        tmp = torch.zeros(x.shape[0], device=dev)\n        xmin = torch.minimum(x.min(1)[0], tmp)\n        xmax = torch.maximum(x.max(1)[0], tmp)\n\n        if self.sym:\n            xmax = torch.maximum(torch.abs(xmin), xmax)\n            tmp = xmin < 0\n            if torch.any(tmp):\n                xmin[tmp] = -xmax[tmp]\n        tmp = (xmin == 0) & (xmax == 0)\n        xmin[tmp] = -1\n        xmax[tmp] = +1\n\n        if self.maxq < 0:\n            self.scale = xmax\n            self.zero = xmin\n        else:\n            self.scale = (xmax - xmin) / self.maxq\n            if self.sym:\n                self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)\n            else:\n                self.zero = torch.round(-xmin / self.scale)\n\n        if self.mse:\n            best = torch.full([x.shape[0]], float(\"inf\"), device=dev)\n            for i in range(int(self.maxshrink * self.grid)):\n                p = 1 - i / self.grid\n                xmin1 = p * xmin\n                xmax1 = p * xmax\n                scale1 = (xmax1 - xmin1) / self.maxq\n                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero\n                q = self._quantize(\n                    x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq\n                )\n                q -= x\n                q.abs_()\n                q.pow_(self.norm)\n                err = torch.sum(q, 1)\n                tmp = err < best\n                if torch.any(tmp):\n                    best[tmp] = err[tmp]\n                    self.scale[tmp] = scale1[tmp]\n                    self.zero[tmp] = zero1[tmp]\n        if not self.perchannel:\n            if weight:\n                tmp = shape[0]\n            else:\n                tmp = shape[1] if len(shape) != 3 else shape[2]\n            self.scale = self.scale.repeat(tmp)\n            self.zero = self.zero.repeat(tmp)\n\n        if weight:\n            shape = [-1] + [1] * (len(shape) - 1)\n            self.scale = self.scale.reshape(shape)\n            self.zero = self.zero.reshape(shape)\n            return\n        if len(shape) == 4:\n            self.scale = self.scale.reshape((1, -1, 1, 1))\n            self.zero = self.zero.reshape((1, -1, 1, 1))\n        if len(shape) == 3:\n            self.scale = self.scale.reshape((1, 1, -1))\n            self.zero = self.zero.reshape((1, 1, -1))\n        if len(shape) == 2:\n            self.scale = self.scale.unsqueeze(0)\n            self.zero = self.zero.unsqueeze(0)\n\n    def quantize(self, x):\n        if self.ready():\n            return self._quantize(x, self.scale, self.zero, self.maxq)\n\n        return x\n\n    def enabled(self):\n        return self.maxq > 0\n\n    def ready(self):\n        return torch.all(self.scale != 0)\n\n\nclass GPTQ:\n    def __init__(self, layer, observe=False):\n        self.layer = layer\n        self.dev = self.layer.weight.device\n        W = layer.weight.data.clone()\n        if isinstance(self.layer, nn.Conv2d):\n            W = W.flatten(1)\n        if isinstance(self.layer, transformers.Conv1D):\n            W = W.t()\n        self.rows = W.shape[0]\n        self.columns = W.shape[1]\n        self.H = torch.zeros((self.columns, self.columns), device=self.dev)\n        self.nsamples = 0\n        self.quantizer = Quantizer()\n        self.observe = observe\n\n    def add_batch(self, inp, out):\n        # Hessian H = 2 X XT + λ I\n        if self.observe:\n            self.inp1 = inp\n            self.out1 = out\n        else:\n            self.inp1 = None\n            self.out1 = None\n\n        if len(inp.shape) == 2:\n            inp = inp.unsqueeze(0)\n        tmp = inp.shape[0]\n        if isinstance(self.layer, nn.Linear) or isinstance(\n            self.layer, transformers.Conv1D\n        ):\n            if len(inp.shape) == 3:\n                inp = inp.reshape((-1, inp.shape[-1]))\n            inp = inp.t()\n        if isinstance(self.layer, nn.Conv2d):\n            unfold = nn.Unfold(\n                self.layer.kernel_size,\n                dilation=self.layer.dilation,\n                padding=self.layer.padding,\n                stride=self.layer.stride,\n            )\n            inp = unfold(inp)\n            inp = inp.permute([1, 0, 2])\n            inp = inp.flatten(1)\n        self.H *= self.nsamples / (self.nsamples + tmp)\n        self.nsamples += tmp\n        # inp = inp.float()\n        inp = math.sqrt(2 / self.nsamples) * inp.float()\n        # self.H += 2 / self.nsamples * inp.matmul(inp.t())\n        self.H += inp.matmul(inp.t())\n\n    def print_loss(self, name, q_weight, weight_error, timecost):\n        table = Texttable()\n        length = 28\n        name = (\n            (name + \" \" * (length - len(name)))\n            if len(name) <= length\n            else name[:length]\n        )\n\n        table.header([\"name\", \"weight_error\", \"fp_inp_SNR\", \"q_inp_SNR\", \"time\"])\n\n        # assign weight\n        self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(\n            self.layer.weight.data.dtype\n        )\n\n        if self.inp1 is not None:\n            # quantize input to int8\n            quantizer = Quantizer()\n            quantizer.configure(8, perchannel=False, sym=True, mse=False)\n            quantizer.find_params(self.inp1)\n            q_in = quantizer.quantize(self.inp1).type(torch.float16)\n            q_out = self.layer(q_in)\n\n            # get kinds of SNR\n            q_SNR = torch_snr_error(q_out, self.out1).item()\n            fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()\n        else:\n            q_SNR = \"-\"\n            fp_SNR = \"-\"\n\n        table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])\n        print(table.draw().split(\"\\n\")[-2])\n\n    def fasterquant(\n        self, blocksize=128, percdamp=0.01, groupsize=-1, act_order=False, name=\"\"\n    ):\n        self.layer.to(self.dev)\n\n        W = self.layer.weight.data.clone()\n        if isinstance(self.layer, nn.Conv2d):\n            W = W.flatten(1)\n        if isinstance(self.layer, transformers.Conv1D):\n            W = W.t()\n        W = W.float()\n\n        tick = time.time()\n\n        if not self.quantizer.ready():\n            self.quantizer.find_params(W, weight=True)\n\n        H = self.H\n        if not self.observe:\n            del self.H\n        dead = torch.diag(H) == 0\n        H[dead, dead] = 1\n        W[:, dead] = 0\n\n        if act_order:\n            perm = torch.argsort(torch.diag(H), descending=True)\n            W = W[:, perm]\n            H = H[perm][:, perm]\n\n        Losses = torch.zeros_like(W)\n        Q = torch.zeros_like(W)\n\n        damp = percdamp * torch.mean(torch.diag(H))\n        diag = torch.arange(self.columns, device=self.dev)\n        H[diag, diag] += damp\n        H = torch.linalg.cholesky(H)\n        H = torch.cholesky_inverse(H)\n        try:\n            H = torch.linalg.cholesky(H, upper=True)\n        except Exception:\n            # Addition because Falcon fails on h_to_4h\n            H = torch.linalg.cholesky(\n                H + 1e-5 * torch.eye(H.shape[0]).to(H.device), upper=True\n            )\n        Hinv = H\n\n        g_idx = []\n        scale = []\n        zero = []\n        now_idx = 1\n\n        for i1 in range(0, self.columns, blocksize):\n            i2 = min(i1 + blocksize, self.columns)\n            count = i2 - i1\n\n            W1 = W[:, i1:i2].clone()\n            Q1 = torch.zeros_like(W1)\n            Err1 = torch.zeros_like(W1)\n            Losses1 = torch.zeros_like(W1)\n            Hinv1 = Hinv[i1:i2, i1:i2]\n\n            for i in range(count):\n                w = W1[:, i]\n                d = Hinv1[i, i]\n\n                if groupsize != -1:\n                    if (i1 + i) % groupsize == 0:\n                        self.quantizer.find_params(\n                            W[:, (i1 + i) : (i1 + i + groupsize)], weight=True\n                        )\n\n                    if ((i1 + i) // groupsize) - now_idx == -1:\n                        scale.append(self.quantizer.scale)\n                        zero.append(self.quantizer.zero)\n                        now_idx += 1\n\n                q = self.quantizer.quantize(w.unsqueeze(1)).flatten()\n                Q1[:, i] = q\n                Losses1[:, i] = (w - q) ** 2 / d**2\n\n                err1 = (w - q) / d\n                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))\n                Err1[:, i] = err1\n\n            Q[:, i1:i2] = Q1\n            Losses[:, i1:i2] = Losses1 / 2\n\n            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])\n\n        torch.cuda.synchronize()\n        error = torch.sum(Losses).item()\n\n        groupsize = groupsize if groupsize != -1 else self.columns\n        g_idx = [i // groupsize for i in range(self.columns)]\n        g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)\n        if act_order:\n            invperm = torch.argsort(perm)\n            Q = Q[:, invperm]\n            g_idx = g_idx[invperm]\n\n        if isinstance(self.layer, transformers.Conv1D):\n            Q = Q.t()\n\n        self.print_loss(\n            name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)\n        )\n\n        if scale == []:\n            scale.append(self.quantizer.scale)\n            zero.append(self.quantizer.zero)\n        scale = torch.cat(scale, dim=1)\n        zero = torch.cat(zero, dim=1)\n        return scale, zero, g_idx, error\n\n    def free(self):\n        self.inp1 = None\n        self.out1 = None\n        self.H = None\n        self.Losses = None\n        self.Trace = None\n        torch.cuda.empty_cache()\n\n\ndef get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train\")\n    testdata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\"\\n\\n\".join(traindata[\"text\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\"\\n\\n\".join(testdata[\"text\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"train\")\n    valdata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"validation\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\"\\n\\n\".join(traindata[\"sentence\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\"\\n\\n\".join(valdata[\"sentence\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"train\": \"en/c4-train.00000-of-01024.json.gz\"},\n        split=\"train\",\n        use_auth_token=False,\n    )\n    valdata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"validation\": \"en/c4-validation.00000-of-00008.json.gz\"},\n        split=\"validation\",\n        use_auth_token=False,\n    )\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(traindata) - 1)\n            trainenc = tokenizer(traindata[i][\"text\"], return_tensors=\"pt\")\n            if trainenc.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n\n    import random\n\n    random.seed(0)\n    valenc = []\n    for _ in range(256):\n        while True:\n            i = random.randint(0, len(valdata) - 1)\n            tmp = tokenizer(valdata[i][\"text\"], return_tensors=\"pt\")\n            if tmp.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        valenc.append(tmp.input_ids[:, i:j])\n    valenc = torch.hstack(valenc)\n\n    class TokenizerWrapper:\n        def __init__(self, input_ids):\n            self.input_ids = input_ids\n\n    valenc = TokenizerWrapper(valenc)\n\n    return trainloader, valenc\n\n\ndef get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"train\")\n    testdata = load_dataset(\"ptb_text_only\", \"penn_treebank\", split=\"test\")\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    trainenc = tokenizer(\" \".join(traindata[\"sentence\"]), return_tensors=\"pt\")\n    testenc = tokenizer(\" \".join(testdata[\"sentence\"]), return_tensors=\"pt\")\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n    return trainloader, testenc\n\n\ndef get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):\n    from datasets import load_dataset\n\n    traindata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"train\": \"en/c4-train.00000-of-01024.json.gz\"},\n        split=\"train\",\n    )\n    valdata = load_dataset(\n        \"allenai/c4\",\n        \"allenai--c4\",\n        data_files={\"validation\": \"en/c4-validation.00000-of-00008.json.gz\"},\n        split=\"validation\",\n    )\n\n    try:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=False, trust_remote_code=trust_remote_code\n        )\n    except Exception:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id, use_fast=True, trust_remote_code=trust_remote_code\n        )\n\n    import random\n\n    random.seed(seed)\n    trainloader = []\n    for _ in range(nsamples):\n        while True:\n            i = random.randint(0, len(traindata) - 1)\n            trainenc = tokenizer(traindata[i][\"text\"], return_tensors=\"pt\")\n            if trainenc.input_ids.shape[1] >= seqlen:\n                break\n        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n        j = i + seqlen\n        inp = trainenc.input_ids[:, i:j]\n        tar = inp.clone()\n        tar[:, :-1] = -100\n        trainloader.append((inp, tar))\n\n    valenc = tokenizer(\" \".join(valdata[:1100][\"text\"]), return_tensors=\"pt\")\n    valenc = valenc.input_ids[:, : (256 * seqlen)]\n\n    class TokenizerWrapper:\n        def __init__(self, input_ids):\n            self.input_ids = input_ids\n\n    valenc = TokenizerWrapper(valenc)\n\n    return trainloader, valenc\n\n\ndef get_loaders(\n    name, nsamples=128, seed=0, seqlen=2048, model_id=\"\", trust_remote_code=False\n):\n    if \"wikitext2\" in name:\n        return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)\n    if \"ptb\" in name:\n        if \"new\" in name:\n            return get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code)\n        return get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code)\n    if \"c4\" in name:\n        if \"new\" in name:\n            return get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code)\n        return get_c4(nsamples, seed, seqlen, model_id, trust_remote_code)\n\n\ndef find_layers(module, layers=(nn.Conv2d, nn.Linear), name=\"\"):\n    # Skip last lm_head linear\n    # Need isintance Falcon is inheriting Linear.\n    if isinstance(module, layers) and \"lm_head\" not in name:\n        return {name: module}\n    res = {}\n    for name1, child in module.named_children():\n        res.update(\n            find_layers(\n                child, layers=layers, name=name + \".\" + name1 if name != \"\" else name1\n            )\n        )\n    return res\n\n\n@torch.no_grad()\ndef sequential(\n    model,\n    dataloader,\n    dev,\n    nsamples,\n    bits,\n    groupsize,\n    *,\n    hooks,\n    percdamp=0.01,\n    sym: bool = False,\n    act_order: bool = False,\n):\n    print(\"Starting ...\")\n\n    use_cache = model.config.use_cache\n    model.config.use_cache = False\n    try:\n        layers = model.model.layers\n        prefix = \"model.layers\"\n    except Exception:\n        layers = model.transformer.h\n        prefix = \"transformer.h\"\n\n    dtype = next(iter(model.parameters())).dtype\n    inps = torch.zeros(\n        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev\n    )\n\n    cache = {\"i\": 0}\n    extra = {}\n\n    class Catcher(nn.Module):\n        def __init__(self, module):\n            super().__init__()\n            self.module = module\n\n        def forward(self, inp, **kwargs):\n            inps[cache[\"i\"]] = inp\n            cache[\"i\"] += 1\n            extra.update(kwargs.copy())\n            raise ValueError\n\n    layers[0] = Catcher(layers[0])\n    for batch in dataloader:\n        try:\n            model(batch[0].cuda())\n        except ValueError:\n            pass\n    layers[0] = layers[0].module\n\n    # layers[0] = layers[0].cpu()\n    # model.model.embed_tokens = model.model.embed_tokens.cpu()\n    # model.model.norm = model.model.norm.cpu()\n    torch.cuda.empty_cache()\n    for hook in hooks:\n        hook.remove()\n\n    outs = torch.zeros_like(inps)\n\n    extra = {\n        k: v.to(dev) if isinstance(v, torch.Tensor) else v for k, v in extra.items()\n    }\n\n    print(\"Ready.\")\n\n    quantizers = {}\n    for i in range(len(layers)):\n        print(f\"Quantizing layer {i+1}/{len(layers)}..\")\n        print(\"+------------------+--------------+------------+-----------+-------+\")\n        print(\"|       name       | weight_error | fp_inp_SNR | q_inp_SNR | time  |\")\n        print(\"+==================+==============+============+===========+=======+\")\n\n        layer = layers[i]\n        layer.load()\n        full = find_layers(layer)\n        sequential = [list(full.keys())]\n\n        for names in sequential:\n            subset = {n: full[n] for n in names}\n            gptq = {}\n            for name in subset:\n                gptq[name] = GPTQ(subset[name])\n                gptq[name].quantizer.configure(\n                    bits, perchannel=True, sym=sym, mse=False\n                )\n                pass\n\n            def add_batch(name):\n                nonlocal gptq\n\n                def tmp(_, inp, out):\n                    gptq[name].add_batch(inp[0].data, out.data)\n\n                return tmp\n\n            handles = []\n            for name in subset:\n                handles.append(subset[name].register_forward_hook(add_batch(name)))\n            for j in range(nsamples):\n                outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]\n            for h in handles:\n                h.remove()\n\n            for name in subset:\n                scale, zero, g_idx, error = gptq[name].fasterquant(\n                    percdamp=percdamp,\n                    groupsize=groupsize,\n                    act_order=act_order,\n                    name=name,\n                )\n                quantizers[f\"{prefix}.{i}.{name}\"] = (\n                    gptq[name].quantizer.cpu(),\n                    scale.cpu(),\n                    zero.cpu(),\n                    g_idx.cpu(),\n                    bits,\n                    groupsize,\n                )\n\n                gptq[name].free()\n\n        for j in range(nsamples):\n            outs[j] = layer(inps[j].unsqueeze(0), **extra)[0]\n\n        layer.unload()\n        del layer\n        del gptq\n        torch.cuda.empty_cache()\n\n        inps, outs = outs, inps\n        print(\"+------------------+--------------+------------+-----------+-------+\")\n        print(\"\\n\")\n\n    model.config.use_cache = use_cache\n\n    return quantizers\n\n\ndef make_quant_linear(module, names, bits, groupsize, name=\"\"):\n    if isinstance(module, QuantLinear):\n        return\n    for attr in dir(module):\n        tmp = getattr(module, attr)\n        name1 = name + \".\" + attr if name != \"\" else attr\n        if name1 in names:\n            delattr(module, attr)\n            setattr(\n                module,\n                attr,\n                QuantLinear.new(\n                    bits,\n                    groupsize,\n                    tmp.in_features,\n                    tmp.out_features,\n                    tmp.bias is not None,\n                ),\n            )\n    for name1, child in module.named_children():\n        make_quant_linear(\n            child, names, bits, groupsize, name + \".\" + name1 if name != \"\" else name1\n        )\n\n\n# TODO: perform packing on GPU\ndef pack(model, quantizers, bits, groupsize):\n    layers = find_layers(model)\n    layers = {n: layers[n] for n in quantizers}\n    make_quant_linear(model, quantizers, bits, groupsize)\n    qlayers = find_layers(model, (QuantLinear,))\n    print(\"Packing ...\")\n    for name in qlayers:\n        print(name)\n        quantizers[name], scale, zero, g_idx, _, _ = quantizers[name]\n        qlayers[name].pack(layers[name], scale, zero, g_idx)\n    print(\"Done.\")\n    return model\n\n\ndef setdeepattr(module, full_name, tensor):\n    current = module\n    tokens = full_name.split(\".\")\n    for token in tokens[:-1]:\n        current = getattr(current, token)\n    setattr(current, tokens[-1], tensor)\n\n\ndef getdeepattr(module, full_name):\n    current = module\n    tokens = full_name.split(\".\")\n    for token in tokens:\n        current = getattr(current, token)\n    return current\n\n\ndef load_weights_pre_hook(module_name, weights, recursive=False):\n    def inner(module, args):\n        print(f\"Pre hook {module_name}\")\n        local_params = {}\n        for k, v in module.named_parameters():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for k, v in module.named_buffers():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n\n        for local_param in local_params:\n            current_tensor = getdeepattr(module, local_param)\n            if current_tensor.device == torch.device(\"meta\"):\n                # print(f\"Loading {local_param}\")\n                if module_name:\n                    tensor_name = f\"{module_name}.{local_param}\"\n                else:\n                    tensor_name = local_param\n                tensor = weights.get_tensor(tensor_name)\n                setdeepattr(module, local_param, nn.Parameter(tensor))\n            else:\n                tensor = current_tensor.to(device=torch.device(\"cuda:0\"))\n                if current_tensor.requires_grad:\n                    tensor = nn.Parameter(tensor)\n                setdeepattr(module, local_param, tensor)\n\n    return inner\n\n\ndef load_weights_post_hook(module_name, weights, recursive=False):\n    def inner(module, args, output):\n        print(f\"Post hook {module_name}\")\n        local_params = {}\n        for k, v in module.named_parameters():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for k, v in module.named_buffers():\n            if not recursive and k.count(\".\") != 1:\n                continue\n            local_params[k] = v\n        for local_param in local_params:\n            # print(f\"Unloading {local_param}\")\n            current_tensor = getdeepattr(module, local_param)\n            setdeepattr(\n                module,\n                local_param,\n                nn.Parameter(current_tensor.to(device=torch.device(\"cpu\"))),\n            )\n        return output\n\n    return inner\n\n\ndef quantize(\n    model_id: str,\n    bits: int,\n    groupsize: int,\n    output_dir: str,\n    revision: str,\n    trust_remote_code: bool,\n    upload_to_model_id: Optional[str],\n    percdamp: float,\n    act_order: bool,\n    sym: bool,\n):\n    print(\"loading model\")\n    config = AutoConfig.from_pretrained(\n        model_id,\n        trust_remote_code=trust_remote_code,\n    )\n\n    with init_empty_weights():\n        model = AutoModelForCausalLM.from_config(\n            config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code\n        )\n    model = model.eval()\n\n    print(\"LOADED model\")\n    files = weight_files(model_id, revision, extension=\".safetensors\")\n    process_group, _, _ = initialize_torch_distributed()\n    weights = Weights(\n        files,\n        device=torch.device(\"cuda:0\"),\n        dtype=torch.float16,\n        process_group=process_group,\n        aliases={\"embed_tokens.weight\": [\"lm_head.weight\"]},\n        weights_loader=DefaultWeightsLoader(UnquantizedWeight),\n    )\n    hooks = []\n    for name, module in model.named_modules():\n\n        def load(module, name):\n            def _load():\n                load_weights_pre_hook(name, weights, recursive=True)(module, None)\n\n            return _load\n\n        def unload(module, name):\n            def _unload():\n                load_weights_post_hook(name, weights, recursive=True)(\n                    module, None, None\n                )\n\n            return _unload\n\n        module.load = load(module, name)\n        module.unload = unload(module, name)\n        hooks.append(\n            module.register_forward_pre_hook(load_weights_pre_hook(name, weights))\n        )\n        hooks.append(\n            module.register_forward_hook(load_weights_post_hook(name, weights))\n        )\n    model.seqlen = 2048\n\n    dataset = \"wikitext2\"\n    nsamples = 128\n    seed = None\n\n    dataloader, testloader = get_loaders(\n        dataset,\n        nsamples=nsamples,\n        seed=seed,\n        model_id=model_id,\n        seqlen=model.seqlen,\n        trust_remote_code=trust_remote_code,\n    )\n\n    tick = time.time()\n    quantizers = sequential(\n        model,\n        dataloader,\n        DEV,\n        nsamples,\n        bits,\n        groupsize,\n        percdamp=percdamp,\n        act_order=act_order,\n        hooks=hooks,\n        sym=sym,\n    )\n    print(time.time() - tick)\n\n    pack(model, quantizers, bits, groupsize)\n    from safetensors.torch import save_file\n    from huggingface_hub import split_torch_state_dict_into_shards\n\n    state_dict = model.state_dict()\n    state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}\n\n    max_shard_size = \"10GB\"\n    state_dict_split = split_torch_state_dict_into_shards(\n        state_dict,\n        filename_pattern=\"model.safetensors\",\n        max_shard_size=max_shard_size,\n    )\n    index = None\n    if state_dict_split.is_sharded:\n        index = {\n            \"metadata\": state_dict_split.metadata,\n            \"weight_map\": state_dict_split.tensor_to_filename,\n        }\n    shards = state_dict_split.filename_to_tensors\n    os.makedirs(output_dir, exist_ok=True)\n    for shard_file, shard in shards.items():\n        save_file(\n            shard,\n            os.path.join(output_dir, shard_file),\n            metadata={\n                \"format\": \"pt\",\n                \"quantized\": \"gptq\",\n                \"origin\": \"text-generation-inference\",\n            },\n        )\n    if index is None:\n        path_to_weights = os.path.join(output_dir, \"model.safetensors\")\n        logger.info(f\"Model weights saved in {path_to_weights}\")\n    else:\n        save_index_file = \"model.safetensors.index.json\"\n        save_index_file = os.path.join(output_dir, save_index_file)\n        with open(save_index_file, \"w\", encoding=\"utf-8\") as f:\n            content = json.dumps(index, indent=2, sort_keys=True) + \"\\n\"\n            f.write(content)\n        logger.info(\n            f\"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be \"\n            f\"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the \"\n            f\"index located at {save_index_file}.\"\n        )\n    config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)\n    config.quantization_config = {\n        \"bits\": bits,\n        \"group_size\": groupsize,\n        \"damp_percent\": percdamp,\n        \"desc_act\": act_order,\n        \"static_groups\": False,\n        \"sym\": sym,\n        \"quant_method\": \"gptq\",\n    }\n    config.save_pretrained(output_dir)\n    logger.info(\"Saved config\")\n    logger.info(\"Saving tokenizer\")\n    tokenizer = AutoTokenizer.from_pretrained(\n        model_id, trust_remote_code=trust_remote_code\n    )\n    tokenizer.save_pretrained(output_dir)\n    logger.info(\"Saved tokenizer\")\n\n    if upload_to_model_id:\n        api = HfApi()\n\n        api.upload_folder(\n            folder_path=output_dir, repo_id=upload_to_model_id, repo_type=\"model\"\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/triton.py",
    "content": "import math\r\nimport numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.cuda.amp import custom_fwd\r\n\r\nimport triton\r\nimport triton.language as tl\r\nfrom . import custom_autotune\r\n\r\n\r\n# code based https://github.com/fpgaminer/GPTQ-triton\r\n@custom_autotune.autotune(\r\n    configs=[\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 64,\r\n                \"BLOCK_SIZE_N\": 256,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=4,\r\n            num_warps=4,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 128,\r\n                \"BLOCK_SIZE_N\": 128,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=4,\r\n            num_warps=4,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 64,\r\n                \"BLOCK_SIZE_N\": 128,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=4,\r\n            num_warps=4,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 128,\r\n                \"BLOCK_SIZE_N\": 32,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=4,\r\n            num_warps=4,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 64,\r\n                \"BLOCK_SIZE_N\": 64,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=4,\r\n            num_warps=4,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 64,\r\n                \"BLOCK_SIZE_N\": 128,\r\n                \"BLOCK_SIZE_K\": 32,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=2,\r\n            num_warps=8,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 64,\r\n                \"BLOCK_SIZE_N\": 64,\r\n                \"BLOCK_SIZE_K\": 64,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=3,\r\n            num_warps=8,\r\n        ),\r\n        triton.Config(\r\n            {\r\n                \"BLOCK_SIZE_M\": 32,\r\n                \"BLOCK_SIZE_N\": 32,\r\n                \"BLOCK_SIZE_K\": 128,\r\n                \"GROUP_SIZE_M\": 8,\r\n            },\r\n            num_stages=2,\r\n            num_warps=4,\r\n        ),\r\n    ],\r\n    key=[\"M\", \"N\", \"K\"],\r\n    nearest_power_of_two=True,\r\n    prune_configs_by={\r\n        \"early_config_prune\": custom_autotune.matmul248_kernel_config_pruner,\r\n        \"perf_model\": None,\r\n        \"top_k\": None,\r\n    },\r\n)\r\n@triton.jit\r\ndef matmul_248_kernel(\r\n    a_ptr,\r\n    b_ptr,\r\n    c_ptr,\r\n    scales_ptr,\r\n    zeros_ptr,\r\n    g_ptr,\r\n    M,\r\n    N,\r\n    K,\r\n    bits,\r\n    maxq,\r\n    stride_am,\r\n    stride_ak,\r\n    stride_bk,\r\n    stride_bn,\r\n    stride_cm,\r\n    stride_cn,\r\n    stride_scales,\r\n    stride_zeros,\r\n    BLOCK_SIZE_M: tl.constexpr,\r\n    BLOCK_SIZE_N: tl.constexpr,\r\n    BLOCK_SIZE_K: tl.constexpr,\r\n    GROUP_SIZE_M: tl.constexpr,\r\n):\r\n    \"\"\"\r\n    Compute the matrix multiplication C = A x B.\r\n    A is of shape (M, K) float16\r\n    B is of shape (K//8, N) int32\r\n    C is of shape (M, N) float16\r\n    scales is of shape (G, N) float16\r\n    zeros is of shape (G, N) float16\r\n    g_ptr is of shape (K) int32\r\n    \"\"\"\r\n    infearure_per_bits = 32 // bits\r\n\r\n    pid = tl.program_id(axis=0)\r\n    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\r\n    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\r\n    num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\r\n    num_pid_in_group = GROUP_SIZE_M * num_pid_n\r\n    group_id = pid // num_pid_in_group\r\n    first_pid_m = group_id * GROUP_SIZE_M\r\n    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\r\n    pid_m = first_pid_m + (pid % group_size_m)\r\n    pid_n = (pid % num_pid_in_group) // group_size_m\r\n\r\n    offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\r\n    offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\r\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\r\n    a_ptrs = a_ptr + (\r\n        offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\r\n    )  # (BLOCK_SIZE_M, BLOCK_SIZE_K)\r\n    a_mask = offs_am[:, None] < M\r\n    # b_ptrs is set up such that it repeats elements along the K axis 8 times\r\n    b_ptrs = b_ptr + (\r\n        (offs_k[:, None] // infearure_per_bits) * stride_bk\r\n        + offs_bn[None, :] * stride_bn\r\n    )  # (BLOCK_SIZE_K, BLOCK_SIZE_N)\r\n    g_ptrs = g_ptr + offs_k\r\n    # shifter is used to extract the N bits of each element in the 32-bit word from B\r\n    scales_ptrs = scales_ptr + offs_bn[None, :]\r\n    zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)\r\n\r\n    shifter = (offs_k % infearure_per_bits) * bits\r\n    zeros_shifter = (offs_bn % infearure_per_bits) * bits\r\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\r\n\r\n    for k in range(0, num_pid_k):\r\n        g_idx = tl.load(g_ptrs)\r\n\r\n        # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\r\n        scales = tl.load(\r\n            scales_ptrs + g_idx[:, None] * stride_scales\r\n        )  # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\r\n        zeros = tl.load(\r\n            zeros_ptrs + g_idx[:, None] * stride_zeros\r\n        )  # (BLOCK_SIZE_K, BLOCK_SIZE_N,)\r\n\r\n        zeros = (zeros >> zeros_shifter[None, :]) & maxq\r\n        zeros = (zeros + 1) & maxq  # eventually avoid overflow\r\n\r\n        a = tl.load(a_ptrs, mask=a_mask, other=0.0)  # (BLOCK_SIZE_M, BLOCK_SIZE_K)\r\n        b = tl.load(b_ptrs)  # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\r\n\r\n        # Now we need to unpack b (which is N-bit values) into 32-bit values\r\n        b = (b >> shifter[:, None]) & maxq  # Extract the N-bit values\r\n        b = (b - zeros) * scales  # Scale and shift\r\n\r\n        accumulator += tl.dot(a, b)\r\n        a_ptrs += BLOCK_SIZE_K\r\n        b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk\r\n        g_ptrs += BLOCK_SIZE_K\r\n\r\n    c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]\r\n    c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\r\n    tl.store(c_ptrs, accumulator, mask=c_mask)\r\n\r\n\r\ndef matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):\r\n    with (\r\n        torch.xpu.device(input.device)\r\n        if torch.xpu.is_available()\r\n        else torch.cuda.device(input.device)\r\n    ):\r\n        output = torch.empty(\r\n            (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16\r\n        )\r\n\r\n        def grid(META):\r\n            return (\r\n                triton.cdiv(input.shape[0], META[\"BLOCK_SIZE_M\"])\r\n                * triton.cdiv(qweight.shape[1], META[\"BLOCK_SIZE_N\"]),\r\n            )\r\n\r\n        matmul_248_kernel[grid](\r\n            input,\r\n            qweight,\r\n            output,\r\n            scales,\r\n            qzeros,\r\n            g_idx,\r\n            input.shape[0],\r\n            qweight.shape[1],\r\n            input.shape[1],\r\n            bits,\r\n            maxq,\r\n            input.stride(0),\r\n            input.stride(1),\r\n            qweight.stride(0),\r\n            qweight.stride(1),\r\n            output.stride(0),\r\n            output.stride(1),\r\n            scales.stride(0),\r\n            qzeros.stride(0),\r\n        )\r\n        return output\r\n\r\n\r\nclass QuantLinearFunction(torch.autograd.Function):\r\n    @staticmethod\r\n    @custom_fwd(cast_inputs=torch.float16)\r\n    def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):\r\n        output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)\r\n        return output\r\n\r\n\r\nclass QuantLinear(nn.Module):\r\n    def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):\r\n        super().__init__()\r\n        self.register_buffer(\"qweight\", qweight)\r\n        self.register_buffer(\"qzeros\", qzeros)\r\n        self.register_buffer(\"scales\", scales)\r\n        self.register_buffer(\"g_idx\", g_idx)\r\n        if bias is not None:\r\n            self.register_buffer(\"bias\", bias)\r\n        else:\r\n            self.bias = None\r\n        if bits not in [2, 4, 8]:\r\n            raise NotImplementedError(\"Only 2,4,8 bits are supported.\")\r\n        self.bits = bits\r\n        self.maxq = 2**self.bits - 1\r\n        self.groupsize = groupsize\r\n\r\n        self.outfeatures = qweight.shape[1]\r\n        self.infeatures = qweight.shape[0] * 32 // bits\r\n\r\n    @classmethod\r\n    def new(cls, bits, groupsize, infeatures, outfeatures, bias):\r\n        if bits not in [2, 4, 8]:\r\n            raise NotImplementedError(\"Only 2,4,8 bits are supported.\")\r\n\r\n        qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)\r\n        qzeros = torch.zeros(\r\n            (math.ceil(infeatures / groupsize), outfeatures // 32 * bits),\r\n            dtype=torch.int32,\r\n        )\r\n        scales = torch.zeros(\r\n            (math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16\r\n        )\r\n        g_idx = torch.tensor(\r\n            [i // groupsize for i in range(infeatures)], dtype=torch.int32\r\n        )\r\n        if bias:\r\n            bias = torch.zeros((outfeatures), dtype=torch.float16)\r\n        else:\r\n            bias = None\r\n        return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)\r\n\r\n    def pack(self, linear, scales, zeros, g_idx=None):\r\n        self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx\r\n\r\n        scales = scales.t().contiguous()\r\n        zeros = zeros.t().contiguous()\r\n        scale_zeros = zeros * scales\r\n        self.scales = scales.clone().half()\r\n        if linear.bias is not None:\r\n            self.bias = linear.bias.clone().half()\r\n\r\n        intweight = []\r\n        for idx in range(self.infeatures):\r\n            intweight.append(\r\n                torch.round(\r\n                    (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])\r\n                    / self.scales[self.g_idx[idx]]\r\n                ).to(torch.int)[:, None]\r\n            )\r\n        intweight = torch.cat(intweight, dim=1)\r\n        intweight = intweight.t().contiguous()\r\n        intweight = intweight.numpy().astype(np.uint32)\r\n        qweight = np.zeros(\r\n            (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32\r\n        )\r\n        i = 0\r\n        row = 0\r\n        while row < qweight.shape[0]:\r\n            if self.bits in [2, 4, 8]:\r\n                for j in range(i, i + (32 // self.bits)):\r\n                    qweight[row] |= intweight[j] << (self.bits * (j - i))\r\n                i += 32 // self.bits\r\n                row += 1\r\n            else:\r\n                raise NotImplementedError(\"Only 2,4,8 bits are supported.\")\r\n\r\n        qweight = qweight.astype(np.int32)\r\n        self.qweight = torch.from_numpy(qweight)\r\n\r\n        zeros -= 1\r\n        zeros = zeros.numpy().astype(np.uint32)\r\n        qzeros = np.zeros(\r\n            (zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32\r\n        )\r\n        i = 0\r\n        col = 0\r\n        while col < qzeros.shape[1]:\r\n            if self.bits in [2, 4, 8]:\r\n                for j in range(i, i + (32 // self.bits)):\r\n                    qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))\r\n                i += 32 // self.bits\r\n                col += 1\r\n            else:\r\n                raise NotImplementedError(\"Only 2,4,8 bits are supported.\")\r\n\r\n        qzeros = qzeros.astype(np.int32)\r\n        self.qzeros = torch.from_numpy(qzeros)\r\n\r\n    def forward(self, x):\r\n        out_shape = x.shape[:-1] + (self.outfeatures,)\r\n        out = QuantLinearFunction.apply(\r\n            x.reshape(-1, x.shape[-1]),\r\n            self.qweight,\r\n            self.scales,\r\n            self.qzeros,\r\n            self.g_idx,\r\n            self.bits,\r\n            self.maxq,\r\n        )\r\n        out = out + self.bias if self.bias is not None else out\r\n        return out.reshape(out_shape)\r\n"
  },
  {
    "path": "server/text_generation_server/layers/gptq/utils.py",
    "content": "import torch\n\n\n# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py\ndef torch_snr_error(\n    y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = \"mean\"\n) -> torch.Tensor:\n    \"\"\"\n    Compute SNR between y_pred(tensor) and y_real(tensor)\n\n    SNR can be calcualted as following equation:\n\n        SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2\n\n    if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.\n\n        SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)\n\n    Args:\n        y_pred (torch.Tensor): _description_\n        y_real (torch.Tensor): _description_\n        reduction (str, optional): _description_. Defaults to 'mean'.\n\n    Raises:\n        ValueError: _description_\n        ValueError: _description_\n\n    Returns:\n        torch.Tensor: _description_\n    \"\"\"\n    if y_pred.shape != y_real.shape:\n        raise ValueError(\n            f\"Can not compute snr loss for tensors with different shape. \"\n            f\"({y_pred.shape} and {y_real.shape})\"\n        )\n    reduction = str(reduction).lower()\n\n    if y_pred.ndim == 1:\n        y_pred = y_pred.unsqueeze(0)\n        y_real = y_real.unsqueeze(0)\n\n    y_pred = y_pred.flatten(start_dim=1)\n    y_real = y_real.flatten(start_dim=1)\n\n    noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)\n    signal_power = torch.pow(y_real, 2).sum(dim=-1)\n    snr = (noise_power) / (signal_power + 1e-7)\n\n    if reduction == \"mean\":\n        return torch.mean(snr)\n    elif reduction == \"sum\":\n        return torch.sum(snr)\n    elif reduction == \"none\":\n        return snr\n    else:\n        raise ValueError(\"Unsupported reduction method.\")\n"
  },
  {
    "path": "server/text_generation_server/layers/layernorm.py",
    "content": "import torch\nfrom torch import nn\nfrom accelerate import init_empty_weights\nfrom text_generation_server.utils.import_utils import (\n    SYSTEM,\n)\n\n\n# Monkey patching\n@classmethod\ndef load_layer_norm(cls, prefix, weights, eps):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    bias = weights.get_tensor(f\"{prefix}.bias\")\n    with init_empty_weights():\n        ln = cls(weight.shape, eps=eps)\n\n    ln.weight = torch.nn.Parameter(weight)\n    ln.bias = torch.nn.Parameter(bias)\n    return ln\n\n\n@classmethod\ndef load_layer_norm_no_bias(cls, prefix, weights, eps):\n    weight = weights.get_tensor(f\"{prefix}.weight\")\n    with init_empty_weights():\n        ln = cls(weight.shape, eps=eps)\n\n    ln.weight = torch.nn.Parameter(weight)\n    ln.bias = None\n    return ln\n\n\ntorch.nn.LayerNorm.load = load_layer_norm\ntorch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias\n\nif SYSTEM == \"cuda\":\n    import dropout_layer_norm\n\n    class FastLayerNorm(nn.LayerNorm):\n        def forward(self, hidden_states, residual=None):\n            if hidden_states.shape[-1] > 8192:\n                if residual is not None:\n                    hidden_states += residual\n                residual = hidden_states\n\n                return super(FastLayerNorm, self).forward(hidden_states), residual\n            else:\n                (\n                    normed_hidden_states,\n                    residual,\n                    *rest,\n                ) = dropout_layer_norm.dropout_add_ln_fwd(\n                    hidden_states,\n                    residual,\n                    self.weight,\n                    self.bias,\n                    None,\n                    None,\n                    None,\n                    None,\n                    0.0,\n                    self.eps,\n                    1.0,\n                    0,\n                    None,\n                    False,\n                    False,\n                )\n                if residual is None:\n                    residual = hidden_states\n\n                return normed_hidden_states, residual\n\nelif SYSTEM == \"rocm\":\n    import vllm._custom_ops as ops\n\n    class FastLayerNorm(nn.LayerNorm):\n        def forward(self, hidden_states, residual=None):\n            if residual is not None:\n                hidden_states += residual\n            residual = hidden_states\n\n            return super().forward(hidden_states), residual\n\nelif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\n\n    class FastLayerNorm(nn.LayerNorm):\n        def forward(self, hidden_states, residual=None):\n            out = ipex.llm.functional.add_layer_norm(\n                residual,\n                hidden_states,\n                self.weight,\n                self.bias,\n                self.eps,\n                residual is not None,\n            )\n            return out, residual if residual is not None else hidden_states\n\n\nclass FastRMSNorm(nn.Module):\n    def __init__(self, weight: torch.Tensor, eps: float):\n        super().__init__()\n\n        self.weight = nn.Parameter(weight)\n        self.variance_epsilon = eps\n\n    @classmethod\n    def load(cls, prefix, weights, eps=1e-6):\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        return cls(weight, eps)\n\n    def forward(self, hidden_states, residual=None):\n        if SYSTEM == \"ipex\":\n            out = ipex.llm.functional.add_rms_norm(\n                residual,\n                hidden_states,\n                self.weight,\n                None,\n                self.variance_epsilon,\n                residual is not None,\n            )\n            return out, residual if residual is not None else hidden_states\n        elif SYSTEM == \"rocm\":\n            # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.\n            if residual is not None:\n                ops.fused_add_rms_norm(\n                    hidden_states,\n                    residual,\n                    self.weight.data,\n                    self.variance_epsilon,\n                )\n                return hidden_states, residual\n\n            residual = hidden_states\n\n            out = torch.empty_like(hidden_states)\n            ops.rms_norm(\n                out,\n                hidden_states,\n                self.weight.data,\n                self.variance_epsilon,\n            )\n            return out, residual\n        elif hidden_states.shape[-1] > 8192:\n            if residual is not None:\n                hidden_states += residual\n            residual = hidden_states\n\n            hidden_states = hidden_states.to(torch.float32)\n            variance = hidden_states.pow(2).mean(-1, keepdim=True)\n            hidden_states = hidden_states * torch.rsqrt(\n                variance + self.variance_epsilon\n            )\n\n            # convert into half-precision if necessary\n            if self.weight.dtype in [torch.float16, torch.bfloat16]:\n                hidden_states = hidden_states.to(self.weight.dtype)\n\n            return self.weight * hidden_states, residual\n        elif SYSTEM == \"cuda\":\n            # faster post attention rms norm\n            (\n                normed_hidden_states,\n                res,\n                *rest,\n            ) = dropout_layer_norm.dropout_add_ln_fwd(\n                hidden_states,\n                residual,\n                self.weight,\n                None,\n                None,\n                None,\n                None,\n                None,\n                0.0,\n                self.variance_epsilon,\n                1.0,\n                0,\n                None,\n                False,\n                True,  # Activate RMSNorm\n            )\n            if res is None:\n                res = hidden_states\n\n            return normed_hidden_states, res\n        else:\n            raise ValueError(\n                \"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.\"\n            )\n"
  },
  {
    "path": "server/text_generation_server/layers/linear.py",
    "content": "import torch\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom torch.nn import functional as F\nimport os\n\nif SYSTEM == \"rocm\":\n    ROCM_USE_SKINNY_GEMM = os.getenv(\"ROCM_USE_SKINNY_GEMM\", \"True\").lower() in (\n        \"true\",\n        \"1\",\n    )\n\n    if ROCM_USE_SKINNY_GEMM:\n        try:\n            import vllm._custom_ops as ops\n        except Exception as e:\n            raise ImportError(\n                f\"Could not load `vllm._custom_ops` for ROCm skinny gemm. Full error: {e}\"\n            )\n\n\nclass FastLinear(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n    ) -> None:\n        super().__init__()\n        self.weight = torch.nn.Parameter(weight, requires_grad=False)\n        if bias is not None:\n            self.bias = torch.nn.Parameter(bias, requires_grad=False)\n        else:\n            self.bias = None\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        if bias:\n            bias = weights.get_tensor(f\"{prefix}.bias\")\n        else:\n            bias = None\n        return cls(weight, bias)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return F.linear(input, self.weight, self.bias)\n\n\nclass FastLinearROCm(torch.nn.Module):\n    def __init__(\n        self,\n        weight,\n        bias,\n    ) -> None:\n        super().__init__()\n        self.weight = torch.nn.Parameter(weight)\n        if bias is not None:\n            self.bias = torch.nn.Parameter(bias)\n        else:\n            self.bias = None\n\n        self.cu_count = torch.cuda.get_device_properties(\n            device=\"cuda\"\n        ).multi_processor_count\n        self.use_skinny_gemm = (\n            ROCM_USE_SKINNY_GEMM\n            and \"gfx1\" not in torch.cuda.get_device_properties(\"cuda\").gcnArchName\n        )\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        if bias:\n            bias = weights.get_tensor(f\"{prefix}.bias\")\n        else:\n            bias = None\n        return cls(weight, bias)\n\n    def forward(self, inp: torch.Tensor) -> torch.Tensor:\n        weight = self.weight\n        bias = self.bias\n\n        if (\n            self.use_skinny_gemm\n            and inp.dtype == torch.float16\n            and inp.shape[-1] % 8 == 0\n        ):\n            batched = False\n            inp_shape = inp.shape\n\n            if inp.dim() == 3:\n                inp = inp.view(-1, inp_shape[-1])\n                batched = True\n\n            m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]\n            if m > 8 and n <= 4:\n                out = torch.empty(\n                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device\n                )\n                ops.wvSpltK(weight, inp, out, n, self.cu_count)\n            elif m % 4 == 0 and n == 1 and k <= 8192:\n                out = torch.empty(\n                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device\n                )\n                ops.LLMM1(weight, inp, out, 4)\n            else:\n                out = F.linear(inp, weight)\n\n            if batched:\n                out.view(*inp_shape[:-1], out.shape[-1])\n\n            if bias is not None:\n                out = out + bias\n            return out\n        return F.linear(inp, self.weight, self.bias)\n\n\ndef get_linear(weight, bias):\n    # Weights that are loaded through methods that are not\n    # quantization-aware are still bare tensors. We may want\n    # to change this in the future.\n    if isinstance(weight, torch.Tensor):\n        if SYSTEM == \"rocm\":\n            return FastLinearROCm(weight, bias)\n        else:\n            return FastLinear(weight, bias)\n\n    return weight.get_linear(bias)\n"
  },
  {
    "path": "server/text_generation_server/layers/lora.py",
    "content": "from typing import TYPE_CHECKING, Optional, List\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom torch.distributed import ProcessGroup\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nfrom text_generation_server.utils.kernels import load_kernel\n\nif SYSTEM == \"cuda\":\n    punica_sgmv = load_kernel(\n        module=\"punica_sgmv\", repo_id=\"kernels-community/punica-sgmv\"\n    )\nelse:\n    punica_sgmv = None\n\nif SYSTEM == \"ipex\":\n    try:\n        from intel_extension_for_pytorch.llm.functional import (\n            bgmv_expand,\n            bgmv_shrink,\n            sgmv_expand,\n            sgmv_shrink,\n        )\n    except ImportError:\n        pass\n\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters import AdapterBatchData\n    from text_generation_server.adapters.lora import BatchLoraWeights\n\n\nclass LoraLinear(nn.Module):\n    def __init__(\n        self, base_layer: nn.Module, layer_id: int, process_group: ProcessGroup\n    ):\n        super().__init__()\n        self.base_layer = base_layer\n        self.layer_id = layer_id\n        self.process_group = process_group\n\n    def forward_layer_type(\n        self,\n        result: torch.Tensor,\n        input: torch.Tensor,\n        adapter_data: \"AdapterBatchData\",\n        layer_type: str,\n        start_idx: int,\n        end_idx: int,\n    ) -> torch.Tensor:\n        if adapter_data is None:\n            return result\n        data: Optional[\"BatchLoraWeights\"] = adapter_data.data.get(layer_type)\n\n        if data is not None and (\n            SYSTEM == \"ipex\"\n            or (punica_sgmv is not None and data.can_vectorize(self.process_group))\n        ):\n            # In tensor-parallel configurations, each GPU processes a specific segment of the output.\n            # The 'result' tensor represents the full output, which can vary in size based on\n            # the layer type (e.g., attention vs. feed-forward layers). We define the current\n            # segment using start_idx and end_idx. If the segment size doesn't match this GPU's\n            # slice of 'result', we create a zero tensor of the correct size for LoRA computation.\n            # This approach ensures accurate LoRA application across various layer sizes and\n            # configurations, adapting to different model architectures and parallelization strategies.\n            #\n            # Example scenarios where this is necessary:\n            # 1. The adapter's size doesn't evenly divide across GPUs.\n            # 2. We're processing the last segment which might be smaller.\n            # 3. Different projection layers (q, k, v) have different sizes.\n            if end_idx - start_idx != result.shape[1]:\n                proj = torch.zeros_like(result[:, start_idx:end_idx])\n            else:\n                proj = result\n\n            for r, rank_segments in data.rank_data.items():\n                if SYSTEM == \"ipex\":\n                    lora_a_ptr = rank_segments.lora_a_ptr[\n                        :, self.layer_id, :\n                    ].contiguous()\n                    lora_b_ptr = rank_segments.lora_b_ptr[\n                        :, self.layer_id, :\n                    ].contiguous()\n                else:\n                    lora_a_ptr = rank_segments.lora_a_ptr\n                    lora_b_ptr = rank_segments.lora_b_ptr\n\n                if lora_a_ptr is None or lora_b_ptr is None:\n                    raise ValueError(\"LoRA data is missing\")\n\n                if data.use_sgmv:\n                    if SYSTEM == \"ipex\":\n                        # Use SGMV for prefill\n                        seq_len_tensor = (\n                            rank_segments.segment_ends - rank_segments.segment_starts\n                        ).to(torch.int64)\n                        b_seq_start_loc = rank_segments.segment_starts.to(torch.int64)\n                        total_tokens = seq_len_tensor.sum()\n                        v = torch.zeros(\n                            (total_tokens, r), dtype=input.dtype, device=input.device\n                        )\n                        bs = seq_len_tensor.shape[0]\n                        sgmv_shrink(\n                            input,\n                            lora_a_ptr,\n                            v,\n                            b_seq_start_loc,\n                            seq_len_tensor,\n                            rank_segments.indices,\n                            bs,\n                            seq_len_tensor.max().item(),\n                            1.0,\n                        )\n                    else:\n                        # Use SGMV for prefill\n                        v = punica_sgmv.lora_a_sgmv_cutlass(\n                            input,\n                            rank_segments.tmp_shrink,\n                            lora_a_ptr,\n                            rank_segments.segment_starts,\n                            rank_segments.segment_ends,\n                            self.layer_id,\n                            r,\n                        )\n\n                    if self.process_group.size() > 1:\n                        v = self.collect_lora_a(v)\n                    if SYSTEM == \"ipex\":\n                        sgmv_expand(\n                            v,\n                            lora_b_ptr,\n                            proj,\n                            b_seq_start_loc,\n                            seq_len_tensor,\n                            rank_segments.indices,\n                            bs,\n                            seq_len_tensor.max().item(),\n                            add_inputs=True,\n                        )\n                    else:\n                        punica_sgmv.lora_b_sgmv_cutlass(\n                            proj,\n                            v,\n                            rank_segments.tmp_expand,\n                            lora_b_ptr,\n                            rank_segments.segment_starts,\n                            rank_segments.segment_ends,\n                            self.layer_id,\n                        )\n                else:\n                    # Use BGMV for decode\n                    v = torch.zeros(\n                        (input.size(0), r), dtype=input.dtype, device=input.device\n                    )\n                    if SYSTEM == \"ipex\":\n                        bgmv_shrink(\n                            input,\n                            lora_a_ptr,\n                            v,\n                            rank_segments.indices,\n                            1.0,\n                        )\n                    else:\n                        # TODO: error with [-1, 0], but not [0, -1]\n                        punica_sgmv.add_lora_a_bgmv(\n                            v,\n                            input,\n                            lora_a_ptr,\n                            rank_segments.indices,\n                            self.layer_id,\n                        )\n\n                    if self.process_group.size() > 1:\n                        v = self.collect_lora_a(v)\n\n                    if SYSTEM == \"ipex\":\n                        bgmv_expand(\n                            v,\n                            lora_b_ptr,\n                            proj,\n                            rank_segments.indices,\n                            add_inputs=True,\n                        )\n                    else:\n                        punica_sgmv.add_lora_b_bgmv(\n                            proj,\n                            v,\n                            lora_b_ptr,\n                            rank_segments.indices,\n                            self.layer_id,\n                        )\n\n            if end_idx - start_idx != result.shape[1]:\n                result[:, start_idx:end_idx] += proj\n        else:\n            for adapter_index in adapter_data.meta.adapter_set:\n                if data is not None and data.has_adapter(adapter_index):\n                    adapter_mask = (\n                        (adapter_data.meta.adapter_indices == adapter_index)\n                        .to(input.dtype)\n                        .view(-1, 1)\n                    )\n                    layer_result = self.forward_lora(\n                        input, data, adapter_index, adapter_mask\n                    )\n                    result[:, start_idx:end_idx] += layer_result\n\n        return result\n\n    def forward_lora(\n        self,\n        input: torch.Tensor,\n        data: \"BatchLoraWeights\",\n        adapter_index: int,\n        adapter_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        lora_a = data.lora_a[adapter_index][self.layer_id, :, :]\n        lora_b = data.lora_b[adapter_index][self.layer_id, :, :]\n\n        lora_a = punica_sgmv.orient_for_rank(lora_a, lora_b.size(0))\n\n        a_out = input @ lora_a\n        if self.process_group.size() > 1:\n            a_out = self.collect_lora_a(a_out)\n\n        result = (a_out @ lora_b) * adapter_mask\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError(\"Implemented in subclasses\")\n\n\nclass TensorParallelMultiAdapterLinear(LoraLinear):\n    def __init__(\n        self,\n        base_layer: nn.Module,\n        layer_id: int,\n        layer_names: List[str],\n        sizes: List[int],\n        process_group: ProcessGroup,\n    ):\n        super().__init__(base_layer, layer_id, process_group)\n        self.layer_names = layer_names\n        self.sizes = sizes\n\n    @classmethod\n    def load(\n        cls,\n        base_layer: nn.Module,\n        layer_id: int,\n        layer_names: List[str],\n        sizes: List[int],\n        process_group: ProcessGroup,\n    ):\n        return TensorParallelMultiAdapterLinear(\n            base_layer, layer_id, layer_names, sizes, process_group\n        )\n\n    def forward(\n        self, input: torch.Tensor, adapter_data: \"AdapterBatchData\"\n    ) -> torch.Tensor:\n        result = self.base_layer(input)\n\n        # noop if no layer names are provided (e.g. for models without adapters)\n        if self.layer_names is None:\n            return result\n\n        # handle models like Bloom that have inputs of shape\n        # (batch_size, sequence_length, hidden_size)\n        # we need to reshape them to (batch_size * sequence_length, hidden_size)\n        # for the LoRA computation, then reshape back\n        prev_shape = result.shape\n        is_3d = len(input.shape) >= 3\n        if is_3d:\n            input = input.reshape(-1, input.shape[-1])\n            result = result.reshape(-1, result.shape[-1])\n\n        offset = 0\n        for i, layer_name in enumerate(self.layer_names):\n            start_idx = offset // self.process_group.size()\n            # The 'sizes' parameter is essential in tensor-parallel setups for handling multiple\n            # projection layers (q_proj, k_proj, v_proj) by defining their output dimensions. It\n            # ensures correct slicing of the result tensor, accommodating variations like grouped-query\n            # attention where k_proj and v_proj differ from q_proj. This allows precise application of\n            # LoRA adapters to each sub-component of the multi-head attention mechanism, managing the\n            # different projection sizes across layers and model architectures.\n            if self.sizes is not None:\n                offset += self.sizes[i]\n                end_idx = offset // self.process_group.size()\n            else:\n                end_idx = result.shape[1]\n\n            result = self.forward_layer_type(\n                result, input, adapter_data, layer_name, start_idx, end_idx\n            )\n\n        if is_3d:\n            result = result.reshape(prev_shape)\n\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        # Tensor parallel implementation of X @ A@B, where A and B are sharded column-wise.\n        # We use an all-gather between X@A and (X@A)@B to ensure alignment across ranks.\n        #\n        # TODO(travis): this is not very efficient as we do an all-gather for every adapter,\n        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same\n        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:\n        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609\n        gathered_tensors = [\n            torch.empty_like(a_out) for _ in range(self.process_group.size())\n        ]\n        torch.distributed.all_gather(gathered_tensors, a_out)\n        return torch.cat(gathered_tensors, dim=1)\n\n\nclass TensorParallelAdapterRowLinear(LoraLinear):\n    def __init__(self, base_layer, layer_id, layer_name, process_group):\n        super().__init__(base_layer, layer_id, process_group)\n        self.layer_name = layer_name\n\n    @classmethod\n    def load(cls, base_layer, layer_id, layer_name, process_group):\n        return cls(base_layer, layer_id, layer_name, process_group)\n\n    def forward(\n        self, input: torch.Tensor, adapter_data: \"AdapterBatchData\"\n    ) -> torch.Tensor:\n        result = self.base_layer(input)\n\n        if self.layer_name is None:\n            return result\n\n        # Fused all-gather + all-reduce from S-LoRA paper: https://arxiv.org/abs/2311.03285\n        stride = result.shape[-1] // self.process_group.size()\n        start_idx = self.process_group.rank() * stride\n        end_idx = (self.process_group.rank() + 1) * stride\n\n        self.forward_layer_type(\n            result, input, adapter_data, self.layer_name, start_idx, end_idx\n        )\n\n        return result\n\n    def collect_lora_a(self, a_out: torch.Tensor) -> torch.Tensor:\n        # Tensor parallel implementation of X @ A@B, where A and B are sharded row-wise.\n        # We use an all-reduce between X@A and (X@A)@B to ensure alignment across ranks.\n        #\n        # TODO(travis): this is not very efficient as we do an all-reduce for every adapter,\n        #   instead we could pre-allocate a (B, a, r) tensor for all adapters with the same\n        #   rank, compute `a_out` on each, and then slice them into the buffer as shown here:\n        #   https://discuss.pytorch.org/t/concatenate-tensors-without-memory-copying/34609\n        torch.distributed.all_reduce(a_out, group=self.process_group)\n        return a_out\n"
  },
  {
    "path": "server/text_generation_server/layers/marlin/__init__.py",
    "content": "from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear\nfrom text_generation_server.layers.marlin.gptq import (\n    GPTQMarlinWeightsLoader,\n    can_use_gptq_marlin,\n    repack_gptq_for_marlin,\n)\nfrom text_generation_server.layers.marlin.marlin import MarlinWeightsLoader\n\n__all__ = [\n    \"GPTQMarlinFP8Linear\",\n    \"GPTQMarlinWeightsLoader\",\n    \"MarlinWeightsLoader\",\n    \"can_use_gptq_marlin\",\n    \"repack_gptq_for_marlin\",\n]\n"
  },
  {
    "path": "server/text_generation_server/layers/marlin/fp8.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom text_generation_server.layers.fp8 import fp8_quantize\nfrom text_generation_server.layers.marlin.gptq import _check_valid_shape\nfrom text_generation_server.layers.marlin.util import (\n    _check_marlin_kernels,\n    permute_scales,\n)\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\n\nMARLIN_TILE_SIZE = 16\n\n\nclass GPTQMarlinFP8Linear(nn.Module):\n    \"\"\"\n    FP8 GPTQ-Marlin linear layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        qweight: torch.Tensor,\n        scales: torch.Tensor,\n        bias: Optional[torch.Tensor],\n    ) -> None:\n        super().__init__()\n\n        _check_marlin_kernels()\n        assert quantization is not None\n\n        scales = scales.unsqueeze(0)\n        if scales.shape[1] == 1:\n            out_features, in_features = qweight.shape\n            scales = scales.repeat(1, out_features)\n        qweight, scales = repack_fp8_for_marlin(qweight, scales)\n\n        in_features = qweight.shape[0] * MARLIN_TILE_SIZE\n        out_features = scales.shape[1]\n        _check_valid_shape(in_features=in_features, out_features=out_features)\n\n        self.qweight = qweight\n        self.scales = scales\n        self.bias = bias if bias is not None else None\n\n        self.workspace = torch.zeros(\n            out_features // 64 * 16, dtype=torch.int, device=qweight.device\n        )\n\n    @classmethod\n    def from_unquant(cls, weight, bias, dtype):\n        qweight, scales = fp8_quantize(weight)\n        return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)\n\n    @classmethod\n    def from_fp8(\n        cls,\n        weight: torch.Tensor,\n        scale: torch.Tensor,\n        bias: torch.Tensor,\n        dtype: torch.dtype,\n        **kwargs,\n    ):\n        return cls(qweight=weight, scales=scale.to(dtype), bias=bias)\n\n    def forward(self, A: torch.Tensor) -> torch.Tensor:\n        assert quantization is not None\n\n        A_flat = A.view(-1, A.shape[-1])\n        C = quantization.fp8_marlin_gemm(\n            A_flat,\n            self.qweight,\n            self.scales,\n            self.workspace,\n            8,\n            A_flat.shape[0],\n            self.scales.shape[1],\n            A_flat.shape[1],\n        )\n        C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))\n\n        if self.bias is not None:\n            C += self.bias\n\n        return C\n\n\ndef pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Repack FP8 weights to gptq format (packed int32 elements).\n    \"\"\"\n    assert fp8_tensor.dtype == torch.float8_e4m3fn\n\n    if fp8_tensor.shape[0] % 4 != 0:\n        raise ValueError(\n            f\"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}\"\n        )\n\n    # Reshape to prepare for packing\n    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])\n\n    # Convert fp8 to uint8 (byte) representation\n    byte_tensor = reshaped.view(torch.uint8)\n\n    # Pack 4 uint8 values into one int32\n    packed = torch.zeros(\n        fp8_tensor.shape[0] // 4,\n        fp8_tensor.shape[1],\n        dtype=torch.int32,\n        device=fp8_tensor.device,\n    )\n\n    for i in range(4):\n        packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)\n\n    return packed\n\n\ndef repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):\n    \"\"\"\n    Repack FP8 tensor for GPTQ-Marlin.\n    \"\"\"\n\n    out_features, in_features = weight.shape\n\n    # Torch linear layers weights with shape [out_features, in_features],\n    # GPTQ-quantized weights use [in_feateres/pack_factor, in_features],\n    # so transpose before packing.\n    qweight = pack_fp8_as_int32(weight.t())\n\n    perm = torch.empty(0, dtype=torch.int, device=qweight.device)\n    repacked = quantization.gptq_marlin_repack(\n        qweight, perm, in_features, out_features, 8\n    )\n\n    scales = permute_scales(scales)\n\n    return repacked, scales\n"
  },
  {
    "path": "server/text_generation_server/layers/marlin/gptq.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport numpy\nimport torch\nimport torch.nn as nn\nfrom loguru import logger\nfrom text_generation_server.layers.marlin.util import (\n    _check_marlin_kernels,\n    marlin_zero_points,\n    permute_scales,\n    unpack_cols,\n)\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import Weight, Weights, WeightsLoader\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\n\ntry:\n    major, _minor = torch.cuda.get_device_capability()\n    has_sm_8_0 = major >= 8\nexcept Exception:\n    has_sm_8_0 = False\n\n\nGPTQ_MARLIN_BITS = [4, 8]\nGPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]\nMARLIN_TILE_SIZE = 16\n\n\ndef can_use_gptq_marlin(\n    *, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool\n) -> bool:\n    return (\n        SYSTEM == \"cuda\"\n        and quantization is not None\n        and has_sm_8_0\n        and quantize in {\"awq\", \"gptq\"}\n        and quant_method in {\"awq\", \"gptq\"}\n        and bits in GPTQ_MARLIN_BITS\n        and groupsize in GPTQ_MARLIN_GROUP_SIZES\n        # We only support asymmetric quantization for AWQ.\n        and (sym or quant_method == \"awq\")\n    )\n\n\nclass GPTQMarlinWeightsLoader(WeightsLoader):\n    \"\"\"\n    Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        bits: int,\n        desc_act: bool,\n        groupsize: int,\n        quant_method: str,\n        quantize: str,\n        sym: bool,\n    ):\n        self.bits = bits\n        self.desc_act = desc_act\n        self.groupsize = groupsize\n        self.quant_method = quant_method\n        self.quantize = quantize\n        self.sym = sym\n\n    def get_weights(self, weights: Weights, prefix: str):\n        log_once(logger.info, \"Using GPTQ-Marlin kernels\")\n        try:\n            qweight = weights.get_tensor(f\"{prefix}.qweight\")\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized\"\n            )\n\n        if not self.sym:\n            qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n        else:\n            qzeros = None\n\n        if self.quant_method == \"awq\":\n            g_idx = None\n        else:\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        scales = weights.get_tensor(f\"{prefix}.scales\")\n\n        return repack_gptq_for_marlin(\n            qweight=qweight,\n            scales=scales,\n            qzeros=qzeros,\n            g_idx=g_idx,\n            bits=self.bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=self.quant_method,\n            sym=self.sym,\n            sharded_infeatures=False,\n        )\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        try:\n            qweight = weights.get_packed_sharded(\n                f\"{prefix}.qweight\", dim=1, block_sizes=block_sizes\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized.\"\n            )\n        scales = weights.get_packed_sharded(\n            f\"{prefix}.scales\", dim=1, block_sizes=block_sizes\n        )\n        scales = scales.to(dtype=weights.dtype)\n\n        if not self.sym:\n            qzeros = weights.get_packed_sharded(\n                f\"{prefix}.qzeros\", dim=1, block_sizes=block_sizes\n            )\n        else:\n            qzeros = None\n\n        if self.quant_method == \"awq\":\n            g_idx = None\n        else:\n            g_idx = weights.get_tensor(f\"{prefix}.g_idx\")\n        return repack_gptq_for_marlin(\n            qweight=qweight,\n            scales=scales,\n            qzeros=qzeros,\n            g_idx=g_idx,\n            bits=self.bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=self.quant_method,\n            sym=self.sym,\n            sharded_infeatures=False,\n        )\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        try:\n            qweight = torch.cat(\n                [weights.get_sharded(f\"{p}.qweight\", dim=1) for p in prefixes], dim=1\n            )\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight, make sure the model is already quantized\"\n            )\n\n        scales = torch.cat(\n            [weights.get_sharded(f\"{p}.scales\", dim=1) for p in prefixes], dim=1\n        )\n\n        if not self.sym:\n            qzeros = torch.cat(\n                [weights.get_sharded(f\"{p}.qzeros\", dim=1) for p in prefixes], dim=1\n            )\n        else:\n            qzeros = None\n\n        if self.quant_method == \"awq\":\n            g_idx = None\n        else:\n            w = [weights.get_tensor(f\"{p}.g_idx\") for p in prefixes]\n            for w2 in w[1:]:\n                torch.testing.assert_close(w2, w[0])\n            g_idx = w[0]\n\n        return repack_gptq_for_marlin(\n            qweight=qweight,\n            scales=scales,\n            qzeros=qzeros,\n            g_idx=g_idx,\n            bits=self.bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=self.quant_method,\n            sym=self.sym,\n            sharded_infeatures=False,\n        )\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        log_once(logger.info, \"Using GPTQ-Marlin kernels\")\n        try:\n            qweight = weights.get_sharded(f\"{prefix}.qweight\", dim=0)\n        except RuntimeError:\n            raise RuntimeError(\n                f\"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized\"\n            )\n\n        if not self.sym:\n            if self.desc_act or self.groupsize == -1:\n                qzeros = weights.get_tensor(f\"{prefix}.qzeros\")\n            else:\n                qzeros = weights.get_sharded(f\"{prefix}.qzeros\", dim=0)\n        else:\n            qzeros = None\n\n        if self.quant_method == \"awq\":\n            g_idx = None\n        else:\n            g_idx = weights.get_sharded(f\"{prefix}.g_idx\", dim=0)\n\n        if self.desc_act or self.groupsize == -1:\n            scales = weights.get_tensor(f\"{prefix}.scales\")\n        else:\n            scales = weights.get_sharded(f\"{prefix}.scales\", dim=0)\n\n        sharded_in_features = weights.process_group.size() > 1\n\n        return repack_gptq_for_marlin(\n            qweight=qweight,\n            scales=scales,\n            qzeros=qzeros,\n            g_idx=g_idx,\n            bits=self.bits,\n            desc_act=self.desc_act,\n            groupsize=self.groupsize,\n            quant_method=self.quant_method,\n            sym=self.sym,\n            sharded_infeatures=sharded_in_features,\n        )\n\n    def _get_gptq_params(self, weights: Weights):\n        if weights.has_tensor(\"gptq_bits\") and weights.has_tensor(\"gptq_groupsize\"):\n            self.bits = weights.get_tensor(\"gptq_bits\").item()\n            self.groupsize = weights.get_tensor(\"gptq_groupsize\").item()\n            self.desc_act = False\n            # `server quantize` used asymmetric quantization unconditionally\n            # before the `gptq_sym` setting tensor was added.\n            self.sym = (\n                weights.get_tensor(\"gptq_sym\").item()\n                if weights.has_tensor(\"gptq_sym\")\n                else False\n            )\n            self.quant_method = \"gptq\"\n\n\n@dataclass\nclass GPTQMarlinWeight(Weight):\n    \"\"\"\n    Repacked GPTQ Marlin weights.\n    \"\"\"\n\n    qweight: torch.Tensor\n    qzeros: torch.Tensor\n    scales: torch.Tensor\n    g_idx: torch.Tensor\n    perm: torch.Tensor\n    bits: int\n    is_full_k: bool\n\n    def __post_init__(self):\n        assert self.qweight.dtype == torch.int32\n        assert self.scales.dtype in (torch.float16, torch.bfloat16)\n        assert self.g_idx.dtype == torch.int32\n        assert self.perm.dtype == torch.int32\n\n    def get_linear(self, bias: torch.Tensor):\n        return GPTQMarlinLinear(\n            weight=self,\n            bias=bias,\n        )\n\n\ndef repack_gptq_for_marlin(\n    *,\n    qweight: torch.Tensor,\n    qzeros: Optional[torch.Tensor],\n    scales: torch.Tensor,\n    g_idx: Optional[torch.Tensor],\n    bits: int,\n    desc_act: bool,\n    groupsize: int,\n    quant_method: str,\n    sym: bool,\n    sharded_infeatures: bool,\n) -> GPTQMarlinWeight:\n    \"\"\"Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels.\"\"\"\n    _check_marlin_kernels()\n    assert quantization is not None\n\n    if bits not in GPTQ_MARLIN_BITS:\n        supported_bits = \", \".join(str(b) for b in GPTQ_MARLIN_BITS)\n        raise RuntimeError(\n            f\"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}\"\n        )\n\n    if groupsize not in GPTQ_MARLIN_GROUP_SIZES:\n        supported_sizes = \", \".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)\n        raise RuntimeError(\n            f\"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}\"\n        )\n    if not (sym or quant_method == \"awq\" or quant_method == \"compressed-tensors\"):\n        raise RuntimeError(\n            \"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported.\"\n        )\n\n    log_once(logger.info, f\"Converting {quant_method} model to Marlin packing format.\")\n\n    weights_per_int = 32 // bits\n    in_features = qweight.shape[0]\n    out_features = qweight.shape[1]\n\n    # AWQ uses column packing, GPTQ uses row packing\n    if quant_method == \"awq\":\n        out_features *= weights_per_int\n    else:\n        in_features *= weights_per_int\n\n    if in_features % groupsize != 0:\n        raise ValueError(\n            f\"Number of input features ({in_features}) not divisible by group size ({groupsize})\"\n        )\n\n    if g_idx is not None and desc_act and groupsize != -1:\n        perm = torch.argsort(g_idx).to(torch.int)\n        g_idx = g_idx[perm]\n    else:\n        perm = torch.empty(0, dtype=torch.int, device=qweight.device)\n        g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)\n\n    if quant_method == \"awq\":\n        repacked = quantization.awq_marlin_repack(\n            qweight, in_features, out_features, bits\n        )\n        if qzeros is not None:\n            qzeros = awq_to_marlin_zero_points(\n                qzeros,\n                in_features // groupsize,\n                out_features,\n                bits,\n            )\n\n    else:\n        repacked = quantization.gptq_marlin_repack(\n            qweight, perm, in_features, out_features, bits\n        )\n\n    if qzeros is None:\n        qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)\n\n    scales = permute_scales(scales)\n\n    is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)\n\n    return GPTQMarlinWeight(\n        qweight=repacked,\n        qzeros=qzeros,\n        scales=scales,\n        g_idx=g_idx,\n        perm=perm,\n        bits=bits,\n        is_full_k=is_full_k,\n    )\n\n\nclass GPTQMarlinLinear(nn.Module):\n    \"\"\"\n    Linear layer for GPTQ weights that were converted for the GPTQ-Marlin\n    kernels.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        weight: GPTQMarlinWeight,\n        bias: Optional[torch.Tensor],\n    ):\n        super().__init__()\n\n        _check_marlin_kernels()\n        assert quantization is not None\n\n        in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE\n        out_features = weight.scales.shape[1]\n        _check_valid_shape(in_features=in_features, out_features=out_features)\n\n        if weight.bits not in (4, 8):\n            raise ValueError(\"GPTQMarlinLinear only supports 4 and 8-bit quantization\")\n\n        if weight.qzeros.numel() > 0:\n            if weight.bits == 4:\n                self.quant_type = quantization.scalar_types.uint4\n            else:\n                self.quant_type = quantization.scalar_types.uint8\n        else:\n            if weight.bits == 4:\n                self.quant_type = quantization.scalar_types.uint4b8\n            else:\n                self.quant_type = quantization.scalar_types.uint8b128\n\n        self.is_full_k = weight.is_full_k\n\n        self.qweight = weight.qweight\n        self.qzeros = weight.qzeros\n        self.scales = weight.scales\n        self.g_idx = weight.g_idx\n        self.perm = weight.perm\n        if bias is not None:\n            self.bias = bias\n        else:\n            self.bias = None\n\n        self.workspace = torch.zeros(\n            out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device\n        )\n\n    def forward(self, A: torch.Tensor) -> torch.Tensor:\n        assert quantization is not None\n\n        A_flat = A.view(-1, A.shape[-1])\n        C = quantization.gptq_marlin_gemm(\n            A_flat,\n            self.qweight,\n            self.scales,\n            self.qzeros,\n            self.g_idx,\n            self.perm,\n            self.workspace,\n            self.quant_type,\n            A_flat.shape[0],\n            self.scales.shape[1],\n            A_flat.shape[1],\n            self.is_full_k,\n            self.qzeros.numel() > 0,\n            True,\n        )\n        C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))\n\n        if self.bias is not None:\n            C += self.bias\n\n        return C\n\n\ndef awq_to_marlin_zero_points(\n    q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int\n) -> torch.Tensor:\n    # AWQ zero-points are quantized and packed on the column dim.\n    # In addition, the values are permuted based on dequantizer.\n    # Here we undo both of these, and then apply marlin permutation\n    # and pack it back.\n    q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)\n\n    # Undo interleaving (use argsort(..) to get inverse perm)\n    if num_bits == 4:\n        undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))\n    elif num_bits == 8:\n        undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()\n    q_zp = q_zp.reshape((-1, size_n)).contiguous()\n\n    marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)\n    return marlin_zp\n\n\ndef _check_valid_shape(in_features: int, out_features: int):\n    if (in_features % 128 != 0 or out_features % 64 != 0) and (\n        in_features % 64 != 0 or out_features % 128 != 0\n    ):\n        raise ValueError(\n            f\"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features}).\"\n            \" The shape elements must be divisible by (128, 64) or (64, 128).\"\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/marlin/marlin.py",
    "content": "from dataclasses import dataclass\nfrom typing import List, Optional, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.layers.marlin.util import _check_marlin_kernels\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.weights import Weight, Weights, WeightsLoader\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\n\nclass MarlinWeightsLoader(WeightsLoader):\n    \"\"\"Loader for Marlin-quantized weights.\"\"\"\n\n    def __init__(self, *, bits: int, is_marlin_24: bool):\n        self.bits = bits\n        self.is_marlin_24 = is_marlin_24\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        is_marlin_24 = getattr(self, \"gptq_checkpoint_format\", None) == \"marlin_24\"\n        if is_marlin_24:\n            try:\n                B = weights.get_tensor(f\"{prefix}.B_24\")\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized.\"\n                )\n\n            B_meta = weights.get_tensor(f\"{prefix}.B_meta\")\n            s = weights.get_tensor(f\"{prefix}.s\")\n            weight = GPTQMarlin24Weight(\n                weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits\n            )\n        else:\n            try:\n                B = weights.get_tensor(f\"{prefix}.B\")\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` weight, make sure the model is already quantized.\"\n                )\n\n            s = weights.get_tensor(f\"{prefix}.s\")\n            weight = MarlinWeight(B=B, s=s)\n\n        return weight\n\n    def get_weights_col_packed(\n        self,\n        weights: Weights,\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        if self.is_marlin_24:\n            B = weights.get_packed_sharded(\n                f\"{prefix}.B_24\", dim=1, block_sizes=block_sizes\n            )\n            B_meta = weights.get_packed_sharded(\n                f\"{prefix}.B_meta\", dim=1, block_sizes=block_sizes\n            )\n            s = weights.get_packed_sharded(\n                f\"{prefix}.s\", dim=1, block_sizes=block_sizes\n            )\n\n            weight = GPTQMarlin24Weight(\n                weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits\n            )\n        else:\n            B = weights.get_packed_sharded(\n                f\"{prefix}.B\", dim=1, block_sizes=block_sizes\n            )\n            s = weights.get_packed_sharded(\n                f\"{prefix}.s\", dim=1, block_sizes=block_sizes\n            )\n            weight = MarlinWeight(B=B, s=s)\n\n        return weight\n\n    def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):\n        if self.is_marlin_24:\n            try:\n                B = torch.cat(\n                    [weights.get_sharded(f\"{p}.B_24\", dim=1) for p in prefixes], dim=1\n                )\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` weight, make sure the model is already quantized\"\n                )\n\n            B_meta = torch.cat(\n                [weights.get_sharded(f\"{p}.B_meta\", dim=1) for p in prefixes], dim=1\n            )\n\n            s = torch.cat(\n                [weights.get_sharded(f\"{p}.s\", dim=1) for p in prefixes], dim=1\n            )\n\n            weight = GPTQMarlin24Weight(\n                weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits\n            )\n        else:\n            try:\n                B = torch.cat(\n                    [weights.get_sharded(f\"{p}.B\", dim=1) for p in prefixes], dim=1\n                )\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` weight, make sure the model is already quantized\"\n                )\n            s = torch.cat(\n                [weights.get_sharded(f\"{p}.s\", dim=1) for p in prefixes], dim=1\n            )\n\n            weight = MarlinWeight(B=B, s=s)\n\n        return weight\n\n    def get_weights_row(self, weights: Weights, prefix: str):\n        if self.is_marlin_24:\n            try:\n                B = weights.get_sharded(f\"{prefix}.B_24\", dim=0)\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized.\"\n                )\n\n            B_meta = weights.get_sharded(f\"{prefix}.B_meta\", dim=0)\n            num_groups = weights._get_slice(f\"{prefix}.s\").get_shape()[0]\n            if num_groups == 1:\n                # The number of groups is 1 when groupsize == -1. share\n                # scales between all shards in this case.\n                s = weights.get_tensor(f\"{prefix}.s\")\n            else:\n                s = weights.get_sharded(f\"{prefix}.s\", dim=0)\n\n            weight = GPTQMarlin24Weight(\n                weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits\n            )\n        else:\n            try:\n                B = weights.get_sharded(f\"{prefix}.B\", dim=0)\n            except RuntimeError:\n                raise RuntimeError(\n                    \"Cannot load `marlin` weight, make sure the model is already quantized.\"\n                )\n\n            num_groups = weights._get_slice(f\"{prefix}.s\").get_shape()[0]\n            if num_groups == 1:\n                # The number of groups is 1 when groupsize == -1. share\n                # scales between all shards in this case.\n                s = weights.get_tensor(f\"{prefix}.s\")\n            else:\n                s = weights.get_sharded(f\"{prefix}.s\", dim=0)\n            weight = MarlinWeight(B=B, s=s)\n\n        return weight\n\n\n@dataclass\nclass MarlinWeight(Weight):\n    \"\"\"\n    Marlin weights.\n\n    Attributes:\n        B (torch.Tensor): int4-quantized weights packed into int32.\n        s (torch.Tensor): bfloat16/float16 scales.\n    \"\"\"\n\n    B: torch.Tensor\n    s: torch.Tensor\n\n    def __post_init__(self):\n        assert self.B.dtype == torch.int32\n        assert self.s.dtype in [torch.float16, torch.bfloat16]\n\n    def get_linear(self, bias: torch.Tensor):\n        return MarlinLinear(weight=self, bias=bias)\n\n\nclass MarlinLinear(nn.Module):\n    def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):\n        super().__init__()\n\n        _check_marlin_kernels()\n        assert quantization is not None\n\n        in_features = weight.B.shape[0] * MARLIN_TILE_SIZE\n        out_features = weight.s.shape[1]\n        assert (\n            in_features % 128 == 0\n        ), f\"Number of input features ({in_features}) not divisable by 128\"\n        assert (\n            out_features % 256 == 0\n        ), f\"Number of output features ({out_features}) not divisable by 256\"\n\n        groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]\n        assert groupsize in {\n            -1,\n            128,\n        }, f\"Group size must be -1 or 128, was {groupsize}\"\n\n        self.B = weight.B\n        self.s = weight.s\n        if bias is not None:\n            self.bias = bias\n        else:\n            self.bias = None\n\n        self.workspace = torch.zeros(\n            out_features // 64 * 16, dtype=torch.int, device=weight.B.device\n        )\n\n    def forward(self, A: torch.Tensor) -> torch.Tensor:\n        assert quantization is not None\n\n        C = quantization.marlin_gemm(\n            A.view(-1, A.shape[-1]),\n            self.B,\n            self.s,\n            self.workspace,\n            A.shape[0],\n            self.s.shape[1],\n            A.shape[1],\n        )\n        C = C.reshape(A.shape[:-1] + (self.s.shape[1],))\n\n        if self.bias is not None:\n            C += self.bias\n\n        return C\n\n\nGPTQ_MARLIN_24_MIN_THREAD_N = 128\nGPTQ_MARLIN_24_MIN_THREAD_K = 128\nGPTQ_MARLIN_24_MAX_PARALLEL = 64\nGPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]\nGPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]\nMARLIN_TILE_SIZE = 16\n\n\n@dataclass\nclass GPTQMarlin24Weight:\n    \"\"\"\n    GPTQ-Marlin 2:4 weights.\n\n    Attributes:\n        B (torch.Tensor): int4-quantized weights packed into int32.\n        B_meta (torch.Tensor): metadata for 2:4 sparsity.\n        s (torch.Tensor): float16 scales.\n        bits: quantized weight size.\n    \"\"\"\n\n    weight_packed: torch.Tensor\n    meta: torch.Tensor\n    scale_packed: torch.Tensor\n    bits: int\n\n    def __post_init__(self):\n        assert self.weight_packed.dtype == torch.int32\n        assert self.meta.dtype == torch.int16\n        assert self.scale_packed.dtype == torch.float16\n\n    def get_linear(self, bias: torch.Tensor):\n        return GPTQMarlin24Linear(\n            weight=self,\n            bias=bias,\n        )\n\n\nclass GPTQMarlin24Linear(nn.Module):\n    def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):\n        super().__init__()\n\n        _check_marlin_kernels()\n        assert quantization is not None\n\n        if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:\n            supported_bits = \", \".join(\n                str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS\n            )\n            raise RuntimeError(\n                f\"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}\"\n            )\n\n        in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2\n        out_features = weight.scale_packed.shape[1]\n        groupsize = (\n            -1\n            if weight.scale_packed.shape[0] == 1\n            else in_features // weight.scale_packed.shape[0]\n        )\n\n        if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:\n            supported_sizes = \", \".join(\n                str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES\n            )\n            raise RuntimeError(\n                f\"Group size {groupsize} is not supported, must be one of: {supported_sizes}\"\n            )\n\n        if weight.bits == 4:\n            self.quant_type = quantization.scalar_types.uint4b8\n        else:\n            self.quant_type = quantization.scalar_types.uint8b128\n        weights_per_int32 = 32 // weight.bits\n\n        assert (\n            out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0\n        ), f\"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads\"\n        assert (\n            out_features % weights_per_int32 == 0\n        ), f\"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})\"\n\n        assert (\n            in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0\n        ), f\"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads\"\n        if groupsize != -1 and in_features % groupsize != 0:\n            raise ValueError(\n                f\"Number of input features ({in_features}) not divisable by group size ({groupsize})\"\n            )\n\n        self.weight_packed = weight.weight_packed\n        self.meta = weight.meta\n        self.scale_packed = weight.scale_packed\n        if bias is not None:\n            self.bias = bias\n        else:\n            self.bias = None\n\n        self.workspace = torch.zeros(\n            (out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,\n            dtype=torch.int,\n            device=weight.weight_packed.device,\n        )\n\n    def forward(self, A: torch.Tensor) -> torch.Tensor:\n        assert quantization is not None\n\n        C = quantization.gptq_marlin_24_gemm(\n            A.view(-1, A.shape[-1]),\n            self.weight_packed,\n            self.meta,\n            self.scale_packed,\n            self.workspace,\n            self.quant_type,\n            A.shape[0],\n            self.scale_packed.shape[1],\n            A.shape[1],\n        )\n\n        C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))\n\n        if self.bias is not None:\n            C += self.bias\n\n        return C\n"
  },
  {
    "path": "server/text_generation_server/layers/marlin/util.py",
    "content": "import functools\nfrom typing import List, Tuple\n\nimport numpy\nimport torch\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\n\nif SYSTEM == \"cuda\":\n    quantization = load_kernel(\n        module=\"quantization\", repo_id=\"kernels-community/quantization\"\n    )\nelse:\n    quantization = None\n\ntry:\n    major, _minor = torch.cuda.get_device_capability()\n    has_sm_8_0 = major >= 8\nexcept Exception:\n    has_sm_8_0 = False\n\n\ndef _check_marlin_kernels():\n    if not (SYSTEM == \"cuda\" and has_sm_8_0):\n        raise NotImplementedError(\n            \"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later.\"\n        )\n\n    if quantization is None:\n        raise NotImplementedError(\n            \"marlin is not installed, install it with: pip install server/marlin\"\n        )\n\n\n# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54\n@functools.cache\ndef get_perms() -> Tuple[List[int], List[int]]:\n    scale_perm = []\n    for i in range(8):\n        scale_perm.extend([i + 8 * j for j in range(8)])\n    scale_perm_single = []\n    for i in range(4):\n        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return scale_perm, scale_perm_single\n\n\ndef permute_scales(scales: torch.Tensor):\n    scale_perm, scale_perm_single = get_perms()\n    out_features = scales.shape[1]\n    if scales.shape[0] == 1:\n        scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    else:\n        scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]\n    return scales.reshape((-1, out_features)).contiguous()\n\n\n# Functions below are from vLLM\n\n\ndef get_pack_factor(bits: int) -> int:\n    if 32 % bits != 0:\n        raise ValueError(f\"Cannot {bits} bit values into uint32\")\n    return 32 // bits\n\n\ndef pack_cols(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    pack_factor = get_pack_factor(num_bits)\n    assert size_n % pack_factor == 0\n\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_res |= q_w[:, i::pack_factor] << num_bits * i\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    q_res = q_res.contiguous()\n\n    return q_res\n\n\ndef unpack_cols(\n    packed_q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    pack_factor = get_pack_factor(num_bits)\n    assert size_n % pack_factor == 0\n    assert packed_q_w.shape == (\n        size_k,\n        size_n // pack_factor,\n    ), \"packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}\".format(\n        packed_q_w.shape, size_k, size_n, pack_factor\n    )\n\n    orig_device = packed_q_w.device\n\n    packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)\n    q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)\n\n    mask = (1 << num_bits) - 1\n    for i in range(pack_factor):\n        vals = packed_q_w_cpu & mask\n        packed_q_w_cpu >>= num_bits\n        q_res[:, i::pack_factor] = vals\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    q_res = q_res.contiguous()\n\n    return q_res\n\n\ndef marlin_zero_points(\n    zp: torch.Tensor, size_k: int, size_n: int, num_bits: int\n) -> torch.Tensor:\n    scale_perm, _ = get_perms()\n    # Permute zero-points in a similar way to scales, but do not use the\n    # \"single\" permutation, since zero-points are applied on every MMA\n    zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]\n\n    # Interleave column dim (for the dequantize code) and pack it to int32\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()\n    zp = zp.reshape((-1, size_n)).contiguous()\n    zp = pack_cols(zp, num_bits, size_k, size_n)\n\n    return zp\n"
  },
  {
    "path": "server/text_generation_server/layers/medusa.py",
    "content": "import torch\nfrom torch import nn\nfrom typing import Tuple, Optional\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.layers.linear import FastLinear\nfrom text_generation_server.layers.tensor_parallel import (\n    TensorParallelHead,\n    TensorParallelColumnLinear,\n)\n\n\nclass ResBlock(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.linear = FastLinear.load(\n            config, prefix=f\"{prefix}.linear\", weights=weights, bias=True\n        )\n        self.act = torch.nn.SiLU()\n\n    def forward(self, x):\n        return x + self.act(self.linear(x))\n\n\nclass MedusaModel(torch.nn.Module):\n    def __init__(self, config, medusa_config, weights):\n        super().__init__()\n        self.heads = torch.nn.ModuleList(\n            [\n                MedusaHead(config, medusa_config, prefix=f\"{i}\", weights=weights)\n                for i in range(get_speculate())\n            ]\n        )\n\n    def forward(self, x):\n        if not self.heads:\n            return None\n        speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)\n        return speculative_logits\n\n\nclass MedusaHead(torch.nn.Module):\n    def __init__(self, config, medusa_config, prefix, weights):\n        super().__init__()\n        self.blocks = torch.nn.ModuleList(\n            [\n                ResBlock(config, prefix=f\"{prefix}.{i}\", weights=weights)\n                for i in range(medusa_config[\"medusa_num_layers\"])\n            ]\n        )\n        n = len(self.blocks)\n        self.out = FastLinear.load(\n            config, prefix=f\"{prefix}.{n}\", weights=weights, bias=False\n        )\n\n    def forward(self, x):\n        for block in self.blocks:\n            x = block(x)\n        x = self.out(x)\n        return x\n\n\nclass MedusaHeadV1(nn.Module):\n    def __init__(self, lm_head, medusa):\n        super().__init__()\n        self.lm_head = lm_head\n        self.medusa = medusa\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        from pathlib import Path\n        from safetensors import safe_open\n        import json\n\n        speculator = config.speculator\n\n        path = speculator[\"path\"]\n        medusa_config = str(Path(path) / \"config.json\")\n\n        for fname in speculator[\"model_paths\"]:\n            filename = str(Path(path) / fname)\n\n            with open(medusa_config, \"r\") as f:\n                medusa_config = json.load(f)\n            routing = weights.routing\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing and routing[k] != filename:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n\n        medusa = MedusaModel(config, medusa_config, weights)\n        lm_head = TensorParallelHead.load(config, prefix, weights)\n        return MedusaHeadV1(lm_head, medusa)\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        logits = self.lm_head(input)\n        # If we have too many tokens, we skip speculative logits\n        if input.shape[0] > 128:\n            return logits, None\n\n        speculative_logits = self.medusa(input)\n        return logits, speculative_logits\n\n\nclass MedusaHeadV2(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        from pathlib import Path\n        from safetensors import safe_open\n        import json\n\n        speculator_path = config.speculator[\"path\"]\n\n        medusa_config = str(Path(speculator_path) / \"config.json\")\n        filename = str(Path(speculator_path) / \"medusa_lm_head.safetensors\")\n\n        with open(medusa_config, \"r\") as f:\n            medusa_config = json.load(f)\n        routing = weights.routing\n        with safe_open(filename, framework=\"pytorch\") as f:\n            for k in f.keys():\n                if k in routing and routing[k] != filename:\n                    raise RuntimeError(\n                        f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                    )\n                routing[k] = filename\n\n        self.n_medusa_heads = get_speculate()\n\n        assert medusa_config[\"medusa_num_layers\"] == 1\n        self.linear = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{i}.0.linear\" for i in range(self.n_medusa_heads)],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.process_group = weights.process_group\n        self.world_size = self.process_group.size()\n        self.rank = self.process_group.rank()\n\n        self.act = torch.nn.SiLU()\n\n        self.lm_head = TensorParallelHead.load(config, prefix, weights)\n\n    def forward(self, x):\n        # If we have too many tokens, we skip speculative logits\n        if x.shape[0] > 128:\n            logits = self.lm_head(x)\n            return logits, None\n\n        size = x.shape[-1]\n        block_size = (size + self.world_size - 1) // self.world_size\n        start = self.rank * block_size\n        stop = (self.rank + 1) * block_size\n\n        x_block = x[:, start:stop]\n\n        # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1\n        medusa_res = self.act(self.linear(x)).reshape(\n            *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]\n        )\n\n        # Apply all residual medusa heads\n        output = x[:, start:stop].unsqueeze(-2) + medusa_res\n\n        # Gather medusa heads\n        world_output = [\n            torch.empty_like(output) for _ in range(self.process_group.size())\n        ]\n        torch.distributed.all_gather(world_output, output, group=self.process_group)\n        world_output = torch.cat(world_output, dim=-1)\n\n        # Stack x and medusa residual x\n        stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)\n\n        # Compute lm head on x + medusa residual x\n        logits = self.lm_head(stacked_x)\n\n        # Finally, split logits from speculative logits\n        logits, speculative_logits = torch.split(\n            logits, [1, self.n_medusa_heads], dim=-2\n        )\n        # Squeeze added dimension\n        logits = logits.squeeze(-2)\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/layers/mlp.py",
    "content": "import torch\nimport math\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Optional, Tuple\nfrom text_generation_server.layers import TensorParallelEmbedding, FastLinear\nfrom text_generation_server.layers.tensor_parallel import TensorParallelHead\nfrom text_generation_server.utils.speculate import get_speculate\n\n\nclass MLPSpeculatorLayerNorm(nn.Module):\n    \"\"\"\n    A L2 normalization implementation\n    ...\n    Args\n    ----\n    normalized_shape : int\n        Dimensionality of input data (size of final tensor axis)\n    elementwise_scale_weight : torch.Tensor\n        learned scaling term after normalization?\n    elementwise_shift_bias : torch.Tensor\n        learned bias term after normalization?\n    eps : float\n        Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8).\n    \"\"\"\n\n    def __init__(\n        self,\n        prefix,\n        config,\n        weights,\n        eps=1e-06,\n    ):\n        super(MLPSpeculatorLayerNorm, self).__init__()\n        self.weight = weights.get_tensor(f\"{prefix}.weight\")\n        self.bias = weights.get_tensor(f\"{prefix}.bias\")\n        self.eps = eps\n\n    def forward(self, x):\n        xf = x\n        xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)\n        x = xf.type_as(x)\n        x = self.weight * x\n        x = x + self.bias\n        return x\n\n\nINV_SQRT2 = 2**-0.5\n\n\ndef simple_norm(x: torch.Tensor, eps=1e-06):\n    xf = x\n    xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)\n    x = xf.type_as(x)\n    return x * INV_SQRT2\n\n\nclass MLPSpeculatorModelTied(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.config = config\n        self.n_predict = get_speculate()\n        self.hidden_size = config.hidden_size\n\n        self.emb = TensorParallelEmbedding(f\"{prefix}.emb.0\", weights)\n        self.proj0 = FastLinear.load(\n            config,\n            prefix=f\"{prefix}.proj.0\",\n            weights=weights,\n            bias=False,\n        )\n        self.proj1 = FastLinear.load(\n            config,\n            prefix=f\"{prefix}.proj.1\",\n            weights=weights,\n            bias=False,\n        )\n        self.head = FastLinear.load(config, f\"{prefix}.head.0\", weights, bias=False)\n        self.ln = MLPSpeculatorLayerNorm(\n            prefix=f\"{prefix}.ln.0\",\n            config=config,\n            weights=weights,\n        )\n\n        # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation\n        self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1\n        self.activation = nn.GELU()\n        self.vsize = config.vocab_size\n        self.inner_dim = config.speculator_config[\"inner_dim\"]\n        self.top_k_tokens_per_head = [1] * self.n_predict\n        self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(\n            self.inner_dim / 2\n        )\n        self.emb.weight *= self.emb_weight\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_ids: torch.Tensor,\n    ):\n        top_k_tokens_per_head = self.top_k_tokens_per_head\n\n        # k indicates # of candidates\n        # h indicates # of generated tokens\n        state = hidden_states\n        b = state.size(0)\n        ind = input_ids.unsqueeze(0)\n        all_probs = torch.empty(\n            b, self.n_predict, self.vsize, device=state.device\n        )  # b k h v\n        assert (\n            len(top_k_tokens_per_head) == self.n_predict\n        ), f\"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)\"\n        for i in range(self.n_predict):\n            # Project and predict\n            z = self.emb(ind)\n            # z = z.mul(self.emb_weight)  # b k d\n            if i == 0:\n                state = self.proj0(state) * self.state_weight + z\n            else:\n                state = self.proj1(state) * self.state_weight + z\n            state = self.activation(self.ln(state))  # b k d\n            probs = F.log_softmax(self.head(state), dim=-1)  # b k v\n            _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1)  # b k k'\n\n            # Update candidate set with new predictions\n\n            # Update distribution set with new logits\n            all_probs[:, i] = probs.exp()\n\n            # Update state, log_probs and ind for new predictions\n            state = state.unsqueeze(2).expand(\n                -1, -1, top_k_tokens_per_head[i], -1\n            )  # b k k' d\n            state = state.reshape(-1, b, state.size(3))  # b kk' d\n            ind = preds.view(-1, b)  # b kk'\n\n        speculative_logits = all_probs\n        return speculative_logits\n\n\nclass MLPSpeculatorModel(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.config = config\n        self.n_predict = get_speculate()\n        self.hidden_size = config.hidden_size\n\n        self.emb = nn.ModuleList(\n            [\n                TensorParallelEmbedding(f\"{prefix}.emb.{i}\", weights)\n                for i in range(self.n_predict)\n            ]\n        )\n        self.proj = [\n            FastLinear.load(\n                config,\n                prefix=f\"{prefix}.proj.{i}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_predict)\n        ]\n        self.head = nn.ModuleList(\n            [\n                FastLinear.load(config, f\"{prefix}.head.{i}\", weights, bias=False)\n                for i in range(self.n_predict)\n            ]\n        )\n        self.ln = nn.ModuleList(\n            [\n                MLPSpeculatorLayerNorm(\n                    prefix=f\"{prefix}.ln.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(self.n_predict)\n            ]\n        )\n\n        # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation\n        self.state_weight = 0.5 ** (0.5 / self.n_predict) if self.n_predict > 0 else 1\n        self.activation = nn.GELU()\n        self.vsize = config.vocab_size\n        self.inner_dim = config.speculator_config[\"inner_dim\"]\n        self.top_k_tokens_per_head = [1] * self.n_predict\n        self.emb_weight = math.sqrt(1 - self.state_weight**2) * math.sqrt(\n            self.inner_dim / 2\n        )\n        self.emb.weight *= self.emb_weight\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        input_ids: torch.Tensor,\n    ):\n        top_k_tokens_per_head = self.top_k_tokens_per_head\n\n        # k indicates # of candidates\n        # h indicates # of generated tokens\n        state = hidden_states\n        b = state.size(0)\n        ind = input_ids.unsqueeze(0)\n        all_probs = torch.empty(\n            b, self.n_predict, self.vsize, device=state.device\n        )  # b k h v\n        assert (\n            len(top_k_tokens_per_head) == self.n_predict\n        ), f\"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)\"\n        for i in range(self.n_predict):\n            # Project and predict\n            z = self.emb[i](ind)\n            # z = z.mul(self.emb_weight)  # b k d\n            state = self.proj[i](state) * self.state_weight + z\n            state = self.activation(self.ln[i](state))  # b k d\n            probs = F.log_softmax(self.head[i](state), dim=-1)  # b k v\n            _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1)  # b k k'\n\n            # Update candidate set with new predictions\n\n            # Update distribution set with new logits\n            all_probs[:, i] = probs.exp()\n\n            # Update state, log_probs and ind for new predictions\n            state = state.unsqueeze(2).expand(\n                -1, -1, top_k_tokens_per_head[i], -1\n            )  # b k k' d\n            state = state.reshape(-1, b, state.size(3))  # b kk' d\n            ind = preds.view(-1, b)  # b kk'\n\n        speculative_logits = all_probs\n        return speculative_logits\n\n\nclass MLPSpeculatorHead(nn.Module):\n    def __init__(self, lm_head, mlp_speculator, scale_input: bool):\n        super().__init__()\n        self.lm_head = lm_head\n        self.mlp_speculator = mlp_speculator\n        self.scale_input = scale_input\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        logits = self.lm_head(input)\n        # If we have too many tokens, we skip speculative logits\n        if input.shape[0] > 128:\n            return logits, None\n\n        input_ids = logits.argmax(dim=-1)\n        if self.scale_input:\n            input = simple_norm(input)\n        speculative_logits = self.mlp_speculator(input, input_ids)\n        return logits, speculative_logits\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        from pathlib import Path\n        from safetensors import safe_open\n\n        speculator_path = config.speculator[\"path\"]\n\n        for fname in config.speculator[\"model_paths\"]:\n            filename = str(Path(speculator_path) / fname)\n            routing = weights.routing\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing and routing[k] != filename:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n\n        tie_weights = config.speculator_config.get(\"tie_weights\", False)\n        if tie_weights:\n            mlp_speculator = MLPSpeculatorModelTied(config, \"speculator\", weights)\n        else:\n            mlp_speculator = MLPSpeculatorModel(config, \"speculator\", weights)\n        # This is used in https://huggingface.co/ibm-fms/llama3-70b-accelerator\n        scale_input = config.speculator_config.get(\"scale_input\", False)\n        lm_head = TensorParallelHead.load(config, prefix, weights)\n        return MLPSpeculatorHead(lm_head, mlp_speculator, scale_input)\n"
  },
  {
    "path": "server/text_generation_server/layers/moe/__init__.py",
    "content": "from typing import Optional, Protocol, runtime_checkable\n\nimport torch\nimport torch.nn as nn\nfrom loguru import logger\nfrom transformers.activations import ACT2FN\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.layers.fp8 import HybridFP8UnquantLoader\nfrom text_generation_server.layers.marlin import GPTQMarlinWeightsLoader\nfrom text_generation_server.layers.moe.gptq_marlin import (\n    GPTQMarlinSparseMoELayer,\n    can_use_marlin_moe_gemm,\n)\nfrom text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer\nfrom text_generation_server.layers.moe.fp8 import FP8SparseMoELayer\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.log import log_once\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    Weights,\n    UnquantizedWeight,\n)\n\nif SYSTEM == \"ipex\":\n    from .fused_moe_ipex import fused_topk, grouped_topk\nelif SYSTEM == \"cuda\":\n    moe_kernels = load_kernel(module=\"moe\", repo_id=\"kernels-community/moe\")\n    fused_topk = moe_kernels.fused_topk\n    grouped_topk = moe_kernels.grouped_topk\nelse:\n    from moe_kernels.fused_moe import fused_topk, grouped_topk\n\n\n# NOTE: we are using a protocol here, because multiple inherance is not nice.\n#       We need `Module`, and `Module` -> some abstract class -> some concrete\n#       class inheritance is whacky.\n\n\n@runtime_checkable\nclass MoELayer(Protocol):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n        hidden_act: str = \"silu\",\n        scoring_func: Optional[str] = None,\n        e_score_correction_bias: Optional[float] = None,\n    ): ...\n\n    def forward(\n        self, x: torch.Tensor, *, gating_output: torch.Tensor\n    ) -> torch.Tensor: ...\n\n\nclass DenseMoELayer(nn.Module):\n    \"\"\"\n    Layer for MoE that applies *all* experts to each tokens and then weights\n    their outputs based on the calculated routing. This layer is much slower\n    than `SparseMoELayer` and should only be used when no fused kernels are\n    available (e.g. for unsupported quantizers).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n        hidden_act: str = \"silu\",\n        scoring_func: Optional[str] = None,\n        e_score_correction_bias: Optional[float] = None,\n    ):\n        super().__init__()\n\n        assert scoring_func is None, \"scoring func is not handled\"\n        assert e_score_correction_bias is None, \"scoring correction bias is not handled\"\n\n        log_once(\n            logger.info,\n            \"No fused layers are available for this model type, using (slower) dense MoE layer\",\n        )\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.n_experts = n_experts\n        self.renormalize = renormalize\n        self.topk = topk\n        self.topk_group = topk_group\n\n        if \"gelu\" in hidden_act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        elif \"silu\" in hidden_act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[hidden_act]\n\n        self.gate_proj = [\n            TensorParallelColumnLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{gate_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n        self.up_proj = [\n            TensorParallelColumnLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{up_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n        self.down_proj = [\n            TensorParallelRowLinear.load(\n                None,\n                prefix=f\"{prefix}.{i}.{down_proj_name}\",\n                weights=weights,\n                bias=False,\n            )\n            for i in range(self.n_experts)\n        ]\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        x: (sequence_length, model_dim)\n        gating_output: (sequence_length, n_experts)\n        \"\"\"\n        # optional reshape\n        input_shape = x.shape\n        x = x.view(-1, input_shape[-1])\n\n        if self.n_expert_group is not None and self.topk_group is not None:\n            topk_weights, topk_ids = grouped_topk(\n                x,\n                gating_output,\n                self.topk,\n                renormalize=self.renormalize,\n                num_expert_group=self.n_expert_group,\n                topk_group=self.topk_group,\n            )\n        else:\n            topk_weights, topk_ids = fused_topk(\n                x, gating_output, self.topk, self.renormalize\n            )\n            topk_weights = topk_weights.to(x.dtype)\n\n        weights = torch.zeros(\n            topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device\n        )\n\n        weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))\n\n        out = torch.zeros_like(x)\n        for i in range(self.n_experts):\n            h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)\n            h = self.down_proj[i](h, reduce=False)\n            out += h * weights[:, i].view(-1, 1)\n\n        return out\n\n\nclass SparseMoELayer(nn.Module):\n    \"\"\"\n    Layer for MoE that uses fused kernels to only apply the active experts\n    for each token (rather than applying all experts and selecting the\n    outputs of active experts).\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n        if (\n            isinstance(weights.loader, DefaultWeightsLoader)\n            and isinstance(weights.loader.weight_class, UnquantizedWeight)\n        ) or isinstance(weights.loader, HybridFP8UnquantLoader):\n            if (\n                isinstance(weights.loader, HybridFP8UnquantLoader)\n                and weights.loader.to_fp8\n            ):\n                cls = FP8SparseMoELayer\n            else:\n                cls = UnquantizedSparseMoELayer\n        elif isinstance(\n            weights.loader, GPTQMarlinWeightsLoader\n        ) and can_use_marlin_moe_gemm(\n            quant_method=weights.loader.quant_method,\n            quantize=weights.loader.quantize,\n            sym=weights.loader.sym,\n        ):\n            cls = GPTQMarlinSparseMoELayer\n        else:\n            raise ValueError(\n                f\"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights\"\n            )\n\n        log_once(\n            logger.info,\n            \"Using MoE layer wih fused gemm\",\n        )\n\n        self.moe = cls(\n            n_expert_group=n_expert_group,\n            n_experts=n_experts,\n            prefix=prefix,\n            renormalize=renormalize,\n            topk=topk,\n            topk_group=topk_group,\n            weights=weights,\n            scoring_func=scoring_func,\n            e_score_correction_bias=e_score_correction_bias,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            down_proj_name=down_proj_name,\n        )\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        return self.moe(x, gating_output=gating_output)\n\n    @staticmethod\n    def is_supported(weights: Weights) -> bool:\n        return (\n            (\n                isinstance(weights.loader, DefaultWeightsLoader)\n                and isinstance(weights.loader.weight_class, UnquantizedWeight)\n            )\n            or isinstance(weights.loader, HybridFP8UnquantLoader)\n            or (\n                isinstance(weights.loader, GPTQMarlinWeightsLoader)\n                and can_use_marlin_moe_gemm(\n                    quant_method=weights.loader.quant_method,\n                    quantize=weights.loader.quantize,\n                    sym=weights.loader.sym,\n                )\n            )\n        )\n"
  },
  {
    "path": "server/text_generation_server/layers/moe/fp8.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.utils.weights import Weights\nfrom text_generation_server.layers.fp8 import (\n    Fp8Weight,\n    fp8_quantize,\n    quant_dtype,\n    normalize_e4m3fn_to_native_float8,\n)\n\ntry:\n    from .unquantized import fused_moe\nexcept Exception:\n    fused_moe = None\n\n\nclass FP8SparseMoELayer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.topk = topk\n        self.topk_group = topk_group\n        self.renormalize = renormalize\n        self.weight_block_size = weights.weights_loader.weight_block_size\n        self.scoring_func = scoring_func\n        self.e_score_correction_bias = e_score_correction_bias\n\n        (\n            self.gate_up_proj,\n            self.gate_up_proj_weight_scale,\n            self.gate_up_proj_input_scale,\n        ) = _load_expert_multi_weights_col(\n            prefix=prefix,\n            n_experts=n_experts,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            weights=weights,\n        )\n\n        self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (\n            _load_expert_weights_row(\n                prefix=prefix,\n                n_experts=n_experts,\n                name=down_proj_name,\n                weights=weights,\n            )\n        )\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        return fused_moe(\n            x,\n            w1=self.gate_up_proj,\n            w2=self.down_proj,\n            gating_output=gating_output,\n            topk=self.topk,\n            renormalize=self.renormalize,\n            inplace=True,\n            use_grouped_topk=self.n_expert_group is not None,\n            num_expert_group=self.n_expert_group,\n            topk_group=self.topk_group,\n            scoring_func=self.scoring_func,\n            e_score_correction_bias=self.e_score_correction_bias,\n            use_fp8_w8a8=True,\n            w1_scale=self.gate_up_proj_weight_scale,\n            w2_scale=self.down_proj_weight_scale,\n            a1_scale=self.gate_up_proj_input_scale,\n            a2_scale=self.down_proj_input_scale,\n        )\n\n\ndef _load_expert_weights(\n    get_weight_fn,\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n) -> torch.Tensor:\n    all_weight = None\n    all_weight_scales = None\n    max_input_scale = None\n\n    for i in range(n_experts):\n        weight = get_weight_fn(prefix, i, name, weights)\n\n        assert isinstance(weight, Fp8Weight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=quant_dtype,\n                device=weight.weight.device,\n            )\n        if all_weight_scales is None:\n            all_weight_scales = torch.empty(\n                (n_experts,) + weight.weight_scale.shape,\n                dtype=torch.float32,\n                device=weight.weight.device,\n            )\n\n        if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:\n            all_weight[i], all_weight_scales[i], current_input_scale = (\n                normalize_e4m3fn_to_native_float8(\n                    weight.weight, weight.weight_scale, weight.input_scale\n                )\n            )\n            if current_input_scale is not None:\n                if max_input_scale is None or current_input_scale > max_input_scale:\n                    max_input_scale = current_input_scale\n        else:\n            all_weight[i], all_weight_scales[i] = fp8_quantize(\n                weight.weight, scalar=True\n            )\n\n    assert all_weight is not None\n\n    return all_weight, all_weight_scales, max_input_scale\n\n\ndef _load_expert_multi_weights_col(\n    *,\n    prefix: str,\n    n_experts: int,\n    gate_proj_name: str,\n    up_proj_name: str,\n    weights: Weights,\n) -> torch.Tensor:\n    def get_weight_fn(prefix, i, name, weights):\n        return weights.get_multi_weights_col(\n            [f\"{prefix}.{i}.{gate_proj_name}\", f\"{prefix}.{i}.{up_proj_name}\"], 0\n        )\n\n    return _load_expert_weights(\n        get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights\n    )\n\n\ndef _load_expert_weights_row(\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n) -> torch.Tensor:\n    def get_weight_fn(prefix, i, name, weights):\n        return weights.get_weights_row(f\"{prefix}.{i}.{name}\")\n\n    return _load_expert_weights(\n        get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights\n    )\n"
  },
  {
    "path": "server/text_generation_server/layers/moe/fused_moe_ipex.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Tuple\n\nimport torch\n\n\ndef grouped_topk(\n    hidden_states: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n    num_expert_group: int = 0,\n    topk_group: int = 0,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    scores = torch.softmax(gating_output, dim=-1)\n    num_token = scores.shape[0]\n    group_scores = (\n        scores.view(num_token, num_expert_group, -1).max(dim=-1).values\n    )  # [n, n_group]\n    group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[\n        1\n    ]  # [n, top_k_group]\n    group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n    group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n    score_mask = (\n        group_mask.unsqueeze(-1)\n        .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)\n        .reshape(num_token, -1)\n    )  # [n, e]\n    tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]\n    topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)\n\n    if renormalize:\n        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)\n\n    return topk_weights, topk_ids\n\n\ndef fused_topk(\n    hidden_states: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    topk_weights = torch.nn.functional.softmax(\n        gating_output, dim=1, dtype=torch.float32\n    )\n    topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)\n    if renormalize:\n        topk_weights /= topk_weights.sum(dim=-1, keepdim=True)\n    return topk_weights, topk_ids\n"
  },
  {
    "path": "server/text_generation_server/layers/moe/gptq_marlin.py",
    "content": "from dataclasses import dataclass\nfrom typing import Callable, List, Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.layers import moe\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.weights import Weights\nfrom text_generation_server.layers.marlin.gptq import (\n    GPTQMarlinWeight,\n    GPTQMarlinWeightsLoader,\n)\n\nif SYSTEM == \"cuda\":\n    moe_kernels = load_kernel(module=\"moe\", repo_id=\"kernels-community/moe\")\nelse:\n    moe_kernels = None\n\n\ntry:\n    major, _minor = torch.cuda.get_device_capability()\n    has_sm_8_0 = major >= 8\nexcept Exception:\n    has_sm_8_0 = False\n\n\ndef can_use_marlin_moe_gemm(\n    *,\n    quant_method: str,\n    quantize: str,\n    sym: bool,\n):\n    return (\n        SYSTEM == \"cuda\"\n        and moe is not None\n        and has_sm_8_0\n        and quantize in {\"awq\", \"gptq\"}\n        and quant_method in {\"awq\", \"gptq\"}\n        # We only support asymmetric quantization for AWQ.\n        and (sym or quant_method == \"awq\")\n    )\n\n\n@dataclass\nclass GPTQMarlinMoEWeight:\n    qweight: torch.Tensor\n    qzeros: torch.Tensor\n    scales: torch.Tensor\n    g_idx: torch.Tensor\n    perm: torch.Tensor\n    is_full_k: bool\n\n\nclass GPTQMarlinSparseMoELayer(nn.Module):\n    \"\"\"\n    MoE layer that uses a fused GPTQ-Marlin kernel.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n        scoring_func: Optional[str] = None,\n        e_score_correction_bias: Optional[float] = None,\n    ):\n        assert scoring_func in (\n            \"sigmoid\",\n            \"softmax\",\n        ), f\"scoring func {scoring_func} is not handled\"\n        super().__init__()\n\n        if not (\n            isinstance(weights.loader, GPTQMarlinWeightsLoader)\n            and can_use_marlin_moe_gemm(\n                quant_method=weights.loader.quant_method,\n                quantize=weights.loader.quantize,\n                sym=weights.loader.sym,\n            )\n        ):\n            raise ValueError(\n                f\"Unsupported weights loader: {type(weights.loader)}, only GPTQMarlinWeightsLoader with AWQ and symmetric GPTQ quantization is supported\"\n            )\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.topk = topk\n        self.topk_group = topk_group\n        self.renormalize = renormalize\n        self.scoring_func = scoring_func\n        self.e_score_correction_bias = e_score_correction_bias\n\n        self.gate_up_proj = _load_expert_multi_weights_col(\n            prefix=prefix,\n            n_experts=n_experts,\n            names=[gate_proj_name, up_proj_name],\n            weights=weights,\n        )\n\n        self.down_proj = _load_expert_weights_row(\n            prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights\n        )\n\n        self.bits = weights.loader.bits\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        return fused_marlin_moe(\n            hidden_states=x,\n            w1=self.gate_up_proj.qweight,\n            w2=self.down_proj.qweight,\n            w1_scale=self.gate_up_proj.scales,\n            w2_scale=self.down_proj.scales,\n            w1_zeros=(\n                self.gate_up_proj.qzeros\n                if self.gate_up_proj.qzeros.numel() > 0\n                else None\n            ),\n            w2_zeros=(\n                self.down_proj.qzeros if self.down_proj.qzeros.numel() > 0 else None\n            ),\n            g_idx1=self.gate_up_proj.g_idx,\n            g_idx2=self.down_proj.g_idx,\n            sort_indices1=self.gate_up_proj.perm,\n            sort_indices2=self.down_proj.perm,\n            is_k_full=self.gate_up_proj.is_full_k or self.down_proj.is_full_k,\n            gating_output=gating_output,\n            topk=self.topk,\n            renormalize=self.renormalize,\n            use_grouped_topk=self.n_expert_group is not None,\n            num_expert_group=self.n_expert_group,\n            topk_group=self.topk_group,\n            num_bits=self.bits,\n            scoring_func=self.scoring_func,\n            e_score_correction_bias=self.e_score_correction_bias,\n        )\n\n\ndef _load_expert_multi_weights_col(\n    *,\n    prefix: str,\n    n_experts: int,\n    names: List[str],\n    weights: Weights,\n) -> GPTQMarlinMoEWeight:\n    moe_weight = None\n    for i in range(n_experts):\n        weight = weights.get_multi_weights_col(\n            [f\"{prefix}.{i}.{name}\" for name in names], 0\n        )\n        assert isinstance(weight, GPTQMarlinWeight)\n        moe_weight = _pack_weight(\n            n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight\n        )\n    assert moe_weight is not None\n    return moe_weight\n\n\ndef _load_expert_weights_row(\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n) -> GPTQMarlinMoEWeight:\n    moe_weight = None\n    for i in range(n_experts):\n        weight = weights.get_weights_row(\n            f\"{prefix}.{i}.{name}\",\n        )\n        assert isinstance(weight, GPTQMarlinWeight)\n        moe_weight = _pack_weight(\n            n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight\n        )\n    assert moe_weight is not None\n    return moe_weight\n\n\ndef _pack_weight(\n    *,\n    n_experts: int,\n    expert: int,\n    moe_weight: Optional[GPTQMarlinMoEWeight],\n    weight: GPTQMarlinWeight,\n) -> GPTQMarlinMoEWeight:\n    if moe_weight is None:\n        qweight = torch.empty(\n            (n_experts,) + weight.qweight.shape,\n            dtype=weight.qweight.dtype,\n            device=weight.qweight.device,\n        )\n        qzeros = torch.empty(\n            (n_experts,) + weight.qzeros.shape,\n            dtype=weight.qzeros.dtype,\n            device=weight.qzeros.device,\n        )\n        scales = torch.empty(\n            (n_experts,) + weight.scales.shape,\n            dtype=weight.scales.dtype,\n            device=weight.scales.device,\n        )\n        g_idx = torch.empty(\n            (n_experts,) + weight.g_idx.shape,\n            dtype=weight.g_idx.dtype,\n            device=weight.g_idx.device,\n        )\n        perm = torch.empty(\n            (n_experts,) + weight.perm.shape,\n            dtype=weight.perm.dtype,\n            device=weight.perm.device,\n        )\n\n        moe_weight = GPTQMarlinMoEWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            perm=perm,\n            is_full_k=weight.is_full_k,\n        )\n\n    moe_weight.qweight[expert] = weight.qweight\n    moe_weight.qzeros[expert] = weight.qzeros\n    moe_weight.scales[expert] = weight.scales\n    moe_weight.g_idx[expert] = weight.g_idx\n    moe_weight.perm[expert] = weight.perm\n\n    return moe_weight\n\n\ndef fused_marlin_moe(\n    *,\n    hidden_states: torch.Tensor,\n    w1: torch.Tensor,\n    w2: torch.Tensor,\n    w1_scale: Optional[torch.Tensor] = None,\n    w2_scale: Optional[torch.Tensor] = None,\n    gating_output: torch.Tensor,\n    g_idx1: torch.Tensor,\n    g_idx2: torch.Tensor,\n    sort_indices1: torch.Tensor,\n    sort_indices2: torch.Tensor,\n    w1_zeros: Optional[torch.Tensor] = None,\n    w2_zeros: Optional[torch.Tensor] = None,\n    is_k_full: bool,\n    topk: int,\n    renormalize: bool,\n    num_bits: int = 8,\n    use_grouped_topk: bool = False,\n    num_expert_group: Optional[int] = None,\n    custom_routing_function: Optional[Callable] = None,\n    topk_group: Optional[int] = None,\n    scoring_func: Optional[str] = None,\n    e_score_correction_bias: Optional[float] = None,\n) -> torch.Tensor:\n    \"\"\"\n    This function computes a Mixture of Experts (MoE) layer using two sets of\n    weights, w1 and w2, and top-k gating mechanism.\n\n    Parameters:\n    - hidden_states (torch.Tensor): The input tensor to the MoE layer.\n    - w1 (torch.Tensor): The first set of expert weights.\n    - w2 (torch.Tensor): The second set of expert weights.\n    - w1_scale (Optional[torch.Tensor]): Optional scale to be used for\n        w1.\n    - w2_scale (Optional[torch.Tensor]): Optional scale to be used for\n        w2.\n    - gating_output (torch.Tensor): The output of the gating operation\n        (before softmax).\n    - g_idx1 (torch.Tensor): The first set of act_order indices.\n    - g_idx2 (torch.Tensor): The second set of act_order indices.\n    - sort_indices1 (torch.Tensor): The first act_order input permutation.\n    - sort_indices2 (torch.Tensor): The second act_order input permutation.\n    - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.\n    - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.\n    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.\n    - num_bits (bool): The number of bits in expert weights quantization.\n\n    Returns:\n    - torch.Tensor: The output tensor after applying the MoE layer.\n    \"\"\"\n    # Check constraints.\n    assert hidden_states.shape[0] == gating_output.shape[0], \"Number of tokens mismatch\"\n    assert hidden_states.shape[1] == w1.shape[1] * 16, \"Hidden size mismatch w1\"\n    assert hidden_states.shape[1] == w2.shape[2] // (\n        num_bits // 2\n    ), \"Hidden size mismatch w2\"\n    assert gating_output.shape[1] == w1.shape[0], \"Number of experts mismatch\"\n    assert hidden_states.is_contiguous(), \"Hidden_states must be contiguous\"\n    assert w1.is_contiguous(), \"Expert weights1 must be contiguous\"\n    assert w2.is_contiguous(), \"Expert weights2 must be contiguous\"\n    assert hidden_states.dtype == torch.float16\n    assert num_bits in [4, 8]\n\n    # DeekSeekv2 uses grouped_top_k\n    if use_grouped_topk:\n        assert topk_group is not None\n        assert num_expert_group is not None\n        topk_weights, topk_ids = moe_kernels.grouped_topk(\n            hidden_states=hidden_states,\n            gating_output=gating_output,\n            topk=topk,\n            renormalize=renormalize,\n            num_expert_group=num_expert_group,\n            topk_group=topk_group,\n            scoring_func=scoring_func,\n            e_score_correction_bias=e_score_correction_bias,\n        )\n    elif custom_routing_function is None:\n        topk_weights, topk_ids = moe_kernels.fused_topk(\n            hidden_states=hidden_states,\n            gating_output=gating_output,\n            topk=topk,\n            renormalize=renormalize,\n        )\n    else:\n        topk_weights, topk_ids = custom_routing_function(\n            hidden_states=hidden_states,\n            gating_output=gating_output,\n            topk=topk,\n            renormalize=renormalize,\n        )\n    return moe_kernels.fused_marlin_moe(\n        hidden_states=hidden_states,\n        w1=w1,\n        w2=w2,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        gating_output=gating_output,\n        topk_weights=topk_weights,\n        topk_ids=topk_ids,\n        g_idx1=g_idx1,\n        g_idx2=g_idx2,\n        sort_indices1=sort_indices1,\n        sort_indices2=sort_indices2,\n        w1_zeros=w1_zeros,\n        w2_zeros=w2_zeros,\n        num_bits=num_bits,\n        is_k_full=is_k_full,\n    )\n"
  },
  {
    "path": "server/text_generation_server/layers/moe/unquantized.py",
    "content": "from typing import Callable, List, Optional\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\nfrom text_generation_server.utils.weights import UnquantizedWeight, Weights\n\nif SYSTEM == \"ipex\":\n    from intel_extension_for_pytorch.llm.modules import GatedMLPMOE\nelif SYSTEM == \"cuda\":\n    moe_kernels = load_kernel(module=\"moe\", repo_id=\"kernels-community/moe\")\nelse:\n    import moe_kernels\n\n\nclass UnquantizedSparseMoELayer(nn.Module):\n    def __init__(\n        self,\n        *,\n        n_expert_group: Optional[int],\n        n_experts: int,\n        prefix: str,\n        renormalize: bool,\n        topk: int,\n        topk_group: Optional[int],\n        weights: Weights,\n        scoring_func: Optional[str] = \"softmax\",\n        e_score_correction_bias: Optional[float] = None,\n        gate_proj_name: str = \"gate_proj\",\n        up_proj_name: str = \"up_proj\",\n        down_proj_name: str = \"down_proj\",\n    ):\n        super().__init__()\n\n        assert (n_expert_group is None) == (\n            topk_group is None\n        ), \"n_expert_group and topk_group must both be None or have some value\"\n\n        self.n_expert_group = n_expert_group\n        self.topk = topk\n        self.topk_group = topk_group\n        self.renormalize = renormalize\n        self.weight_block_size = weights.weights_loader.weight_block_size\n        self.scoring_func = scoring_func\n        self.e_score_correction_bias = e_score_correction_bias\n\n        self.gate_up_proj = _load_expert_multi_weights_col(\n            prefix=prefix,\n            n_experts=n_experts,\n            gate_proj_name=gate_proj_name,\n            up_proj_name=up_proj_name,\n            weights=weights,\n        )\n\n        self.down_proj = _load_expert_weights_row(\n            prefix=prefix,\n            n_experts=n_experts,\n            name=down_proj_name,\n            weights=weights,\n        )\n        if SYSTEM == \"ipex\":\n            self.ipex_fused_moe = GatedMLPMOE(\n                W13=self.gate_up_proj, W2=self.down_proj, use_prepack=True\n            )\n\n    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:\n        if SYSTEM == \"rocm\":\n            return moe_kernels.fused_moe(\n                x,\n                self.gate_up_proj,\n                self.down_proj,\n                gating_output,\n                self.topk,\n                renormalize=self.renormalize,\n                inplace=True,\n            )\n        elif SYSTEM == \"ipex\":\n            return self.ipex_fused_moe(\n                hidden_states=x,\n                router_logits=gating_output,\n                top_k=self.topk,\n                renormalize=self.renormalize,\n                use_grouped_topk=self.n_expert_group is not None,\n                num_expert_group=self.n_expert_group,\n                topk_group=self.topk_group,\n                scoring_func=self.scoring_func,\n                e_score_correction_bias=self.e_score_correction_bias,\n            )\n        return fused_moe(\n            x,\n            w1=self.gate_up_proj,\n            w2=self.down_proj,\n            gating_output=gating_output,\n            topk=self.topk,\n            renormalize=self.renormalize,\n            inplace=True,\n            use_grouped_topk=self.n_expert_group is not None,\n            num_expert_group=self.n_expert_group,\n            topk_group=self.topk_group,\n            scoring_func=self.scoring_func,\n            e_score_correction_bias=self.e_score_correction_bias,\n        )\n\n\ndef _load_expert_multi_weights_col(\n    *,\n    prefix: str,\n    n_experts: int,\n    gate_proj_name: str,\n    up_proj_name: str,\n    weights: Weights,\n) -> torch.Tensor:\n    all_weight = None\n    for i in range(n_experts):\n        weight = weights.get_multi_weights_col(\n            [f\"{prefix}.{i}.{gate_proj_name}\", f\"{prefix}.{i}.{up_proj_name}\"], 0\n        )\n\n        assert isinstance(weight, UnquantizedWeight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=weight.weight.dtype,\n                device=weight.weight.device,\n            )\n\n        all_weight[i] = weight.weight\n\n    assert all_weight is not None\n\n    return all_weight\n\n\ndef _load_expert_weights_row(\n    *,\n    prefix: str,\n    n_experts: int,\n    name: str,\n    weights: Weights,\n) -> torch.Tensor:\n    all_weight = None\n    for i in range(n_experts):\n        weight = weights.get_weights_row(\n            f\"{prefix}.{i}.{name}\",\n        )\n\n        assert isinstance(weight, UnquantizedWeight)\n\n        if all_weight is None:\n            all_weight = torch.empty(\n                (n_experts,) + weight.weight.shape,\n                dtype=weight.weight.dtype,\n                device=weight.weight.device,\n            )\n\n        all_weight[i] = weight.weight\n\n    assert all_weight is not None\n\n    return all_weight\n\n\ndef fused_moe(\n    hidden_states: torch.Tensor,\n    w1: torch.Tensor,\n    w2: torch.Tensor,\n    gating_output: torch.Tensor,\n    topk: int,\n    renormalize: bool,\n    inplace: bool = False,\n    use_grouped_topk: bool = False,\n    num_expert_group: Optional[int] = None,\n    topk_group: Optional[int] = None,\n    custom_routing_function: Optional[Callable] = None,\n    scoring_func: str = \"softmax\",\n    e_score_correction_bias: Optional[torch.Tensor] = None,\n    use_fp8_w8a8: bool = False,\n    use_int8_w8a16: bool = False,\n    use_int4_w4a16: bool = False,\n    w1_scale: Optional[torch.Tensor] = None,\n    w2_scale: Optional[torch.Tensor] = None,\n    a1_scale: Optional[torch.Tensor] = None,\n    a2_scale: Optional[torch.Tensor] = None,\n    block_shape: Optional[List[int]] = None,\n) -> torch.Tensor:\n    \"\"\"\n    This function computes a Mixture of Experts (MoE) layer using two sets of\n    weights, w1 and w2, and top-k gating mechanism.\n\n    Parameters:\n    - hidden_states (torch.Tensor): The input tensor to the MoE layer.\n    - w1 (torch.Tensor): The first set of expert weights.\n    - w2 (torch.Tensor): The second set of expert weights.\n    - gating_output (torch.Tensor): The output of the gating operation\n        (before softmax).\n    - topk (int): The number of top-k experts to select.\n    - renormalize (bool): If True, renormalize the top-k weights to sum to 1.\n    - inplace (bool): If True, perform the operation in-place.\n        Defaults to False.\n    - num_expert_group: Optional[int]: additional parameter for grouped_topk\n    - topk_group: Optional[int]: additional parameter for grouped_topk\n    - use_grouped_topk: If True, use grouped_topk instead of fused_topk\n        note: Deepseekv2 model uses grouped_topk\n    - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner\n        products for w1 and w2. Defaults to False.\n    - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner\n        products for w1 and w2. Defaults to False.\n    - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16\n        activation to compute the inner products for w1 and w2.\n        Defaults to False.\n    - w1_scale (Optional[torch.Tensor]): Optional scale to be used for\n        w1.\n    - w2_scale (Optional[torch.Tensor]): Optional scale to be used for\n        w2.\n    - a1_scale (Optional[torch.Tensor]): Optional scale to be used for\n        a1.\n    - a2_scale (Optional[torch.Tensor]): Optional scale to be used for\n        a2.\n    - block_shape: (Optional[List[int]]): Optional block size for block-wise\n        quantization.\n    Returns:\n    - torch.Tensor: The output tensor after applying the MoE layer.\n    \"\"\"\n    # Check constraints.\n    assert gating_output.shape[1] == w1.shape[0], \"Number of experts mismatch\"\n\n    if use_grouped_topk:\n        assert num_expert_group is not None and topk_group is not None\n        from loguru import logger\n        import inspect\n\n        logger.info(f\"{inspect.signature(moe_kernels.grouped_topk)}\")\n        topk_weights, topk_ids = moe_kernels.grouped_topk(\n            hidden_states,\n            gating_output,\n            topk,\n            renormalize,\n            num_expert_group,\n            topk_group,\n            scoring_func=scoring_func,\n            e_score_correction_bias=e_score_correction_bias,\n        )\n    elif custom_routing_function is None:\n        topk_weights, topk_ids = moe_kernels.fused_topk(\n            hidden_states, gating_output, topk, renormalize\n        )\n    else:\n        topk_weights, topk_ids = custom_routing_function(\n            hidden_states, gating_output, topk, renormalize\n        )\n\n    return moe_kernels.fused_experts(\n        hidden_states,\n        w1,\n        w2,\n        topk_weights,\n        topk_ids,\n        inplace=inplace,\n        use_fp8_w8a8=use_fp8_w8a8,\n        use_int8_w8a16=use_int8_w8a16,\n        use_int4_w4a16=use_int4_w4a16,\n        w1_scale=w1_scale,\n        w2_scale=w2_scale,\n        a1_scale=a1_scale,\n        a2_scale=a2_scale,\n        block_shape=block_shape,\n    )\n"
  },
  {
    "path": "server/text_generation_server/layers/rotary.py",
    "content": "import os\nimport math\nimport torch\nfrom torch import nn\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"cuda\":\n    from text_generation_server.utils.kernels import load_kernel\n\n    rotary = load_kernel(module=\"rotary\", repo_id=\"kernels-community/rotary\")\nelif SYSTEM == \"rocm\":\n    import vllm._custom_ops as ops\nelif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\n\n\ndef _create_inv_freq(dim, base, device):\n    inv_freq = 1.0 / (\n        base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)\n    )\n    return inv_freq\n\n\ndef _get_rope_config(config):\n    if os.getenv(\"ROPE_SCALING\", None) is not None:\n        rope_scaling = {\n            \"type\": os.environ[\"ROPE_SCALING\"],\n            \"factor\": float(os.environ[\"ROPE_FACTOR\"]),\n        }\n        return rope_scaling\n    return getattr(config, \"rope_scaling\", None)\n\n\nclass PositionRotaryEmbedding(nn.Module):\n    def __init__(self, inv_freq, scaling_factor):\n        super().__init__()\n        self.inv_freq = inv_freq\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.scaling_factor = scaling_factor\n        self.dynamic_args = None\n\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        # Such controlflows may add some overhead.\n        if SYSTEM == \"cuda\":\n            rotary_dim = cos.shape[-1]\n            q1 = query[..., :rotary_dim]\n            q2 = query[..., rotary_dim : 2 * rotary_dim]\n\n            rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)\n\n            k1 = key[..., :rotary_dim]\n            k2 = key[..., rotary_dim : 2 * rotary_dim]\n\n            rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)\n        elif SYSTEM == \"rocm\":\n            # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.\n            # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773\n\n            head_size = query.shape[-1]\n\n            # Inplace operation, updating query and key.\n            ops.rotary_embedding(query, key, head_size, cos, sin, True)\n        elif SYSTEM == \"ipex\":\n            ipex.llm.functional.rotary_embedding(\n                query, key, sin, cos, query.size(-1), True\n            )\n        else:\n            raise ValueError(\n                \"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.\"\n            )\n\n    @classmethod\n    def static(cls, config, dim, base, device):\n        inv_freq = _create_inv_freq(dim, base, device)\n        scaling_factor = None\n        rope_scaling = _get_rope_config(config)\n        if rope_scaling is not None:\n            # `rope_type` is now standard in transformers, but some existing models\n            # have `type` instead.\n            rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))\n\n            if rope_type == \"linear\":\n                pass\n            elif rope_type == \"default\":\n                pass\n            elif rope_type == \"mrope\":\n                mrope_section = rope_scaling[\"mrope_section\"]\n                if mrope_section is not None:\n                    return RotaryPositionEmbeddingMultimodalSections(\n                        inv_freq, scaling_factor, mrope_section\n                    )\n            elif rope_type == \"dynamic\":\n                scaling_factor = rope_scaling[\"factor\"]\n                return DynamicPositionRotaryEmbedding(\n                    dim=dim,\n                    max_position_embeddings=config.max_position_embeddings,\n                    base=base,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                )\n            elif rope_type == \"llama3\":\n                inv_freq = apply_llama3_scaling(\n                    inv_freq,\n                    scaling_factor=rope_scaling[\"factor\"],\n                    low_freq_factor=rope_scaling[\"low_freq_factor\"],\n                    high_freq_factor=rope_scaling[\"high_freq_factor\"],\n                    original_max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                )\n\n                return cls(inv_freq, scaling_factor)\n\n            elif rope_type == \"yarn\":\n                scaling_factor = rope_scaling[\"factor\"]\n                mscale = rope_scaling.get(\"mscale\", 1.0)\n                mscale_all_dim = rope_scaling.get(\"mscale_all_dim\", 0.0)\n                return YarnPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                    base=base,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                    extrapolation_factor=1,\n                    attn_factor=1,\n                    beta_fast=32,\n                    beta_slow=1,\n                    mscale=mscale,\n                    mscale_all_dim=mscale_all_dim,\n                )\n            elif rope_type in [\"su\", \"longrope\"]:\n                short_factor = torch.tensor(\n                    rope_scaling[\"short_factor\"], dtype=torch.float32, device=device\n                )\n                short_inv_freq = 1.0 / (\n                    short_factor\n                    * base\n                    ** (\n                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)\n                        / dim\n                    )\n                )\n                long_factor = torch.tensor(\n                    rope_scaling[\"long_factor\"], dtype=torch.float32, device=device\n                )\n                long_inv_freq = 1.0 / (\n                    long_factor\n                    * base\n                    ** (\n                        torch.arange(0, dim, 2, device=device, dtype=torch.float32)\n                        / dim\n                    )\n                )\n\n                original_max_position_embeddings = (\n                    config.original_max_position_embeddings\n                )\n                max_position_embeddings = config.max_position_embeddings\n                if max_position_embeddings <= original_max_position_embeddings:\n                    scaling_factor = 1.0\n                else:\n                    scale = max_position_embeddings / original_max_position_embeddings\n                    scaling_factor = math.sqrt(\n                        1 + math.log(scale) / math.log(original_max_position_embeddings)\n                    )\n\n                # if short_mscale and long_mscale are provided we need to scale the freqs\n                # using the Phi3LongRoPEScaledRotaryEmbedding\n                if (\"short_mscale\" in rope_scaling) and (\"long_mscale\" in rope_scaling):\n                    short_mscale = rope_scaling[\"short_mscale\"]\n                    long_mscale = rope_scaling[\"long_mscale\"]\n                    return Phi3LongRoPEScaledRotaryEmbedding(\n                        short_inv_freq=short_inv_freq,\n                        long_inv_freq=long_inv_freq,\n                        max_position_embeddings=config.max_position_embeddings,\n                        short_mscale=short_mscale,\n                        long_mscale=long_mscale,\n                        original_max_position_embeddings=original_max_position_embeddings,\n                    )\n\n                return SuRotaryEmbedding(\n                    short_inv_freq=short_inv_freq,\n                    long_inv_freq=long_inv_freq,\n                    scaling_factor=scaling_factor,\n                    original_max_position_embeddings=original_max_position_embeddings,\n                )\n            else:\n                raise NotImplementedError(\n                    f\"rope scaling type {rope_scaling['type']} is not implemented or invalid\"\n                )\n        return cls(inv_freq, scaling_factor)\n\n    @classmethod\n    def load(cls, config, prefix, weights):\n        # XXX: Always load this in float32 !\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        inv_freq = weights.get_tensor(f\"{prefix}.inv_freq\")\n        weights.dtype = dtype\n\n        scaling_factor = None\n        rope_scaling = _get_rope_config(config)\n        if rope_scaling is not None:\n            scaling_factor = rope_scaling[\"factor\"]\n            if rope_scaling[\"type\"] == \"linear\":\n                pass\n            elif rope_scaling[\"type\"] == \"dynamic\":\n                return DynamicPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=config.max_position_embeddings,\n                    base=10000.0,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                )\n            elif rope_scaling[\"type\"] == \"yarn\":\n                mscale = rope_scaling.get(\"mscale\", 1.0)\n                mscale_all_dim = rope_scaling.get(\"mscale_all_dim\", 0.0)\n                return YarnPositionRotaryEmbedding(\n                    dim=2 * inv_freq.shape[0],\n                    max_position_embeddings=rope_scaling[\n                        \"original_max_position_embeddings\"\n                    ],\n                    base=10000.0,\n                    device=inv_freq.device,\n                    scaling_factor=scaling_factor,\n                    extrapolation_factor=1,\n                    attn_factor=1,\n                    beta_fast=32,\n                    beta_slow=1,\n                    mscale=mscale,\n                    mscale_all_dim=mscale_all_dim,\n                )\n            else:\n                raise NotImplementedError(\n                    f\"rope scaling type {rope_scaling['type']} is not implemented or invalid\"\n                )\n        return cls(inv_freq, scaling_factor)\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            if self.scaling_factor is not None:\n                t /= self.scaling_factor\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n    def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):\n        \"\"\"\n        Return cos and sin for the asked position ids\n        \"\"\"\n        if SYSTEM == \"rocm\":\n            # For RoCm, we always use float cos/sin to avoid a cast.\n            # For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26\n            # But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.\n            dtype = torch.float32\n\n        self._update_cos_sin_cache(dtype, position_ids.device, max_s)\n\n        cos = torch.index_select(self._cos_cached, 0, position_ids)\n        sin = torch.index_select(self._sin_cached, 0, position_ids)\n\n        # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.\n        return cos.unsqueeze(1), sin.unsqueeze(1)\n\n\nclass SuRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        short_inv_freq,\n        long_inv_freq,\n        scaling_factor,\n        original_max_position_embeddings,\n    ):\n        super(PositionRotaryEmbedding, self).__init__()\n        self.short_inv_freq = short_inv_freq\n        self.long_inv_freq = long_inv_freq\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.dynamic_args = None\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n\n            t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)\n            short_freqs = torch.outer(\n                t[: self.original_max_position_embeddings],\n                self.short_inv_freq.to(device=t.device),\n            )\n            long_freqs = torch.outer(\n                t[self.original_max_position_embeddings :],\n                self.long_inv_freq.to(device=t.device),\n            )\n\n            freqs = torch.cat([short_freqs, long_freqs])\n\n            self._cos_cached = (torch.cos(freqs) * self.scaling_factor).to(dtype)\n            self._sin_cached = (torch.sin(freqs) * self.scaling_factor).to(dtype)\n\n\nclass Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        short_inv_freq: torch.Tensor,\n        long_inv_freq: torch.Tensor,\n        max_position_embeddings: int,\n        short_mscale: float,\n        long_mscale: float,\n        original_max_position_embeddings: int,\n    ):\n        super(PositionRotaryEmbedding, self).__init__()\n        self.short_inv_freq = short_inv_freq\n        self.long_inv_freq = long_inv_freq\n        self.max_position_embeddings = max_position_embeddings\n        self.short_mscale = short_mscale\n        self.long_mscale = long_mscale\n        self.original_max_position_embeddings = original_max_position_embeddings\n\n        # cache\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n        self.dynamic_args = None\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.short_inv_freq.dtype)\n\n            short_freqs = torch.outer(\n                t[: self.original_max_position_embeddings],\n                self.short_inv_freq.to(device=t.device),\n            )\n\n            long_freqs = torch.outer(\n                t[self.original_max_position_embeddings :],\n                self.long_inv_freq.to(device=t.device),\n            )\n\n            short_freqs = short_freqs * self.short_mscale\n            long_freqs = long_freqs * self.long_mscale\n\n            freqs = torch.empty((seqlen, short_freqs.shape[1]), device=device)\n            freqs[: self.original_max_position_embeddings] = short_freqs\n            freqs[self.original_max_position_embeddings :] = long_freqs\n\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n\nclass DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):\n        inv_freq = _create_inv_freq(dim, base, device)\n        super().__init__(inv_freq, scaling_factor)\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            if seqlen > self.max_position_embeddings:\n                newbase = self.base * (\n                    (self.scaling_factor * seqlen / self.max_position_embeddings)\n                    - (self.scaling_factor - 1)\n                ) ** (self.dim / (self.dim - 2))\n                self.inv_freq = _create_inv_freq(\n                    self.dim, newbase, self.inv_freq.device\n                )\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n\n\ndef find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))\n    high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\ndef get_mscale(scale: float = 1.0, mscale: float = 1.0):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\nclass YarnPositionRotaryEmbedding(PositionRotaryEmbedding):\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings,\n        base,\n        device,\n        scaling_factor,\n        *,\n        extrapolation_factor,\n        attn_factor,\n        beta_fast,\n        beta_slow,\n        mscale: float,\n        mscale_all_dim: float,\n    ):\n        inv_freq = _create_inv_freq(dim, base, device)\n        super().__init__(inv_freq, scaling_factor)\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.extrapolation_factor = extrapolation_factor\n        self.attn_factor = attn_factor\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.mscale = float(\n            get_mscale(self.scaling_factor, mscale)\n            / get_mscale(self.scaling_factor, mscale_all_dim)\n            * self.attn_factor\n        )  # Get n-d magnitude scaling corrected for interpolation\n\n    def _update_cos_sin_cache(self, dtype, device, seqlen):\n        # Reset the tables if the sequence length has changed,\n        # or if we're on a new device (possibly due to tracing for instance)\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            if seqlen > self.max_position_embeddings or True:\n                inv_freq_extrapolation = _create_inv_freq(\n                    self.dim, self.base, self.inv_freq.device\n                )\n                freqs = 1.0 / inv_freq_extrapolation\n                inv_freq_interpolation = 1.0 / (self.scaling_factor * freqs)\n                low, high = find_correction_range(\n                    self.beta_fast,\n                    self.beta_slow,\n                    self.dim,\n                    self.base,\n                    self.max_position_embeddings,\n                )\n\n                inv_freq_mask = (\n                    1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)\n                ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation\n                inv_freq = (\n                    inv_freq_interpolation * (1 - inv_freq_mask)\n                    + inv_freq_extrapolation * inv_freq_mask\n                )\n\n                self.inv_freq = inv_freq\n\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            # Don't do einsum, it converts fp32 to fp16\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype)\n            self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)\n\n\ndef apply_llama3_scaling(\n    freqs: torch.Tensor,\n    *,\n    scaling_factor: int,\n    low_freq_factor: int,\n    high_freq_factor: int,\n    original_max_position_embeddings: int,\n):\n    low_freq_wavelen = original_max_position_embeddings / low_freq_factor\n    high_freq_wavelen = original_max_position_embeddings / high_freq_factor\n    new_freqs = []\n\n    for freq in freqs:\n        wavelen = 2 * math.pi / freq\n\n        if wavelen < high_freq_wavelen:\n            new_freqs.append(freq)\n        elif wavelen > low_freq_wavelen:\n            new_freqs.append(freq / scaling_factor)\n        else:\n            assert low_freq_wavelen != high_freq_wavelen\n            smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (\n                high_freq_factor - low_freq_factor\n            )\n            new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)\n\n    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)\n\n\nclass RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):\n    def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):\n        super().__init__(inv_freq, scaling_factor)\n        self.sections = sections\n        self._cos_cached = None\n        self._sin_cached = None\n        self.section_indices = (\n            torch.arange(len(self.sections))\n            .repeat_interleave(torch.tensor(self.sections))\n            .view(1, 1, -1)\n            .to(inv_freq.device)\n        )\n\n    def _update_cos_sin_cache(\n        self, dtype: torch.dtype, device: torch.device, seqlen: int\n    ):\n        # always cache the cos/sin for the full sequence length to avoid\n        # recomputing if the sequence length is smaller than the cached one\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n        ):\n            self._seq_len_cached = seqlen\n            t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n            freqs = torch.outer(t, self.inv_freq.to(device=t.device))\n            self._cos_cached = torch.cos(freqs).to(dtype)\n            self._sin_cached = torch.sin(freqs).to(dtype)\n            self._sections = self.section_indices.expand(seqlen, -1, -1)\n\n    def get_cos_sin(\n        self,\n        position_ids: torch.Tensor,\n        max_s: int,\n        dtype: torch.dtype,\n    ):\n        self._update_cos_sin_cache(dtype, position_ids.device, max_s)\n        slen = position_ids.shape[0]\n\n        cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])\n        sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])\n        return cos, sin\n"
  },
  {
    "path": "server/text_generation_server/layers/speculative.py",
    "content": "import torch\nimport json\nfrom typing import Tuple, Optional\nfrom text_generation_server.layers.tensor_parallel import TensorParallelHead\nfrom text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2\nfrom text_generation_server.layers.mlp import MLPSpeculatorHead\n\n\nclass SpeculativeHead(torch.nn.Module):\n    def __init__(self, lm_head, speculator):\n        super().__init__()\n        self.head = lm_head\n        self.speculator = speculator\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        speculator = config.speculator\n        if speculator:\n            speculator_path = config.speculator[\"path\"]\n            speculator_config = str(speculator_path / \"config.json\")\n\n            with open(speculator_config, \"r\") as f:\n                speculator_config = json.load(f)\n\n            config.speculator_config = speculator_config\n            try:\n                architecture = speculator_config[\"architectures\"][0]\n\n                if architecture == \"MLPSpeculatorPreTrainedModel\":\n                    speculator = MLPSpeculatorHead.load(config, prefix, weights)\n                else:\n                    speculator = None\n            except KeyError:\n                try:\n                    speculator = MedusaHeadV1.load(config, prefix, weights)\n                except Exception:\n                    speculator = MedusaHeadV2(config, prefix, weights)\n            lm_head = None\n        else:\n            lm_head = TensorParallelHead.load(config, prefix, weights)\n            speculator = None\n        return SpeculativeHead(lm_head, speculator)\n\n    def forward(\n        self, input: torch.Tensor\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        if self.speculator is not None:\n            return self.speculator(input)\n\n        assert self.head is not None\n        logits = self.head(input)\n        return logits, None\n"
  },
  {
    "path": "server/text_generation_server/layers/tensor_parallel.py",
    "content": "import torch\nfrom torch.nn import functional as F\nfrom typing import Iterable, List\nfrom text_generation_server.layers.linear import get_linear, FastLinear\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\n\n\nclass LayerConcat(torch.nn.Module):\n    \"\"\"\n    Apply multiple layers to the input and concatenate their\n    outputs.\n    \"\"\"\n\n    def __init__(self, layers: Iterable[torch.nn.Module], dim: int = -1):\n        \"\"\"\n        `dim` is the dimension along which layer outputs are concatenated.\n        \"\"\"\n        super().__init__()\n        self.layers = layers\n        self.dim = dim\n\n    def forward(self, x: torch.Tensor):\n        outputs = [layer(x) for layer in self.layers]\n        return torch.cat(outputs, self.dim)\n\n\nclass SuperLayer(torch.nn.Module):\n    def __init__(self, linear):\n        super().__init__()\n        self.linear = linear\n\n    def forward(self, x):\n        return self.linear.forward(x)\n\n\nclass TensorParallelHead(SuperLayer):\n    def __init__(self, linear, process_group, should_gather: bool):\n        super().__init__(linear)\n        self.process_group = process_group\n        self.should_gather = should_gather\n\n    @staticmethod\n    def load(config, prefix: str, weights):\n        if config.quantize == \"exl2\":\n            try:\n                # If the piece and LM head embeddings are shared, we have\n                # non-quantized weights...\n                weight = weights.get_tensor(f\"{prefix}.weight\")\n            except Exception:\n                # ...otherwise they are quantized.\n                weight = weights.get_weights_col(prefix)\n            should_gather = weights.process_group.size() > 1\n        elif weights.process_group.size() > 1:\n            try:\n                weight = weights.get_sharded(f\"{prefix}.weight\", dim=0)\n                should_gather = True\n            except AssertionError:\n                # If the vocab size is not divisible by number of shards\n                # just load the entire thing.\n                weight = weights.get_tensor(f\"{prefix}.weight\")\n                should_gather = False\n        else:\n            weight = weights.get_tensor(f\"{prefix}.weight\")\n            should_gather = False\n\n        return TensorParallelHead(\n            get_linear(weight, bias=None),\n            process_group=weights.process_group,\n            should_gather=should_gather,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if not self.should_gather:\n            return super().forward(input)\n\n        world_size = self.process_group.size()\n        if len(input.shape) == 2 and isinstance(self.linear, FastLinear):\n            out_dim = self.linear.weight.shape[0]\n\n            if input.shape[0] == 1:\n                world_out = input.new_empty(1, out_dim * world_size)\n                local_out = input.new_empty(1, out_dim)\n                gather_input = local_out\n            else:\n                world_out = input.new_empty(out_dim * world_size, input.shape[0])\n                gather_input = input.new_empty(out_dim, input.shape[0])\n                local_out = gather_input.T\n\n            torch.mm(input, self.linear.weight.T, out=local_out)\n            if SYSTEM == \"ipex\" and gather_input.device.type == \"cpu\":\n                ipex.distributed.all_gather_into_tensor(\n                    world_out, gather_input, group=self.process_group\n                )\n            else:\n                torch.distributed.all_gather_into_tensor(\n                    world_out, gather_input, group=self.process_group\n                )\n\n            if input.shape[0] == 1:\n                return world_out\n            return world_out.T\n\n        output = super().forward(input)\n        world_output = [\n            torch.empty_like(output) for _ in range(self.process_group.size())\n        ]\n        if SYSTEM == \"ipex\" and output.device.type == \"cpu\":\n            ipex.distributed.all_gather(world_output, output, group=self.process_group)\n        else:\n            torch.distributed.all_gather(world_output, output, group=self.process_group)\n        world_output = torch.cat(world_output, dim=-1)\n        return world_output\n\n\nclass TensorParallelColumnLinear(SuperLayer):\n    @classmethod\n    def load_gate_up(cls, config, prefix: str, weights, bias: bool):\n        \"\"\"Specific method when the QKV was joined after the fact\"\"\"\n        weight = weights.get_weights_col_packed_gate_up(prefix)\n        if bias:\n            raise NotImplementedError(\"packed_gate_up only implemented without bias\")\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load_qkv(\n        cls,\n        config,\n        prefix: str,\n        weights,\n        bias: bool,\n        num_heads: int,\n        num_key_value_heads: int,\n    ):\n        \"\"\"Specific method when the QKV was joined after the fact\"\"\"\n        weight = weights.get_weights_col_packed_qkv(\n            prefix,\n            num_heads=num_heads,\n            num_key_value_heads=num_key_value_heads,\n        )\n        if bias:\n            raise NotImplementedError(\"packed_qkv only implemented for baichuan\")\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_weights_col(prefix)\n        if bias:\n            bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n        else:\n            bias = None\n        linear = get_linear(weight, bias)\n        return cls(linear)\n\n    @classmethod\n    def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):\n        if config.quantize == \"exl2\":\n            linears = []\n            for prefix in prefixes:\n                weight = weights.get_weights_col(prefix)\n                b = weights.get_tensor(f\"{prefix}.bias\") if bias else None\n                linears.append(get_linear(weight, b))\n            linear = LayerConcat(linears)\n        else:\n            weight = weights.get_multi_weights_col(prefixes, dim=dim)\n            if bias:\n                b = [weights.get_sharded(f\"{p}.bias\", dim=0) for p in prefixes]\n                bias = torch.cat(b, dim=dim)\n            else:\n                bias = None\n            linear = get_linear(weight, bias)\n        return cls(linear)\n\n\nclass TensorParallelRowLinear(SuperLayer):\n    def __init__(self, linear, process_group):\n        super().__init__(linear)\n        self.process_group = process_group\n\n    @classmethod\n    def load(cls, config, prefix: str, weights, bias: bool):\n        weight = weights.get_weights_row(prefix)\n\n        if bias and weights.process_group.rank() == 0:\n            # Rank is only on the first rank process\n            bias = weights.get_tensor(f\"{prefix}.bias\")\n        else:\n            bias = None\n        return cls(\n            get_linear(weight, bias),\n            process_group=weights.process_group,\n        )\n\n    def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:\n        out = super().forward(input)\n        if self.process_group.size() > 1 and reduce:\n            if SYSTEM == \"ipex\" and out.device.type == \"cpu\":\n                ipex.distributed.all_reduce(out, group=self.process_group)\n            else:\n                torch.distributed.all_reduce(out, group=self.process_group)\n        return out\n\n\nclass TensorParallelEmbedding(torch.nn.Module):\n    def __init__(self, prefix: str, weights, reduce=True):\n        super().__init__()\n        weight = weights.get_partial_sharded(f\"{prefix}.weight\", dim=0)\n        num_embeddings = weights.get_shape(f\"{prefix}.weight\")[0]\n\n        process_group = weights.process_group\n\n        world_size = process_group.size()\n        rank = process_group.rank()\n\n        block_size = (num_embeddings + world_size - 1) // world_size\n        self.min_id = rank * block_size\n        self.max_id = min(num_embeddings, (rank + 1) * block_size)\n        self.null_idx = weight.shape[\n            0\n        ]  # Usually block_size, might be less in non even vocab_size.\n        self.process_group = weights.process_group\n        self.reduce = reduce\n\n        \"\"\"Additional 0 entry used for masking\"\"\"\n        self.weight = torch.nn.Parameter(F.pad(weight, (0, 0, 0, 1)))\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        # default all out of bounds values to `self.null_idx` that will then be mapped to 0\n        # translate for [0, self.max_id - self.min_id[\n        input = torch.where(\n            (self.min_id > input) | (input >= self.max_id),\n            self.null_idx,\n            input - self.min_id,\n        )\n        out = torch.nn.functional.embedding(input, self.weight)\n        if self.reduce and self.process_group.size() > 1:\n            if SYSTEM == \"ipex\" and out.device.type == \"cpu\":\n                ipex.distributed.all_reduce(out, group=self.process_group)\n            else:\n                torch.distributed.all_reduce(out, group=self.process_group)\n        return out\n"
  },
  {
    "path": "server/text_generation_server/models/__init__.py",
    "content": "# ruff: noqa: F821\n# the above line disables the `undefined-name` rule for the model type variables\n\nfrom compressed_tensors.compressors.model_compressors.model_compressor import (\n    QuantizationConfig,\n)\nfrom compressed_tensors.quantization import QuantizationType\nfrom pydantic import ValidationError\nimport enum\nimport os\nfrom typing import Optional, List, Dict\nfrom pathlib import Path\nfrom loguru import logger\n\nimport torch\nimport transformers\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.models.auto import modeling_auto\nfrom transformers.dynamic_module_utils import get_class_from_dynamic_module\nfrom huggingface_hub import hf_hub_download, HfApi\n\nfrom text_generation_server.utils.speculate import get_speculate, set_speculate\nfrom text_generation_server.models.model import Model\nfrom text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast\n\nfrom text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM\nfrom text_generation_server.models.custom_modeling.mpt_modeling import (\n    MPTForCausalLM,\n)\nfrom text_generation_server.models.bloom import BloomCausalLMBatch\nfrom text_generation_server.models.custom_modeling.bloom_modeling import (\n    BloomForCausalLM,\n)\nfrom text_generation_server.models.globals import ATTENTION\nfrom text_generation_server.models.seq2seq_lm import Seq2SeqLM\nfrom text_generation_server.models.galactica import GalacticaCausalLMBatch\nfrom text_generation_server.models.custom_modeling.neox_modeling import (\n    GPTNeoxForCausalLM,\n)\nfrom text_generation_server.models.custom_modeling.phi_modeling import (\n    PhiConfig,\n    PhiForCausalLM,\n)\nfrom text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (\n    PhiMoEConfig,\n)\nfrom text_generation_server.models.custom_modeling.t5_modeling import (\n    T5ForConditionalGeneration,\n)\n\n\nfrom text_generation_server.utils.adapter import (\n    AdapterParameters,\n    build_layer_weight_lookup,\n    load_and_merge_adapters,\n    AdapterInfo,\n)\nfrom text_generation_server.adapters.lora import LoraWeights\n\n\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.log import log_master\n\n# The flag below controls whether to allow TF32 on matmul. This flag defaults to False\n# in PyTorch 1.12 and later.\ntorch.backends.cuda.matmul.allow_tf32 = True\n\n# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.\ntorch.backends.cudnn.allow_tf32 = True\n\n# Disable gradients\ntorch.set_grad_enabled(False)\n\n__all__ = [\n    \"Model\",\n    \"CausalLM\",\n    \"Seq2SeqLM\",\n    \"get_model_with_lora_adapters\",\n]\n\nFLASH_ATT_ERROR_MESSAGE = \"{} requires Flash Attention enabled models.\"\n\nFLASH_ATTENTION = True\n\ntry:\n    from text_generation_server.models.flash_causal_lm import FlashCausalLM\n    from text_generation_server.models.vlm_causal_lm import VlmCausalLM\n    from text_generation_server.models.mllama_causal_lm import MllamaCausalLM\n    from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (\n        FlashDeepseekV2ForCausalLM,\n        DeepseekV2Config,\n    )\n    from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (\n        FlashDeepseekV3ForCausalLM,\n        DeepseekV3Config,\n    )\n    from text_generation_server.models.custom_modeling.flash_llama_modeling import (\n        FlashLlamaForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_cohere_modeling import (\n        FlashCohereForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n        FlashGemmaForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (\n        FlashGemma2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (\n        FlashGemma3ForCausalLM,\n        Gemma3ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.gemma3.processing_gemma3 import (\n        Gemma3Processor,\n    )\n    from text_generation_server.models.custom_modeling.gemma3.configuration_gemma3 import (\n        Gemma3Config,\n        Gemma3TextConfig,\n    )\n    from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (\n        FlashDbrxForCausalLM,\n        DbrxConfig,\n    )\n    from text_generation_server.models.custom_modeling.flash_rw_modeling import (\n        RWConfig,\n        FlashRWForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_neox_modeling import (\n        FlashGPTNeoXForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (\n        PaliGemmaForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.flash_phi_modeling import (\n        FlashPhiForCausalLM,\n    )\n    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM\n    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch\n    from text_generation_server.models.custom_modeling.mllama import (\n        MllamaForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.llava_next import (\n        LlavaNextForConditionalGeneration,\n    )\n\n    from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (\n        FlashSantacoderForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (\n        FlashStarcoder2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n        Qwen2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_mistral_modeling import (\n        FlashMistralForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (\n        FlashMixtralForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (\n        FlashGPT2ForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.flash_gptj_modeling import (\n        FlashGPTJForCausalLM,\n    )\n    from text_generation_server.models.custom_modeling.idefics2 import (\n        Idefics2ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.idefics3 import (\n        Idefics3ForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.qwen2_vl import (\n        Qwen2VLForConditionalGeneration,\n    )\n    from text_generation_server.models.custom_modeling.qwen2_5_vl import (\n        Qwen2_5VLForConditionalGeneration,\n        Qwen2_5_VLConfig,\n        Qwen2_5_VLProcessor,\n    )\n    from text_generation_server.layers.attention import SUPPORTS_WINDOWING\nexcept ImportError as e:\n    log_master(logger.warning, f\"Could not import Flash Attention enabled models: {e}\")\n    SUPPORTS_WINDOWING = False\n    FLASH_ATTENTION = False\n\nif FLASH_ATTENTION:\n    __all__.append(FlashCausalLM)\n    __all__.append(IdeficsCausalLM)\n\nMAMBA_AVAILABLE = True\ntry:\n    from text_generation_server.models.mamba import Mamba\nexcept ImportError as e:\n    log_master(logger.warning, f\"Could not import Mamba: {e}\")\n    MAMBA_AVAILABLE = False\n\nif MAMBA_AVAILABLE:\n    __all__.append(Mamba)\n\nFLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == \"ipex\"\n\ntry:\n    from text_generation_server.models.transformers_flash_causal_lm import (\n        TransformersFlashCausalLM,\n    )\n    from text_generation_server.models.transformers_flash_vlm import (\n        TransformersFlashVlmCausalLM,\n        TransformersGemma3VlmCausalLM,\n        TransformersLlama4VlmCausalLM,\n    )\nexcept ImportError as e:\n    log_master(logger.warning, f\"Could not import Flash Transformers Backend: {e}\")\n    FLASH_TRANSFORMERS_BACKEND = False\n\n\nclass ModelType(enum.Enum):\n    DEEPSEEK_V2 = {\n        \"type\": \"deepseek_v2\",\n        \"name\": \"Deepseek V2\",\n        \"url\": \"https://huggingface.co/deepseek-ai/DeepSeek-V2\",\n    }\n    DEEPSEEK_V3 = {\n        \"type\": \"deepseek_v3\",\n        \"name\": \"Deepseek V3\",\n        \"url\": \"https://huggingface.co/deepseek-ai/DeepSeek-V3\",\n    }\n    IDEFICS2 = {\n        \"type\": \"idefics2\",\n        \"name\": \"Idefics 2\",\n        \"url\": \"https://huggingface.co/HuggingFaceM4/idefics2-8b\",\n        \"multimodal\": True,\n    }\n    IDEFICS3 = {\n        \"type\": \"idefics3\",\n        \"name\": \"Idefics 3\",\n        \"url\": \"https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3\",\n        \"multimodal\": True,\n    }\n    LLAVA_NEXT = {\n        \"type\": \"llava_next\",\n        \"name\": \"Llava Next (1.6)\",\n        \"url\": \"https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf\",\n        \"multimodal\": True,\n    }\n    LLAMA = {\n        \"type\": \"llama\",\n        \"name\": \"Llama\",\n        \"url\": \"https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f\",\n    }\n    LLAMA4 = {\n        \"type\": \"llama4\",\n        \"name\": \"Llama4\",\n        \"url\": \"https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f\",\n    }\n    PHI3 = {\n        \"type\": \"phi3\",\n        \"name\": \"Phi 3\",\n        \"url\": \"https://huggingface.co/microsoft/Phi-3-mini-4k-instruct\",\n    }\n    GRANITE = {\n        \"type\": \"granite\",\n        \"name\": \"Granite\",\n        \"url\": \"https://huggingface.co/ibm-granite/granite-3.0-8b-instruct\",\n    }\n    GEMMA = {\n        \"type\": \"gemma\",\n        \"name\": \"Gemma\",\n        \"url\": \"https://huggingface.co/google/gemma-7b\",\n    }\n    PALIGEMMA = {\n        \"type\": \"paligemma\",\n        \"name\": \"PaliGemma\",\n        \"url\": \"https://huggingface.co/google/paligemma-3b-pt-224\",\n    }\n    GEMMA2 = {\n        \"type\": \"gemma2\",\n        \"name\": \"Gemma2\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315\",\n    }\n    GEMMA3 = {\n        \"type\": \"gemma3\",\n        \"name\": \"Gemma3\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d\",\n    }\n    GEMMA3_TEXT = {\n        \"type\": \"gemma3_text\",\n        \"name\": \"Gemma3 Text\",\n        \"url\": \"https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d\",\n    }\n    COHERE = {\n        \"type\": \"cohere\",\n        \"name\": \"Cohere\",\n        \"url\": \"https://huggingface.co/CohereForAI/c4ai-command-r-plus\",\n    }\n    DBRX = {\n        \"type\": \"dbrx\",\n        \"name\": \"Dbrx\",\n        \"url\": \"https://huggingface.co/databricks/dbrx-instruct\",\n    }\n    MAMBA = {\n        \"type\": \"mamba\",\n        \"name\": \"Mamba\",\n        \"url\": \"https://huggingface.co/state-spaces/mamba-2.8b-slimpj\",\n    }\n    MISTRAL = {\n        \"type\": \"mistral\",\n        \"name\": \"Mistral\",\n        \"url\": \"https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407\",\n    }\n    MIXTRAL = {\n        \"type\": \"mixtral\",\n        \"name\": \"Mixtral\",\n        \"url\": \"https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1\",\n    }\n    GPT_BIGCODE = {\n        \"type\": \"gpt_bigcode\",\n        \"name\": \"Gpt Bigcode\",\n        \"url\": \"https://huggingface.co/bigcode/gpt_bigcode-santacoder\",\n    }\n    PHI = {\n        \"type\": \"phi\",\n        \"name\": \"Phi\",\n        \"url\": \"https://huggingface.co/microsoft/phi-1_5\",\n    }\n    PHI_MOE = {\n        \"type\": \"phimoe\",\n        \"name\": \"PhiMoe\",\n        \"url\": \"https://huggingface.co/microsoft/Phi-3.5-MoE-instruct\",\n    }\n    BAICHUAN = {\n        \"type\": \"baichuan\",\n        \"name\": \"Baichuan\",\n        \"url\": \"https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat\",\n    }\n    FALCON = {\n        \"type\": \"falcon\",\n        \"name\": \"Falcon\",\n        \"url\": \"https://huggingface.co/tiiuae/falcon-7b-instruct\",\n    }\n    STARCODER2 = {\n        \"type\": \"starcoder2\",\n        \"name\": \"StarCoder 2\",\n        \"url\": \"https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1\",\n    }\n    QWEN2 = {\n        \"type\": \"qwen2\",\n        \"name\": \"Qwen 2\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f\",\n    }\n    QWEN2_VL = {\n        \"type\": \"qwen2_vl\",\n        \"name\": \"Qwen 2 VL\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d\",\n    }\n    QWEN2_5_VL = {\n        \"type\": \"qwen2_5_vl\",\n        \"name\": \"Qwen 2.5 VL\",\n        \"url\": \"https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e\",\n    }\n    OPT = {\n        \"type\": \"opt\",\n        \"name\": \"Opt\",\n        \"url\": \"https://huggingface.co/facebook/opt-6.7b\",\n    }\n    T5 = {\n        \"type\": \"t5\",\n        \"name\": \"T5\",\n        \"url\": \"https://huggingface.co/google/flan-t5-xxl\",\n    }\n    GALACTICA = {\n        \"type\": \"galactica\",\n        \"name\": \"Galactica\",\n        \"url\": \"https://huggingface.co/facebook/galactica-120b\",\n    }\n    SANTACODER = {\n        \"type\": \"santacoder\",\n        \"name\": \"SantaCoder\",\n        \"url\": \"https://huggingface.co/bigcode/santacoder\",\n    }\n    BLOOM = {\n        \"type\": \"bloom\",\n        \"name\": \"Bloom\",\n        \"url\": \"https://huggingface.co/bigscience/bloom-560m\",\n    }\n    MPT = {\n        \"type\": \"mpt\",\n        \"name\": \"Mpt\",\n        \"url\": \"https://huggingface.co/mosaicml/mpt-7b-instruct\",\n    }\n    GPT2 = {\n        \"type\": \"gpt2\",\n        \"name\": \"Gpt2\",\n        \"url\": \"https://huggingface.co/openai-community/gpt2\",\n    }\n    GPT_NEOX = {\n        \"type\": \"gpt_neox\",\n        \"name\": \"Gpt Neox\",\n        \"url\": \"https://huggingface.co/EleutherAI/gpt-neox-20b\",\n    }\n    GPTJ = {\n        \"type\": \"gptj\",\n        \"name\": \"Gptj\",\n        \"url\": \"https://huggingface.co/EleutherAI/gpt-j-6b\",\n    }\n    IDEFICS = {\n        \"type\": \"idefics\",\n        \"name\": \"Idefics\",\n        \"url\": \"https://huggingface.co/HuggingFaceM4/idefics-9b\",\n        \"multimodal\": True,\n    }\n    MLLAMA = {\n        \"type\": \"mllama\",\n        \"name\": \"Mllama\",\n        \"url\": \"https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct\",\n        \"multimodal\": True,\n    }\n\n\n__GLOBALS = locals()\nfor data in ModelType:\n    __GLOBALS[data.name] = data.value[\"type\"]\n\n\ndef get_model(\n    model_id: str,\n    lora_adapter_ids: Optional[List[str]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[str],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    max_input_tokens: int,\n) -> Model:\n    global FLASH_ATTENTION\n\n    config_dict, _ = PretrainedConfig.get_config_dict(\n        model_id, revision=revision, trust_remote_code=trust_remote_code\n    )\n    model_type = config_dict.get(\"model_type\", None)\n\n    quantization_config = config_dict.get(\"quantization_config\", None)\n    if quantization_config is None:\n        quantization_config = config_dict.get(\"compression_config\", None)\n    if quantization_config is not None and quantize is None:\n        method = quantization_config.get(\"quant_method\", None)\n        if method in {\"gptq\", \"awq\", \"exl2\"}:\n            log_master(logger.info, f\"Auto selecting quantization method {method}\")\n            quantize = method\n        elif method == \"fbgemm_fp8\" or method == \"fp8\":\n            log_master(logger.info, \"Auto selecting quantization method fp8\")\n            quantize = \"fp8\"\n        if method == \"compressed-tensors\":\n            log_master(\n                logger.info, \"Auto selecting quantization method compressed-tensors\"\n            )\n            quantize = \"compressed-tensors\"\n\n        else:\n            log_master(logger.warning, f\"Unknown quantization method {method}\")\n\n    if dtype is None:\n        if quantize in [\"awq\", \"exl2\", \"gptq\", \"marlin\"]:\n            if SYSTEM == \"ipex\" and not (\n                hasattr(torch, \"xpu\") and torch.xpu.is_available()\n            ):\n                dtype = torch.bfloat16\n            else:\n                # These quantizers only work with float16 params.\n                dtype = torch.float16\n        else:\n            # Keep it as default for now and let\n            # every model resolve their own default dtype.\n            dtype = None\n    elif dtype == \"float16\":\n        dtype = torch.float16\n    elif dtype == \"bfloat16\":\n        dtype = torch.bfloat16\n    else:\n        raise RuntimeError(f\"Unknown dtype {dtype}\")\n\n    compressed_tensors_config = None\n    if quantize == \"compressed-tensors\":\n        try:\n            compressed_tensors_config = QuantizationConfig.model_validate(\n                quantization_config\n            )\n        except ValidationError as e:\n            raise ValueError(\"Cannot parse compressed-tensors configuration\") from e\n\n    if kv_cache_dtype is None:\n        kv_cache_scheme = (\n            compressed_tensors_config.kv_cache_scheme\n            if isinstance(compressed_tensors_config, QuantizationConfig)\n            else None\n        )\n        if (\n            kv_cache_scheme is not None\n            and kv_cache_scheme.type == QuantizationType.FLOAT\n            and kv_cache_scheme.num_bits == 8\n            and SYSTEM == \"cuda\"\n            and ATTENTION == \"flashinfer\"\n        ):\n            kv_cache_dtype = torch.float8_e4m3fn\n        else:\n            kv_cache_dtype = dtype\n    elif kv_cache_dtype == \"fp8_e4m3fn\":\n        kv_cache_dtype = torch.float8_e4m3fn\n    elif kv_cache_dtype == \"fp8_e5m2\":\n        kv_cache_dtype = torch.float8_e5m2\n    else:\n        raise RuntimeError(f\"Unknown kv_cache_dtype: {kv_cache_dtype}\")\n\n    if speculate is not None:\n        set_speculate(speculate)\n    else:\n        set_speculate(0)\n\n    speculator = None\n    if \"medusa_num_heads\" in config_dict:\n        medusa_model_id = model_id\n        medusa_revision = revision\n        model_id = config_dict[\"base_model_name_or_path\"]\n        revision = \"main\"\n        speculate_medusa = config_dict[\"medusa_num_heads\"]\n        if speculate is not None:\n            if speculate > speculate_medusa:\n                raise RuntimeError(\n                    f\"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match\"\n                )\n            else:\n                set_speculate(speculate)\n        else:\n            set_speculate(speculate_medusa)\n\n        config_dict, _ = PretrainedConfig.get_config_dict(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        # Reload model type from parent.\n        model_type = config_dict.get(\"model_type\", None)\n        is_local = Path(medusa_model_id).exists()\n        if not is_local:\n            medusa_config = hf_hub_download(\n                medusa_model_id, revision=medusa_revision, filename=\"config.json\"\n            )\n            hf_hub_download(\n                medusa_model_id,\n                revision=medusa_revision,\n                filename=\"medusa_lm_head.safetensors\",\n            )\n            speculator = {\n                \"path\": Path(medusa_config).parent,\n                \"model_paths\": [\"medusa_lm_head.safetensors\"],\n            }\n        else:\n            speculator = {\n                \"path\": Path(medusa_model_id),\n                \"model_paths\": [\"medusa_lm_head.safetensors\"],\n            }\n\n        method = \"medusa\"\n    elif model_type == \"mlp_speculator\":\n        mlp_model_id = model_id\n        mlp_revision = revision\n        model_id = config_dict[\"base_model_name_or_path\"]\n        revision = \"main\"\n        speculate_mlp = config_dict[\"n_predict\"]\n        if speculate is not None:\n            if speculate > speculate_mlp:\n                raise RuntimeError(\n                    f\"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match\"\n                )\n            else:\n                set_speculate(speculate)\n        else:\n            set_speculate(speculate_mlp)\n\n        config_dict, _ = PretrainedConfig.get_config_dict(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        # Reload model type from parent.\n        model_type = config_dict.get(\"model_type\", None)\n        is_local = Path(mlp_model_id).exists()\n        extension = \".safetensors\"\n        if not is_local:\n            mlp_speculator_config = hf_hub_download(\n                mlp_model_id, revision=mlp_revision, filename=\"config.json\"\n            )\n            api = HfApi()\n            info = api.model_info(mlp_model_id, revision=mlp_revision)\n            filenames = [\n                s.rfilename\n                for s in info.siblings\n                if s.rfilename.endswith(extension)\n                and len(s.rfilename.split(\"/\")) == 1\n                and \"arguments\" not in s.rfilename\n                and \"args\" not in s.rfilename\n                and \"training\" not in s.rfilename\n            ]\n            for filename in filenames:\n                hf_hub_download(\n                    mlp_model_id,\n                    revision=mlp_revision,\n                    filename=filename,\n                )\n            speculator_dir_path = Path(mlp_speculator_config).parent\n            # if these are downloaded, they get converted to safetensors\n            filenames.extend(\n                [p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]\n            )\n            speculator = {\n                \"path\": Path(mlp_speculator_config).parent,\n                \"model_paths\": filenames,\n            }\n        else:\n            speculator = Path(mlp_model_id)\n            filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]\n            speculator = {\"path\": speculator, \"model_paths\": filenames}\n        method = \"mlp_speculator\"\n    else:\n        method = \"n-gram\"\n\n    speculate = get_speculate()\n    if speculate > 0:\n        log_master(\n            logger.info, f\"Using speculation {method} with {speculate} input ids.\"\n        )\n\n    if model_type is None:\n        # TODO: fix how we determine model type for Mamba\n        if \"ssm_cfg\" in config_dict:\n            # *only happens in Mamba case\n            model_type = \"mamba\"\n        else:\n            raise RuntimeError(\n                f\"Could not determine model type for {model_id} revision {revision}\"\n            )\n\n    if quantize == \"exl2\" and sharded:\n        raise RuntimeError(\n            \"Sharding is currently not supported with `exl2` quantization\"\n        )\n\n    sliding_window = (\n        config_dict.get(\"sliding_window\")\n        if config_dict.get(\"sliding_window\") is not None\n        else -1\n    )\n\n    use_sliding_window = sliding_window is not None and sliding_window != -1\n    needs_sliding_window = (\n        max_input_tokens is not None and max_input_tokens > sliding_window\n    )\n    if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING:\n        raise ValueError(\n            f\"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens}).\"\n        )\n    if model_type == DEEPSEEK_V2:\n        if FLASH_ATTENTION:\n            head_size = max(\n                config_dict.get(\"qk_nope_dim\", 128)\n                + config_dict.get(\"qk_rope_dim\", 64),\n                config_dict.get(\"v_head_dim\", 128),\n            )\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDeepseekV2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                default_dtype=torch.bfloat16,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DeepseekV2Config,\n                head_size=head_size,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Deepseek V2\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == DEEPSEEK_V3:\n        if FLASH_ATTENTION:\n            head_size = max(\n                config_dict.get(\"qk_nope_dim\", 128)\n                + config_dict.get(\"qk_rope_dim\", 64),\n                config_dict.get(\"v_head_dim\", 128),\n            )\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDeepseekV3ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                default_dtype=torch.bfloat16,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DeepseekV3Config,\n                head_size=head_size,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Deepseek V3\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == MAMBA:\n        return Mamba(\n            model_id,\n            revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n        )\n    elif model_type == \"ssm\":\n        raise RuntimeError(\n            \"`ssm` models have been deprecated in favor of `mamba` models, which follow standard HF formats. Check out a list here: https://huggingface.co/models?search=mamba%20-hf\"\n        )\n\n    if model_id.startswith(\"facebook/galactica\"):\n        return CausalLM(\n            model_id=model_id,\n            # Yes galactica is just an OPT model.\n            model_class=OPTForCausalLM,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n            batch_class=GalacticaCausalLMBatch,\n        )\n\n    if (\n        model_type == GPT_BIGCODE\n        or model_type == GPT2\n        and model_id.startswith(\"bigcode/\")\n    ):\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashSantacoderForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                aliases={\"transformer.wte.weight\": [\"lm_head.weight\"]},\n                num_kv_heads=1,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Santacoder\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id=model_id,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == BLOOM:\n        return CausalLM(\n            model_id=model_id,\n            model_class=BloomForCausalLM,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n            batch_class=BloomCausalLMBatch,\n        )\n    elif model_type == MPT:\n        return CausalLM(\n            model_id=model_id,\n            model_class=MPTForCausalLM,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n            batch_class=CausalLMBatchKeysLast,\n        )\n    elif model_type == GPT2:\n        if FLASH_ATTENTION:\n            try:\n                return FlashCausalLM(\n                    model_id=model_id,\n                    model_class=FlashGPT2ForCausalLM,\n                    revision=revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    kv_cache_dtype=kv_cache_dtype,\n                    trust_remote_code=trust_remote_code,\n                    lora_adapter_ids=lora_adapter_ids,\n                )\n            except RuntimeError as e:\n                # Lots of legacy models with various weight names.\n                log_master(logger.warning, f\"Couldn't load flash gpt2 variant: {e}\")\n                return CausalLM.fallback(\n                    model_id,\n                    revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    trust_remote_code=trust_remote_code,\n                )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded GPT-2\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == GPTJ:\n        if FLASH_ATTENTION:\n            try:\n                return FlashCausalLM(\n                    model_id=model_id,\n                    model_class=FlashGPTJForCausalLM,\n                    revision=revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    kv_cache_dtype=kv_cache_dtype,\n                    trust_remote_code=trust_remote_code,\n                    lora_adapter_ids=lora_adapter_ids,\n                )\n            except RuntimeError as e:\n                # Lots of legacy models with various weight names.\n                log_master(logger.warning, f\"Couldn't load flash gptj variant: {e}\")\n                return CausalLM.fallback(\n                    model_id,\n                    revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    trust_remote_code=trust_remote_code,\n                )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded GPT-J\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == GPT_NEOX:\n        if FLASH_ATTENTION:\n            from text_generation_server.models.custom_modeling.flash_neox_modeling import (\n                GPTNeoXConfig,\n            )\n\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGPTNeoXForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=GPTNeoXConfig,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            return CausalLM(\n                model_id=model_id,\n                model_class=GPTNeoxForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    elif model_type == PHI:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashPhiForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        else:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    elif model_type == PHI_MOE:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                config_class=PhiMoEConfig,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    elif model_type == \"phi-msft\":\n        if FLASH_ATTENTION:\n            raise NotImplementedError(\n                \"Legacy phi-msft is not supported with Flash Attention\"\n            )\n        else:\n            return CausalLM(\n                model_id=model_id,\n                model_class=PhiForCausalLM,\n                config_class=PhiConfig,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(f\"Sharded {model_type}\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == LLAMA4:\n        if FLASH_TRANSFORMERS_BACKEND:\n            from transformers import Llama4ForConditionalGeneration as Llama4Model\n\n            return TransformersLlama4VlmCausalLM.fallback(\n                model_id,\n                Llama4Model,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                processor_kwargs={\n                    \"use_fast\": True,\n                },\n            )\n    elif model_type == BAICHUAN:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashLlamaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(f\"Sharded {model_type}\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == GEMMA:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemmaForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Gemma\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == GEMMA2:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemma2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Gemma2\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == GEMMA3_TEXT:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashGemma3ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # TODO: once implemented in transformers, use the config class\n                # and processor class from there.\n                config_class=Gemma3TextConfig,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Gemma3\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n    elif model_type == GEMMA3:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=Gemma3ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # TODO: once implemented in transformers, use the config class\n                # and processor class from there.\n                config_class=Gemma3Config,\n                processor_class=Gemma3Processor,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                support_chunking=False,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            from transformers import Gemma3ForConditionalGeneration as Gemma3Model\n\n            return TransformersGemma3VlmCausalLM.fallback(\n                model_id,\n                Gemma3Model,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                support_chunking=False,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Gemma3\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == COHERE:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashCohereForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Cohere\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == DBRX:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashDbrxForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Dbrx works better in bfloat16.\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=DbrxConfig,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded DBRX\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type in [\"RefinedWeb\", \"RefinedWebModel\", FALCON]:\n        if sharded:\n            if FLASH_ATTENTION:\n                if config_dict.get(\"alibi\", False):\n                    raise NotImplementedError(\"sharded is not supported for this model\")\n                return FlashCausalLM(\n                    model_id=model_id,\n                    model_class=FlashRWForCausalLM,\n                    revision=revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    kv_cache_dtype=kv_cache_dtype,\n                    aliases={\n                        \"lm_head.weight\": [\"transformer.word_embeddings.weight\"],\n                        \"transformer.word_embeddings.weight\": [\"lm_head.weight\"],\n                    },\n                    trust_remote_code=trust_remote_code,\n                    lora_adapter_ids=lora_adapter_ids,\n                    config_class=RWConfig,\n                )\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Falcon\"))\n        else:\n            if FLASH_ATTENTION and not config_dict.get(\"alibi\", False):\n                return FlashCausalLM(\n                    model_id=model_id,\n                    model_class=FlashRWForCausalLM,\n                    revision=revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    kv_cache_dtype=kv_cache_dtype,\n                    aliases={\n                        \"lm_head.weight\": [\"transformer.word_embeddings.weight\"],\n                        \"transformer.word_embeddings.weight\": [\"lm_head.weight\"],\n                    },\n                    trust_remote_code=trust_remote_code,\n                    lora_adapter_ids=lora_adapter_ids,\n                    config_class=RWConfig,\n                )\n            else:\n                return CausalLM.fallback(\n                    model_id,\n                    revision,\n                    quantize=quantize,\n                    speculator=speculator,\n                    dtype=dtype,\n                    trust_remote_code=trust_remote_code,\n                )\n\n    if model_type == MISTRAL:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashMistralForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Mistral\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == MIXTRAL:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashMixtralForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Mixtral\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == STARCODER2:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=FlashStarcoder2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(\n                FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Starcoder2\")\n            )\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == QWEN2:\n        if FLASH_ATTENTION:\n            return FlashCausalLM(\n                model_id=model_id,\n                model_class=Qwen2ForCausalLM,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Sharded Qwen2\"))\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    if model_type == OPT:\n        return CausalLM(\n            model_id=model_id,\n            model_class=OPTForCausalLM,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n        )\n\n    if model_type == T5:\n        return Seq2SeqLM(\n            model_id=model_id,\n            model_class=T5ForConditionalGeneration,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n            aliases={\n                \"shared.weight\": [\n                    \"encoder.embed_tokens.weight\",\n                    \"decoder.embed_tokens.weight\",\n                ]\n            },\n        )\n    if model_type == IDEFICS:\n        if FLASH_ATTENTION:\n            return IdeficsCausalLM(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Idefics\"))\n    if model_type == QWEN2_VL:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=Qwen2VLForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # TODO: Fix bug in rust image_text_replacement implementation\n                support_chunking=False,\n            )\n        # TODO: Uncomment when transformers is refactored\n        # elif FLASH_TRANSFORMERS_BACKEND:\n        #     from transformers import Qwen2VLForConditionalGeneration as Qwen2VLModel\n\n        #     return TransformersQwen2VlmCausalLM.fallback(\n        #         model_id,\n        #         Qwen2VLModel,\n        #         revision,\n        #         quantize=quantize,\n        #         speculator=speculator,\n        #         dtype=torch.bfloat16,\n        #         trust_remote_code=trust_remote_code,\n        #     )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Qwen2_VL\"))\n    if model_type == QWEN2_5_VL:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=Qwen2_5VLForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                config_class=Qwen2_5_VLConfig,\n                processor_class=Qwen2_5_VLProcessor,\n                # TODO: Fix bug in rust image_text_replacement implementation\n                support_chunking=False,\n            )\n        # TODO: Uncomment when transformers is refactored\n        # elif FLASH_TRANSFORMERS_BACKEND:\n        #     return TransformersQwen2VlmCausalLM.fallback(\n        #         model_id,\n        #         Qwen2VLModel,\n        #         revision,\n        #         quantize=quantize,\n        #         speculator=speculator,\n        #         dtype=torch.bfloat16,\n        #         trust_remote_code=trust_remote_code,\n        #         config_class=Qwen2_5_VLConfig,\n        #         processor_class=Qwen2_5_VLProcessor,\n        #     )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Qwen2_5_VL\"))\n    if model_type == MLLAMA:\n        if FLASH_ATTENTION:\n            return MllamaCausalLM(\n                model_id=model_id,\n                model_class=MllamaForConditionalGeneration,\n                batch_class=MllamaCausalLMBatch,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                support_chunking=False,\n            )\n        # TODO: Uncomment when transformers is refactored and cross attn is added\n        # elif FLASH_TRANSFORMERS_BACKEND:\n        #     from transformers import MllamaForConditionalGeneration as MllamaModel\n\n        #     return TransformersFlashVlmCausalLM.fallback(\n        #         model_id,\n        #         MllamaModel,\n        #         revision,\n        #         quantize=quantize,\n        #         speculator=speculator,\n        #         dtype=torch.bfloat16,\n        #         trust_remote_code=trust_remote_code,\n        #         batch_class=MllamaCausalLMBatch,\n        #     )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Mllama\"))\n    if model_type == IDEFICS2:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=Idefics2ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # XXX: Extremely important to cap resolution in order to limit\n                # VRAM usage.\n                processor_kwargs={\"size\": {\"longest_edge\": 448, \"shortest_edge\": 378}},\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            from transformers import Idefics2ForConditionalGeneration as Idefics2Model\n\n            return TransformersFlashVlmCausalLM.fallback(\n                model_id,\n                Idefics2Model,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n                processor_kwargs={\"size\": {\"longest_edge\": 448, \"shortest_edge\": 378}},\n            )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Idefics\"))\n    if model_type == IDEFICS3:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=Idefics3ForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n                # XXX: Extremely important to cap resolution in order to limit\n                # VRAM usage.\n                processor_kwargs={\"size\": {\"longest_edge\": 1456}},\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            from transformers import Idefics3ForConditionalGeneration as Idefics3Model\n\n            return TransformersFlashVlmCausalLM.fallback(\n                model_id,\n                Idefics3Model,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                processor_kwargs={\"size\": {\"longest_edge\": 1456}},\n            )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"Idefics\"))\n    if model_type == PALIGEMMA:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_id=model_id,\n                model_class=PaliGemmaForConditionalGeneration,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                # Works better for these models\n                default_dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n                lora_adapter_ids=lora_adapter_ids,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel\n\n            return TransformersFlashVlmCausalLM.fallback(\n                model_id,\n                PaliGemmaModel,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=torch.bfloat16,\n                trust_remote_code=trust_remote_code,\n            )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"PaliGemma\"))\n    if model_type == LLAVA_NEXT:\n        if FLASH_ATTENTION:\n            return VlmCausalLM(\n                model_class=LlavaNextForConditionalGeneration,\n                model_id=model_id,\n                revision=revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                kv_cache_dtype=kv_cache_dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif FLASH_TRANSFORMERS_BACKEND:\n            from transformers import LlavaNextForConditionalGeneration as LlavaNextModel\n\n            return TransformersFlashVlmCausalLM.fallback(\n                model_id,\n                LlavaNextModel,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        else:\n            raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(\"LlavaNext\"))\n\n    if quantize == \"gptq\":\n        raise NotImplementedError(\n            \"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`\"\n        )\n    if quantize == \"awq\":\n        raise NotImplementedError(\"awq quantization is not supported for AutoModel\")\n    elif (quantize == \"bitsandbytes-fp4\") or (quantize == \"bitsandbytes-nf4\"):\n        raise NotImplementedError(\"4bit quantization is not supported for AutoModel\")\n    elif quantize == \"eetq\":\n        raise NotImplementedError(\"Eetq quantization is not supported for AutoModel\")\n    elif quantize == \"exl2\":\n        raise NotImplementedError(\"exl2 quantization is not supported for AutoModel\")\n\n    auto_map = config_dict.get(\"auto_map\", None)\n    model_class = None\n\n    # If the model is already in the library\n    if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:\n        model_class = getattr(\n            transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]\n        )\n    elif (\n        trust_remote_code\n        and auto_map is not None\n        and \"AutoModelForCausalLM\" in auto_map.keys()\n    ):\n        model_class = get_class_from_dynamic_module(\n            config_dict[\"auto_map\"][\"AutoModelForCausalLM\"], model_id\n        )\n\n    # This means the model is ForCausalLM\n    if model_class is not None:\n        if FLASH_TRANSFORMERS_BACKEND and model_class.is_backend_compatible():\n            return TransformersFlashCausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n        elif sharded:\n            raise NotImplementedError(\"sharded is not supported for AutoModel\")\n        else:\n            return CausalLM.fallback(\n                model_id,\n                revision,\n                quantize=quantize,\n                speculator=speculator,\n                dtype=dtype,\n                trust_remote_code=trust_remote_code,\n            )\n\n    # Not supported at this point\n    if sharded:\n        raise NotImplementedError(\"sharded is not supported for AutoModel\")\n\n    # This means it is a ForSeq2SeqLM model\n    if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or (\n        trust_remote_code\n        and auto_map is not None\n        and \"AutoModelForSeq2SeqLM\" in auto_map.keys()\n    ):\n        return Seq2SeqLM.fallback(\n            model_id,\n            revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n        )\n\n    raise ValueError(f\"Unsupported model type {model_type}\")\n\n\n# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters\n# this provides a post model loading hook to load adapters into the model after the model has been loaded\ndef get_model_with_lora_adapters(\n    model_id: str,\n    lora_adapters: Optional[List[AdapterInfo]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[str],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    max_input_tokens: int,\n    adapter_to_index: Dict[str, int],\n):\n    lora_adapter_ids = [adapter.id for adapter in lora_adapters]\n    model = get_model(\n        model_id,\n        lora_adapter_ids,\n        revision,\n        sharded,\n        quantize,\n        speculate,\n        dtype,\n        kv_cache_dtype,\n        trust_remote_code,\n        max_input_tokens,\n    )\n\n    if len(lora_adapters) > 0:\n        target_to_layer = build_layer_weight_lookup(model.model)\n\n        for index, adapter in enumerate(lora_adapters):\n            # The AdapterParameters object allows for merging multiple adapters into a single adapter.\n            # At the moment, we only support loading a single adapter into the model, but we keep the\n            # AdapterParameters object for easier extension in the future.\n            adapter_parameters = AdapterParameters(\n                adapter_info=[adapter],\n                # when merging multiple adapters we can weight them differently\n                # if this is not set, all adapters will be weighted equally\n                # see: text_generation_server.utils.merges.strategies for impl\n                weights=None,\n                merge_strategy=0,\n                density=1.0,\n                majority_sign_method=0,\n            )\n\n            adapter_index = index + 1\n            adapter_to_index[adapter.id] = adapter_index\n\n            logger.info(\n                f\"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}\"\n            )\n            weight_names = tuple([v[0] for v in target_to_layer.values()])\n            (\n                module_map,\n                adapter_config,\n                adapter_weight_names,\n                adapter_tokenizer,\n            ) = load_and_merge_adapters(\n                model.model_id,\n                adapter_parameters,\n                adapter_index,\n                weight_names,\n                False,\n            )\n\n            unused_weight_names = adapter_weight_names.copy()\n\n            adapter_layers = [\n                \"q_proj\",\n                \"k_proj\",\n                \"v_proj\",\n                \"o_proj\",\n                \"gate_proj\",\n                \"up_proj\",\n                \"down_proj\",\n                \"qkv_proj\",\n                # add c_* layers used in starcoder2\n                \"c_proj\",\n                \"c_fc\",\n            ]\n\n            for layer_name in adapter_layers:\n                nlayers = (\n                    1 if layer_name == \"lm_head\" else len(model.model.model.layers)\n                )\n                adapter_weights = LoraWeights.prepare_weights(\n                    config=adapter_config,\n                    module_map=module_map,\n                    layer_type=layer_name,\n                    unused_weight_names=unused_weight_names,\n                    nlayers=nlayers,\n                    dtype=model.dtype,\n                    world_size=model.world_size,\n                    process_group=model.process_group,\n                    target_to_layer=target_to_layer,\n                )\n\n                if adapter_weights is None:\n                    continue\n\n                model.layer_to_adapter_weights[layer_name].add_adapter(\n                    adapter_index, adapter_weights\n                )\n\n            if len(unused_weight_names) > 0:\n                logger.warning(\n                    f\"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}\"\n                )\n\n            if adapter_tokenizer is not None:\n                model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)\n\n            model.loaded_adapters.add(adapter_index)\n\n    return model\n"
  },
  {
    "path": "server/text_generation_server/models/bloom.py",
    "content": "import torch\nimport torch.distributed\n\nfrom typing import Optional, Type\n\nfrom transformers import (\n    PreTrainedTokenizerBase,\n)\n\nfrom text_generation_server.models import CausalLM\nfrom text_generation_server.models.causal_lm import CausalLMBatch\nfrom text_generation_server.pb import generate_pb2\n\n\nclass BloomCausalLMBatch(CausalLMBatch):\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"CausalLMBatch\":\n        batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)\n        batch.keys_head_dim_last = False\n        return batch\n\n\nclass BLOOMSharded(CausalLM):\n    @property\n    def batch_type(self) -> Type[CausalLMBatch]:\n        return BloomCausalLMBatch\n\n    def forward(\n        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None\n    ):\n        outputs, speculative_logits = self.model.forward(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            use_cache=True,\n        )\n\n        logits = outputs.logits\n        return logits, speculative_logits, outputs.past_key_values\n"
  },
  {
    "path": "server/text_generation_server/models/causal_lm.py",
    "content": "import torch\nimport time\nimport torch.distributed\n\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    AutoConfig,\n    AutoTokenizer,\n    AutoModelForCausalLM,\n    PreTrainedTokenizerBase,\n)\nfrom typing import Optional, Tuple, List, Type, Dict\n\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.models import Model\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.tokens import batch_top_tokens\nfrom text_generation_server.models.types import (\n    Batch,\n    Tokens,\n    Generation,\n    GeneratedText,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass CausalLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    requests_idx_mapping: Dict[int, int]\n\n    # Decoder values\n    input_ids: torch.Tensor\n    attention_mask: torch.Tensor\n    position_ids: torch.Tensor\n    past_key_values: Optional[List[Tuple]]\n\n    # All tokens\n    all_input_ids: List[torch.Tensor]\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    prefix_offsets: List[int]\n    read_offsets: List[int]\n\n    # Generation helpers\n    next_token_choosers: List[NextTokenChooser]\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Metadata used for padding\n    max_input_length: int\n    padding_right_offset: int\n\n    # Maximum number of tokens this batch will grow to\n    max_tokens: int\n\n    # Past metadata\n    keys_head_dim_last: bool = True\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.max_tokens,\n            current_tokens=len(self.input_ids),\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"CausalLMBatch\":\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        prefix_offsets = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            requests_idx_mapping[r.id] = i\n            inputs.append(concat_text_chunks(r.input_chunks.chunks))\n\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        tokenized_inputs = tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            return_token_type_ids=False,\n            truncation=True,\n            max_length=max_truncation,\n        ).to(device)\n        for _ in pb.requests:\n            input_len = tokenized_inputs[\"input_ids\"].shape[1]\n            prefix_offsets.append(input_len - 5)\n            read_offsets.append(input_len)\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n\n        input_ids = tokenized_inputs[\"input_ids\"]\n        # Allocate maximum attention_mask\n        attention_mask = input_ids.new_zeros(\n            (pb.size, max_input_length + padding_right_offset)\n        )\n        # Copy tokenizer attention_mask into fully allocated attention_mask\n        attention_mask[:, :max_input_length] = tokenized_inputs[\"attention_mask\"]\n\n        position_ids = tokenized_inputs[\"attention_mask\"].long().cumsum(-1) - 1\n        position_ids.masked_fill_(tokenized_inputs[\"attention_mask\"] == 0, 1)\n        all_input_ids = tokenized_inputs[\"input_ids\"].T.split(1, dim=1)\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n\n        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=None,\n            all_input_ids=list(all_input_ids),\n            input_lengths=input_lengths.tolist(),\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length.item(),\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> Optional[\"CausalLMBatch\"]:\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        if len(request_ids) == len(self):\n            return self\n\n        keep_indices = []\n\n        # New values after filtering\n        requests_idx_mapping = {}\n        requests = []\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        max_input_length = 0\n\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        total_remaining_decode_tokens = 0\n        new_padding_right_offset = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            requests_idx_mapping[request_id] = i\n            keep_indices.append(idx)\n\n            requests.append(self.requests[idx])\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n            all_input_ids.append(self.all_input_ids[idx])\n\n            request_input_length = self.input_lengths[idx]\n            input_lengths.append(request_input_length)\n            max_input_length = max(max_input_length, request_input_length)\n\n            next_token_choosers.append(self.next_token_choosers[idx])\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(self.top_n_tokens[idx])\n            remaining_decode_tokens = (\n                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens\n            )\n            total_remaining_decode_tokens += remaining_decode_tokens\n            new_padding_right_offset = max(\n                new_padding_right_offset, remaining_decode_tokens\n            )\n\n        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached\n        input_ids = self.input_ids[keep_indices]\n        position_ids = self.position_ids[keep_indices]\n        self.attention_mask = self.attention_mask[\n            keep_indices,\n            -(self.padding_right_offset + max_input_length) : (\n                self.attention_mask.shape[1] - self.padding_right_offset\n            )\n            + new_padding_right_offset,\n        ]\n\n        # Ensure that past_key_values tensors can be updated in-place\n        if type(self.past_key_values[0]) is tuple:\n            self.past_key_values = [list(layer) for layer in self.past_key_values]\n\n        # Update tensors in-place to allow incremental garbage collection\n        past_kv_length = max_input_length - 1\n        for layer in self.past_key_values:\n            past_keys, past_values = layer\n            if len(past_keys.shape) == 3:\n                # Force past to be of dim [self_size, num_heads, ...] for easy indexing\n                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])\n                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])\n            if self.keys_head_dim_last:\n                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]\n            else:\n                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]\n            del past_keys\n            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]\n            del past_values\n\n        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]\n        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens\n\n        self.requests = requests\n        self.requests_idx_mapping = requests_idx_mapping\n        self.input_ids = input_ids\n        self.position_ids = position_ids\n        self.all_input_ids = all_input_ids\n        self.input_lengths = input_lengths\n        self.prefix_offsets = prefix_offsets\n        self.read_offsets = read_offsets\n        self.next_token_choosers = next_token_choosers\n        self.stopping_criterias = stopping_criterias\n        self.top_n_tokens = top_n_tokens\n        self.top_n_tokens_tensor = top_n_tokens_tensor\n        self.max_input_length = max_input_length\n        self.padding_right_offset = new_padding_right_offset\n        self.max_tokens = max_tokens\n\n        return self\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches: List[\"CausalLMBatch\"]) -> \"CausalLMBatch\":\n        # Used for padding\n        total_batch_size = 0\n        max_input_length = 0\n        padding_right_offset = 0\n        for batch in batches:\n            total_batch_size += len(batch)\n            max_input_length = max(max_input_length, batch.max_input_length)\n            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)\n\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        max_tokens = 0\n\n        # Batch tensors\n        input_ids = None\n        attention_mask = None\n        position_ids = None\n        past_key_values = []\n        top_n_tokens_tensor = None\n\n        # Used for slicing correctly inside the tensors\n        # Equivalent to a cumsum on batch sizes\n        start_index = 0\n        for i, batch in enumerate(batches):\n            requests.extend(batch.requests)\n            input_lengths.extend(batch.input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n            all_input_ids.extend(batch.all_input_ids)\n            next_token_choosers.extend(batch.next_token_choosers)\n            stopping_criterias.extend(batch.stopping_criterias)\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + start_index\n\n            # Slicing end index for this batch\n            end_index = start_index + len(batch)\n\n            # We only concatenate batches that did at least one step\n            if batch.past_key_values is None:\n                raise ValueError(\"only concatenate prefilled batches\")\n\n            # Create empty tensor\n            # input_ids is always of shape [batch_size, 1]\n            # We do not need to pad it\n            if input_ids is None:\n                input_ids = batch.input_ids.new_empty((total_batch_size, 1))\n            # Copy to correct indices\n            input_ids[start_index:end_index] = batch.input_ids\n\n            # Create padded tensor\n            if attention_mask is None:\n                attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_input_length + padding_right_offset),\n                )\n\n            if top_n_tokens_tensor is None:\n                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n                    total_batch_size,\n                )\n            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor\n\n            # We need to slice the attention mask to remove padding from previous steps\n            # and to remove unused allocated space\n            left_offset = max_input_length - batch.max_input_length\n            batch_left_offset = (\n                batch.attention_mask.shape[1]\n                - batch.max_input_length\n                - batch.padding_right_offset\n            )\n            attention_mask[\n                start_index:end_index,\n                left_offset:-padding_right_offset,\n            ] = batch.attention_mask[\n                :,\n                batch_left_offset : -batch.padding_right_offset,\n            ]\n\n            # Create empty tensor\n            # position_ids is always of shape [batch_size, 1]\n            if position_ids is None:\n                position_ids = batch.position_ids.new_empty((total_batch_size, 1))\n            position_ids[start_index:end_index] = batch.position_ids\n\n            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape\n            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]\n            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]\n            # And ensure that we can update tensors in-place\n            if isinstance(batch.past_key_values[0], tuple):\n                batch.past_key_values = [\n                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]\n                    for layer in batch.past_key_values\n                ]\n            elif len(batch.past_key_values[0][0].shape) == 3:\n                for layer in batch.past_key_values:\n                    for k, t in enumerate(layer):\n                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])\n\n            # Add eventual padding tokens that were added while concatenating\n            max_tokens += batch.max_tokens + (\n                max_input_length - batch.max_input_length\n            ) * len(batch)\n\n            start_index = end_index\n\n        first_past_kvs = batches[0].past_key_values\n        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape\n\n        padded_past_values_shape = (\n            total_batch_size,\n            num_heads,\n            max_input_length - 1,\n            head_dim,\n        )\n\n        if batches[0].keys_head_dim_last:\n            padded_past_keys_shape = padded_past_values_shape\n        else:\n            # seq_length is last for BLOOM\n            padded_past_keys_shape = (\n                total_batch_size,\n                num_heads,\n                head_dim,\n                max_input_length - 1,\n            )\n\n        # Iterate over attention layers\n        # Concatenate past key values layer by layer to allow incremental garbage collection\n        for j in range(len(first_past_kvs)):\n            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)\n            start_index = 0\n            for batch in batches:\n                past_keys = batch.past_key_values[j][0]\n                # Clear reference to the original tensor\n                batch.past_key_values[j][0] = None\n\n                # Slicing end index for this batch\n                end_index = start_index + len(batch)\n                # We slice the keys to remove the padding from previous batches\n                past_seq_len = batch.max_input_length - 1\n                if batch.keys_head_dim_last:\n                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (\n                        past_keys[:, :, -past_seq_len:, :]\n                    )\n                else:\n                    # BLOOM case\n                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (\n                        past_keys[:, :, :, -past_seq_len:]\n                    )\n                del past_keys\n\n                start_index = end_index\n\n            padded_past_values = first_past_kvs[j][1].new_zeros(\n                padded_past_values_shape\n            )\n            start_index = 0\n            for batch in batches:\n                past_values = batch.past_key_values[j][1]\n                # Clear reference to the original tensor\n                batch.past_key_values[j][1] = None\n\n                # Slicing end index for this batch\n                end_index = start_index + len(batch)\n                # We slice the past values to remove the padding from previous batches\n                past_seq_len = batch.max_input_length - 1\n                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (\n                    past_values[:, :, -past_seq_len:, :]\n                )\n                del past_values\n\n                # Update values\n                start_index = end_index\n\n            past_key_values.append([padded_past_keys, padded_past_values])\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            all_input_ids=all_input_ids,\n            input_lengths=input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length,\n            padding_right_offset=padding_right_offset,\n            keys_head_dim_last=batches[0].keys_head_dim_last,\n            max_tokens=max_tokens,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\n@dataclass\nclass CausalLMBatchKeysLast(CausalLMBatch):\n    keys_head_dim_last: bool = False\n\n\nclass CausalLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        default_dtype=torch.float16,\n        trust_remote_code: bool = False,\n        tokenizer_class=AutoTokenizer,\n        config_class=AutoConfig,\n        batch_class=CausalLMBatch,\n    ):\n        self.quantize = quantize\n        self.batch_class = batch_class\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            device = torch.device(f\"xpu:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            device = torch.device(\"cpu\")\n            # Float16 doesn't exist on target.\n            dtype = torch.bfloat16 if dtype is None else dtype\n        else:\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n\n        config = config_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n        if tokenizer.pad_token_id is None:\n            if config.pad_token_id is not None:\n                tokenizer.pad_token_id = config.pad_token_id\n            elif config.eos_token_id is not None:\n                tokenizer.pad_token_id = config.eos_token_id\n            elif tokenizer.eos_token_id is not None:\n                tokenizer.pad_token_id = tokenizer.eos_token_id\n\n        torch.distributed.barrier(group=self.process_group)\n        weights_loader = get_loader(\n            quantize=quantize, model_id=model_id, revision=revision\n        )\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device=device,\n            dtype=dtype,\n            process_group=self.process_group,\n            weights_loader=weights_loader,\n        )\n\n        prefix = \"\"\n        model = model_class(prefix, config, weights)\n\n        torch.distributed.barrier(group=self.process_group)\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n        )\n\n    @classmethod\n    def fallback(\n        cls,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        if speculator:\n            raise RuntimeError(\"Speculator decoding is not enabled for AutoModel\")\n\n        device_count = 0\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            device_count = torch.cuda.device_count()\n            dtype = torch.float16 if dtype is None else dtype\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n            device_count = torch.xpu.device_count()\n            dtype = torch.float16 if dtype is None else dtype\n        else:\n            if quantize:\n                raise ValueError(\"quantization is not available on CPU\")\n\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        model = AutoModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=dtype,\n            device_map=(\"auto\" if device_count > 1 else None),\n            load_in_8bit=quantize == \"bitsandbytes\",\n            trust_remote_code=trust_remote_code,\n        )\n        if device_count == 1 and quantize != \"bitsandbytes\":\n            model = model.to(device)\n\n        if tokenizer.pad_token_id is None:\n            if model.config.pad_token_id is not None:\n                tokenizer.pad_token_id = model.config.pad_token_id\n            elif model.config.eos_token_id is not None and isinstance(\n                model.config.eos_token_id, int\n            ):\n                tokenizer.pad_token_id = model.config.eos_token_id\n            elif tokenizer.eos_token_id is not None:\n                tokenizer.pad_token_id = tokenizer.eos_token_id\n            else:\n                tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n\n        self = cls.__new__(\n            cls,\n        )\n        self.batch_class = CausalLMBatch\n        super().__init__(\n            self,\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n        )\n        self.quantize = quantize\n        return self\n\n    @property\n    def batch_type(self) -> Type[CausalLMBatch]:\n        return self.batch_class\n\n    def forward(\n        self, input_ids, attention_mask, position_ids, past_key_values: Optional = None\n    ) -> Tuple[\n        torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]\n    ]:\n        # Model Forward\n        kwargs = {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"past_key_values\": past_key_values,\n            \"use_cache\": True,\n            \"return_dict\": True,\n        }\n        if self.has_position_ids:\n            kwargs[\"position_ids\"] = position_ids\n\n        outputs = self.model.forward(**kwargs)\n        if isinstance(outputs, tuple):\n            outputs, speculative_logits = outputs\n        else:\n            speculative_logits = None\n        return outputs.logits, speculative_logits, outputs.past_key_values\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batch: CausalLMBatch\n    ) -> Tuple[List[Generation], Optional[CausalLMBatch], Tuple[int, int]]:\n        start = time.time_ns()\n        # slice the attention mask to the correct shape\n        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]\n\n        logits, speculative_logits, past = self.forward(\n            batch.input_ids,\n            attention_mask,\n            batch.position_ids,\n            batch.past_key_values,\n        )\n\n        # Results\n        generations: List[Generation] = []\n        stopped = True\n\n        # Speculation is not active for causal\n        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]\n        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n            batch.top_n_tokens,\n            batch.top_n_tokens_tensor,\n            torch.log_softmax(logits[:, -1], -1),\n            accepted_ids,\n        )\n\n        start_decode = time.time_ns()\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            logits,\n            batch.next_token_choosers,\n            batch.stopping_criterias,\n            batch.all_input_ids,\n            batch.top_n_tokens,\n            batch_top_token_ids,\n            batch_top_token_logprobs,\n        )\n\n        # For each member of the batch\n        for i, (\n            request,\n            input_length,\n            prefix_offset,\n            read_offset,\n            logits,\n            next_token_chooser,\n            stopping_criteria,\n            all_input_ids,\n            top_n_tokens,\n            top_token_ids,\n            top_token_logprobs,\n        ) in enumerate(iterator):\n            # Select next token\n            next_token_id, logprobs = next_token_chooser(\n                all_input_ids.view(1, -1), logits[-1:, :]\n            )\n\n            # Append next token to all tokens\n            all_input_ids = torch.cat([all_input_ids, next_token_id])\n            new_input_length = input_length + 1\n\n            # Generated token\n            next_token_logprob = logprobs[-1, next_token_id]\n            next_token_id_squeezed = next_token_id.squeeze()\n            next_token_text, prefix_offset, read_offset = self.decode_token(\n                all_input_ids[:, 0], prefix_offset, read_offset\n            )\n\n            # Evaluate stopping criteria\n            stop, reason = stopping_criteria(\n                next_token_id_squeezed,\n                next_token_text,\n            )\n\n            if not stop:\n                stopped = False\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if i % self.world_size == self.rank:\n                if stop:\n                    # Decode generated tokens\n                    output_text, _, _ = self.decode_token(\n                        all_input_ids[:, 0],\n                        prefix_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens\n                        - 1,\n                        read_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens,\n                        skip_special_tokens=True,\n                    )\n                    # Get seed\n                    if isinstance(next_token_chooser.choice, Sampling):\n                        seed = next_token_chooser.choice.seed\n                    else:\n                        seed = None\n\n                    generated_text = GeneratedText(\n                        output_text, stopping_criteria.current_tokens, reason, seed\n                    )\n                else:\n                    generated_text = None\n\n                # Prefill\n                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:\n                    # Remove generated token to only have prefill and add nan for first prompt token\n                    prefill_logprobs = [float(\"nan\")] + torch.log_softmax(\n                        logits, -1\n                    ).gather(1, all_input_ids[1:]).squeeze(1)[\n                        -new_input_length:-1\n                    ].tolist()\n                    prefill_token_ids = all_input_ids[-new_input_length:-1]\n                    prefill_texts = self.tokenizer.batch_decode(\n                        prefill_token_ids,\n                        clean_up_tokenization_spaces=False,\n                        skip_special_tokens=False,\n                    )\n                    prefill_tokens = Tokens(\n                        prefill_token_ids,\n                        prefill_logprobs,\n                        prefill_texts,\n                        is_special=[],\n                    )\n                else:\n                    prefill_tokens = None\n\n                if top_n_tokens > 0:\n                    all_top_tokens = []\n                    for top_token_ids, top_token_logprobs in zip(\n                        top_token_ids, top_token_logprobs\n                    ):\n                        toptoken_texts = self.tokenizer.batch_decode(\n                            top_token_ids,\n                            clean_up_tokenization_spaces=False,\n                            skip_special_tokens=False,\n                        )\n                        special_toptokens = [\n                            token_id in self.all_special_ids\n                            for token_id in top_token_ids\n                        ]\n                        top_tokens = Tokens(\n                            top_token_ids,\n                            top_token_logprobs,\n                            toptoken_texts,\n                            special_toptokens,\n                        )\n                        all_top_tokens.append(top_tokens)\n                    top_tokens = all_top_tokens\n                else:\n                    top_tokens = None\n\n                generation = Generation(\n                    request.id,\n                    prefill_tokens,\n                    Tokens(\n                        [next_token_id_squeezed],\n                        [next_token_logprob],\n                        [next_token_text],\n                        [next_token_id_squeezed.item() in self.all_special_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n            # Update values\n            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(\n                next_token_id_squeezed.item()\n            )\n            batch.input_ids[i, 0] = next_token_id\n            batch.all_input_ids[i] = all_input_ids\n            batch.input_lengths[i] = new_input_length\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.max_input_length = max(batch.max_input_length, new_input_length)\n\n        # We finished all generations in the batch; there is no next batch\n        if stopped:\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        # Slice unused values from prefill\n        batch.input_ids = batch.input_ids[:, :1]\n\n        # Update attention_mask as we added a new token to input_ids\n        batch.attention_mask[:, -batch.padding_right_offset] = 1\n        # Decrease right offset\n        batch.padding_right_offset -= 1\n\n        # Update position_ids\n        batch.position_ids = batch.position_ids[:, -1:] + 1\n\n        # Update past key values\n        batch.past_key_values = past\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/__init__.py",
    "content": ""
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/bloom_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 HuggingFace Inc. team and BigScience workshop.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch BLOOM model.\"\"\"\n\nimport math\nimport os\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.distributed\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import LayerNorm\nfrom torch.nn import functional as F\n\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n)\nfrom transformers import BloomConfig, PreTrainedModel\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n)\n\nCUSTOM_KERNELS_ENABLED = False\nif (\n    torch.cuda.is_available()\n    and not os.environ.get(\"DISABLE_CUSTOM_KERNELS\", \"False\") == \"True\"\n):\n    try:\n        from custom_kernels import fused_bloom_attention_cuda\n\n        CUSTOM_KERNELS_ENABLED = True\n    except ImportError:\n        pass\n\n_CHECKPOINT_FOR_DOC = \"bigscience/bloom-560m\"\n_CONFIG_FOR_DOC = \"BloomConfig\"\n\nBLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"bigscience/bigscience-small-testing\",\n    \"bigscience/bloom-560m\",\n    \"bigscience/bloom-1b1\",\n    \"bigscience/bloom-1b7\",\n    \"bigscience/bloom-3b\",\n    \"bigscience/bloom-7b1\",\n    \"bigscience/bloom\",\n]\n\n\ndef _make_causal_mask(\n    input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int\n) -> torch.BoolTensor:\n    \"\"\"\n    Make causal mask used for self-attention.\n    \"\"\"\n    batch_size, target_length = input_ids_shape\n    mask = torch.ones(\n        (target_length, target_length + past_key_values_length),\n        dtype=torch.bool,\n        device=device,\n    )\n    mask = mask.triu(1 + past_key_values_length)\n\n    expanded_mask = mask.unsqueeze(0).expand(\n        batch_size, target_length, target_length + past_key_values_length\n    )\n    return expanded_mask\n\n\ndef _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:\n    \"\"\"\n    Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.\n    \"\"\"\n    batch_size, src_length = mask.shape\n    tgt_length = tgt_length if tgt_length is not None else src_length\n\n    expanded_mask = ~(mask[:, None, :].to(torch.bool))\n    return expanded_mask.expand(batch_size, tgt_length, src_length)\n\n\ndef build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:\n    \"\"\"\n    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it\n    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value\n    `softmax(l+a) = softmax(l)`. Based on\n    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742\n    TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.\n\n    Args:\n    Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)\n        attention_mask (`torch.Tensor`):\n            Token-wise attention mask, this should be of shape (batch_size, max_seq_len).\n        num_heads (`int`, *required*):\n            number of heads\n        dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):\n            dtype of the output tensor\n    \"\"\"\n    batch_size, seq_length = attention_mask.shape\n    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))\n    base = torch.tensor(\n        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),\n        device=attention_mask.device,\n        dtype=torch.float32,\n    )\n    powers = torch.arange(\n        1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32\n    )\n    slopes = torch.pow(base, powers)\n\n    if closest_power_of_2 != num_heads:\n        extra_base = torch.tensor(\n            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),\n            device=attention_mask.device,\n            dtype=torch.float32,\n        )\n        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)\n        extra_powers = torch.arange(\n            1,\n            1 + 2 * num_remaining_heads,\n            2,\n            device=attention_mask.device,\n            dtype=torch.int32,\n        )\n        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)\n\n    # Note: alibi will added to the attention bias that will be applied to the query, key product of attention\n    # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)\n    # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)\n    # => the query_length dimension will then be broadcasted correctly\n    # This is more or less identical to T5's relative position bias:\n    # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527\n    arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]\n    alibi = slopes[..., None] * arange_tensor\n    return alibi\n\n\n# @torch.jit.script\ndef dropout_add(\n    x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool\n) -> torch.Tensor:\n    \"\"\"\n    Dropout add function\n\n    Args:\n        x (`torch.tensor`, *required*):\n            input tensor\n        residual (`torch.tensor`, *required*):\n            esidual tensor\n        prob (`float`, *required*):\n            dropout probability\n        training (`bool`, *required*):\n            training mode\n    \"\"\"\n    out = F.dropout(x, p=prob, training=training)\n    out = residual + out\n    return out\n\n\n# @torch.jit.script # this is shit for unknow reasons.\ndef _split_heads(\n    fused_qkv: torch.Tensor, num_heads: int, head_dim: int\n) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory\n    storage as `fused_qkv`\n\n    Args:\n        fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]\n\n    Returns:\n        query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]\n        value: [batch_size, seq_length, num_heads, head_dim]\n    \"\"\"\n    batch_size, seq_length, three_times_hidden_size = fused_qkv.shape\n    fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)\n    query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)\n\n    query_layer = query_layer.transpose(1, 2).reshape(\n        batch_size * num_heads, seq_length, head_dim\n    )\n    key_layer = key_layer.permute(0, 2, 3, 1).reshape(\n        batch_size * num_heads, head_dim, seq_length\n    )\n    value_layer = value_layer.transpose(1, 2).reshape(\n        batch_size * num_heads, seq_length, head_dim\n    )\n\n    return query_layer, key_layer, value_layer\n\n\n# @torch.jit.script\ndef _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:\n    \"\"\"\n    Merge heads together over the last dimenstion\n\n    Args:\n        x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]\n\n    Returns:\n        torch.tensor: [batch_size, seq_length, num_heads * head_dim]\n    \"\"\"\n    # What we want to achieve is:\n    # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim\n    batch_size_and_num_heads, seq_length, _ = x.shape\n    batch_size = batch_size_and_num_heads // num_heads\n\n    # First view to decompose the batch size\n    # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim\n    x = x.view(batch_size, num_heads, seq_length, head_dim)\n\n    # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim\n    x = x.permute(0, 2, 1, 3)\n\n    # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim\n    return x.reshape(batch_size, seq_length, num_heads * head_dim)\n\n\nclass BloomAttention(nn.Module):\n    def __init__(self, prefix, config: BloomConfig, weights):\n        super().__init__()\n\n        self.pretraining_tp = config.pretraining_tp\n        self.slow_but_exact = config.slow_but_exact\n\n        self.process_group = weights.process_group\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.n_head\n        self.head_dim = self.hidden_size // self.num_heads\n        self.split_size = self.hidden_size\n        self.hidden_dropout = config.hidden_dropout\n\n        if self.head_dim * self.num_heads != self.hidden_size:\n            raise ValueError(\n                f\"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n\n        # Layer-wise attention scaling\n        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)\n        self.beta = 1.0\n\n        process_group = weights.process_group\n        if self.num_heads % process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // process_group.size()\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config=config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=True,\n        )\n        self.dense = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.dense\", weights=weights, bias=True\n        )\n        self.attention_dropout = nn.Dropout(config.attention_dropout)\n\n    @staticmethod\n    def compute_attention(\n        fused_qkv: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        head_mask: Optional[torch.Tensor],\n        beta: float,\n        inv_norm_factor: float,\n        num_heads: int,\n        use_cache: bool,\n    ):\n        batch_size, q_length, three_times_hidden_size = fused_qkv.shape\n        head_dim = three_times_hidden_size // (3 * num_heads)\n        batch_size * num_heads\n\n        ### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?\n        # 3 x [batch_size, seq_length, num_heads, head_dim]\n        (query_layer, key_layer, value_layer) = _split_heads(\n            fused_qkv, num_heads=num_heads, head_dim=head_dim\n        )\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            # concatenate along seq_length dimension:\n            #  - key: [batch_size * self.num_heads, head_dim, kv_length]\n            #  - value: [batch_size * self.num_heads, kv_length, head_dim]\n            past_key = past_key.view(-1, *past_key.shape[-2:])\n            key_layer = torch.cat((past_key, key_layer), dim=2)\n            past_value = past_value.view(-1, *past_value.shape[-2:])\n            value_layer = torch.cat((past_value, value_layer), dim=1)\n\n        _, _, kv_length = key_layer.shape\n\n        if use_cache is True:\n            present = (key_layer, value_layer)\n        else:\n            present = None\n        ###\n\n        # [batch_size * num_heads, q_length, kv_length]\n        # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11\n        attention_scores = alibi.baddbmm(\n            batch1=query_layer,\n            batch2=key_layer,\n            beta=beta,\n            alpha=inv_norm_factor,\n        )\n\n        # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]\n        input_dtype = attention_scores.dtype\n        # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`\n        if input_dtype == torch.float16:\n            attention_scores = attention_scores.to(torch.float)\n        # torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`\n        attn_weights = attention_scores.masked_fill_(\n            attention_mask, torch.finfo(attention_scores.dtype).min\n        )\n        attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(\n            input_dtype\n        )\n\n        # # [batch_size, num_heads, q_length, kv_length]\n        # attention_probs = self.attention_dropout(attention_probs)\n\n        if head_mask is not None:\n            attention_probs = attention_probs * head_mask\n\n        # matmul: [batch_size * num_heads, q_length, head_dim]\n        context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)\n\n        # change view [batch_size, num_heads, q_length, head_dim]\n        context_layer = _merge_heads(\n            context_layer, num_heads=num_heads, head_dim=head_dim\n        )\n\n        return context_layer, present, attention_probs\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        fused_qkv = self.query_key_value(\n            hidden_states\n        )  # [batch_size, seq_length, 3 x hidden_size]\n        batch_size, q_length, _ = fused_qkv.shape\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            layer_past = (\n                past_key.view(-1, *past_key.shape[-2:]),\n                past_value.view(-1, *past_value.shape[-2:]),\n            )\n\n        if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:\n            assert self.training is False, \"Only foward pass was implemented\"\n            assert (\n                attention_mask.shape[-1] < 4096\n            ), \"Custom kernel support only up to 4096 tokens\"\n            (\n                context_layer,\n                present,\n                attention_probs,\n            ) = fused_bloom_attention_cuda.forward(\n                fused_qkv,\n                layer_past,\n                alibi,\n                attention_mask,\n                head_mask,\n                self.beta,\n                self.inv_norm_factor,\n                self.num_heads,\n                use_cache,\n            )\n        else:\n            context_layer, present, attention_probs = self.compute_attention(\n                fused_qkv=fused_qkv,\n                layer_past=layer_past,\n                alibi=alibi,\n                attention_mask=attention_mask,\n                head_mask=head_mask,\n                beta=self.beta,\n                inv_norm_factor=self.inv_norm_factor,\n                num_heads=self.num_heads,\n                use_cache=use_cache,\n            )\n\n        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232\n        if self.pretraining_tp > 1 and self.slow_but_exact:\n            slices = self.hidden_size / self.pretraining_tp\n            output_tensor = torch.zeros_like(context_layer)\n            for i in range(self.pretraining_tp):\n                output_tensor = output_tensor + F.linear(\n                    context_layer[:, :, int(i * slices) : int((i + 1) * slices)],\n                    self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],\n                )\n        else:\n            output_tensor = self.dense(context_layer)\n\n        # output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)\n        output_tensor += residual\n\n        outputs = (output_tensor, present)\n        if output_attentions:\n            outputs += (attention_probs,)\n\n        return outputs\n\n\nclass BloomMLP(nn.Module):\n    def __init__(self, prefix, config: BloomConfig, weights):\n        super().__init__()\n\n        self.pretraining_tp = config.pretraining_tp\n        self.slow_but_exact = config.slow_but_exact\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=True\n        )\n        self.dense_4h_to_h = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=True\n        )\n        self.gelu_impl = torch.nn.GELU(approximate=\"tanh\")\n        self.hidden_dropout = config.hidden_dropout\n\n    def forward(\n        self, hidden_states: torch.Tensor, residual: torch.Tensor\n    ) -> torch.Tensor:\n        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))\n\n        if self.pretraining_tp > 1 and self.slow_but_exact:\n            intermediate_output = torch.zeros_like(residual)\n            slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp\n            for i in range(self.pretraining_tp):\n                intermediate_output = intermediate_output + F.linear(\n                    hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],\n                    self.dense_4h_to_h.weight[\n                        :, int(i * slices) : int((i + 1) * slices)\n                    ],\n                )\n        else:\n            intermediate_output = self.dense_4h_to_h(hidden_states)\n\n        # output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)\n        intermediate_output += residual\n\n        return intermediate_output\n\n\nclass BloomBlock(nn.Module):\n    def __init__(self, layer_id: int, config: BloomConfig, weights):\n        super().__init__()\n\n        prefix = f\"h.{layer_id}\"\n        self.input_layernorm = LayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.num_heads = config.n_head\n        self.self_attention = BloomAttention(\n            prefix=f\"{prefix}.self_attention\", config=config, weights=weights\n        )\n        self.post_attention_layernorm = LayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.mlp = BloomMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.apply_residual_connection_post_layernorm = (\n            config.apply_residual_connection_post_layernorm\n        )\n        self.hidden_dropout = config.hidden_dropout\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        alibi: torch.Tensor,\n        attention_mask: torch.Tensor,\n        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        use_cache: bool = False,\n        output_attentions: bool = False,\n    ):\n        # hidden_states: [batch_size, seq_length, hidden_size]\n\n        # Layer norm at the beginning of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n\n        # Layer norm post the self attention.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        # Self attention.\n        attn_outputs = self.self_attention(\n            layernorm_output,\n            residual,\n            layer_past=layer_past,\n            attention_mask=attention_mask,\n            alibi=alibi,\n            head_mask=head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n\n        attention_output = attn_outputs[0]\n\n        outputs = attn_outputs[1:]\n\n        layernorm_output = self.post_attention_layernorm(attention_output)\n\n        # Get residual\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = attention_output\n\n        # MLP.\n        output = self.mlp(layernorm_output, residual)\n\n        if use_cache:\n            outputs = (output,) + outputs\n        else:\n            outputs = (output,) + outputs[1:]\n\n        return outputs  # hidden_states, present, attentions\n\n\nclass BloomPreTrainedModel(PreTrainedModel):\n    config_class = BloomConfig\n    base_model_prefix = \"transformer\"\n    _no_split_modules = [\"BloomBlock\"]\n\n    @staticmethod\n    def _convert_to_standard_cache(\n        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,\n        num_heads, ...]))\n        \"\"\"\n        batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape\n        num_heads = batch_size_times_num_heads // batch_size\n        # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]\n        # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]\n        return tuple(\n            (\n                layer_past[0].view(batch_size, num_heads, head_dim, seq_length),\n                layer_past[1].view(batch_size, num_heads, seq_length, head_dim),\n            )\n            for layer_past in past_key_value\n        )\n\n    @staticmethod\n    def _convert_to_bloom_cache(\n        past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],\n    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))\n        \"\"\"\n        batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape\n        batch_size_times_num_heads = batch_size * num_heads\n        # key:  [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]\n        # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]\n        return tuple(\n            (\n                layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),\n                layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),\n            )\n            for layer_past in past_key_value\n        )\n\n\nclass BloomModel(BloomPreTrainedModel):\n    def __init__(self, config: BloomConfig, weights):\n        super().__init__(config)\n\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.n_head\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        self.word_embeddings = TensorParallelEmbedding(\n            prefix=\"word_embeddings\", weights=weights\n        )\n\n        self.word_embeddings_layernorm = LayerNorm.load(\n            prefix=\"word_embeddings_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        # Transformer blocks\n        self.h = nn.ModuleList(\n            [\n                BloomBlock(layer_id=layer_id, config=config, weights=weights)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        # Final Layer Norm\n        self.ln_f = LayerNorm.load(\n            prefix=\"ln_f\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def _prepare_attn_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_shape: Tuple[int, int],\n        past_key_values_length: int,\n    ) -> torch.BoolTensor:\n        # create causal mask\n        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n        combined_attention_mask = None\n        device = attention_mask.device\n        _, src_length = input_shape\n\n        if src_length > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                device=device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)\n        combined_attention_mask = (\n            expanded_attn_mask\n            if combined_attention_mask is None\n            else expanded_attn_mask | combined_attention_mask\n        )\n\n        return combined_attention_mask\n\n    def set_input_embeddings(self, new_embeddings: torch.Tensor):\n        self.word_embeddings = new_embeddings\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:\n        if deprecated_arguments.pop(\"position_ids\", False) is not False:\n            # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`\n            warnings.warn(\n                \"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore\"\n                \" passing `position_ids`.\",\n                FutureWarning,\n            )\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f\"Got unexpected arguments: {deprecated_arguments}\")\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if past_key_values is None:\n            past_key_values = tuple([None] * len(self.h))\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape batch_size x num_heads x N x N\n        # head_mask has shape n_layer x batch x num_heads x N x N\n        head_mask = self.get_head_mask(head_mask, self.config.n_layer)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n\n        presents = () if use_cache else None\n        all_self_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        # Compute alibi tensor: check build_alibi_tensor documentation\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values[0] is not None:\n            past_key_values_length = past_key_values[0][0].shape[-1]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), device=hidden_states.device\n            )\n        else:\n            attention_mask = attention_mask.to(hidden_states.device)\n\n        alibi = build_alibi_tensor(attention_mask, self.num_heads)\n\n        causal_mask = self._prepare_attn_mask(\n            attention_mask,\n            input_shape=(batch_size, seq_length),\n            past_key_values_length=past_key_values_length,\n        )\n\n        if hasattr(self, \"tp_rank\"):\n            assert self.num_heads % self.tp_world_size == 0\n            block_size = self.num_heads // self.tp_world_size\n            alibi = alibi[\n                :, self.tp_rank * block_size : (self.tp_rank + 1) * block_size\n            ]\n            alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)\n            causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)\n        else:\n            alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)\n            causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)\n\n        alibi = alibi.to(hidden_states.dtype)\n\n        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            outputs = block(\n                hidden_states,\n                layer_past=layer_past,\n                attention_mask=causal_mask,\n                head_mask=head_mask[i],\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n                alibi=alibi,\n            )\n\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (\n                    outputs[2 if use_cache else 1],\n                )\n\n        # Add last hidden state\n        hidden_states = self.ln_f(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    presents,\n                    all_hidden_states,\n                    all_self_attentions,\n                ]\n                if v is not None\n            )\n\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n        )\n\n\nclass BloomForCausalLM(BloomPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.transformer = BloomModel(config, weights)\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"word_embeddings\",\n            weights=weights,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> dict:\n        # only last token for input_ids if past is not None\n        if past_key_values:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n\n            # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed\n            if past_key_values[0][0].shape[0] == input_ids.shape[0]:\n                past_key_values = self._convert_to_bloom_cache(past_key_values)\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        labels: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **deprecated_arguments,\n    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set\n            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`\n            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`\n        \"\"\"\n        if deprecated_arguments.pop(\"position_ids\", False) is not False:\n            # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`\n            warnings.warn(\n                \"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore\"\n                \" passing `position_ids`.\",\n                FutureWarning,\n            )\n        if len(deprecated_arguments) > 0:\n            raise ValueError(f\"Got unexpected arguments: {deprecated_arguments}\")\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.transformer(\n            input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n\n        logits, speculative_logits = self.lm_head(hidden_states)\n        loss = None\n\n        if not return_dict:\n            output = (logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return (\n            CausalLMOutputWithCrossAttentions(\n                loss=loss,\n                logits=logits,\n                past_key_values=transformer_outputs.past_key_values,\n                hidden_states=transformer_outputs.hidden_states,\n                attentions=transformer_outputs.attentions,\n            ),\n            speculative_logits,\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/clip.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_attn_mask_utils import (\n    _create_4d_causal_attention_mask,\n    _prepare_4d_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPooling,\n)\nfrom transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig\n\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\nclass CLIPVisionEmbeddings(nn.Module):\n    def __init__(self, prefix, config: CLIPVisionConfig, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        # TODO Should we TP this ?\n        self.class_embedding = weights.get_tensor(f\"{prefix}.class_embedding\")\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions, device=weights.device).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        batch_size = pixel_values.shape[0]\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(\n            pixel_values.to(dtype=target_dtype)\n        )  # shape = [*, width, grid, grid]\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass CLIPTextEmbeddings(nn.Module):\n    def __init__(self, config: CLIPTextConfig):\n        super().__init__()\n        embed_dim = config.hidden_size\n\n        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)\n        self.position_embedding = nn.Embedding(\n            config.max_position_embeddings, embed_dim\n        )\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(config.max_position_embeddings).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ) -> torch.Tensor:\n        seq_length = (\n            input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n        )\n\n        if position_ids is None:\n            position_ids = self.position_ids[:, :seq_length]\n\n        if inputs_embeds is None:\n            inputs_embeds = self.token_embedding(input_ids)\n\n        position_embeddings = self.position_embedding(position_ids)\n        embeddings = inputs_embeds + position_embeddings\n\n        return embeddings\n\n\nclass CLIPAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_size)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n            ]\n            * 3,\n            dim=2,\n        )\n        query_states = query_states * self.scale\n        key_states = self._shape(key_states, -1, bsz)\n        value_states = self._shape(value_states, -1, bsz)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_size)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        # apply the causal_attention_mask first\n        if causal_attention_mask is not None:\n            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is\"\n                    f\" {causal_attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + causal_attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        attn_probs = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, None\n\n\nclass CLIPMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass CLIPEncoderLayer(nn.Module):\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = CLIPAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.mlp = CLIPMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", weights=weights, eps=config.layer_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        causal_attention_mask: torch.Tensor,\n    ):\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass CLIPPreTrainedModel(nn.Module):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = CLIPConfig\n    base_model_prefix = \"clip\"\n    supports_gradient_checkpointing = True\n\n\nCLIP_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nCLIP_TEXT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n\"\"\"\n\nCLIP_VISION_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using\n            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.\n\"\"\"\n\nCLIP_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.max_position_embeddings - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using\n            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.\n        return_loss (`bool`, *optional*):\n            Whether or not to return the contrastive loss.\n\"\"\"\n\n\nclass CLIPEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`CLIPEncoderLayer`].\n\n    Args:\n        config: CLIPConfig\n    \"\"\"\n\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                CLIPEncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Args:\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Causal mask for the text model. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n        \"\"\"\n\n        hidden_states = inputs_embeds\n        for idx, encoder_layer in enumerate(self.layers):\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n                causal_attention_mask,\n            )\n\n        return hidden_states\n\n\nclass CLIPTextTransformer(nn.Module):\n    def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):\n        super().__init__()\n        self.config = config\n        embed_dim = config.hidden_size\n        self.embeddings = CLIPTextEmbeddings(config)\n        # Initialize weights and apply final processing with `self.post_init()`\n        self.encoder = CLIPEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n        # For `pooled_output` computation\n        self.eos_token_id = config.eos_token_id\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        \"\"\"\n        if input_ids is None:\n            raise ValueError(\"You have to specify input_ids\")\n\n        input_shape = input_ids.size()\n        input_ids = input_ids.view(-1, input_shape[-1])\n\n        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)\n\n        # CLIP's text model uses causal mask, prepare it here.\n        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n        causal_attention_mask = _create_4d_causal_attention_mask(\n            input_shape, hidden_states.dtype, device=hidden_states.device\n        )\n        # expand attention_mask\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            attention_mask = _prepare_4d_attention_mask(\n                attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs[0]\n        last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n        if self.eos_token_id == 2:\n            # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.\n            # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added\n            # ------------------------------------------------------------\n            # text_embeds.shape = [batch_size, sequence_length, transformer.width]\n            # take features from the eot embedding (eot_token is the highest number in each sequence)\n            # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14\n            last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(\n                    dim=-1\n                ),\n            ]\n        else:\n            # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)\n            last_hidden_state[\n                torch.arange(\n                    last_hidden_state.shape[0], device=last_hidden_state.device\n                ),\n                # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)\n                (\n                    input_ids.to(dtype=torch.int, device=last_hidden_state.device)\n                    == self.eos_token_id\n                )\n                .int()\n                .argmax(dim=-1),\n            ]\n\n        return last_hidden_state\n\n\nclass CLIPTextModel(CLIPPreTrainedModel):\n    config_class = CLIPTextConfig\n\n    _no_split_modules = [\"CLIPTextEmbeddings\", \"CLIPEncoderLayer\"]\n\n    def __init__(self, prefix, config: CLIPTextConfig):\n        super().__init__(config)\n        self.text_model = CLIPTextTransformer(prefix, config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, CLIPTextModel\n\n        >>> model = CLIPTextModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=True, return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states\n        ```\"\"\"\n\n        return self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n\nclass CLIPVisionTransformer(nn.Module):\n    def __init__(self, prefix, config: CLIPVisionConfig, weights):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = CLIPVisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.pre_layrnorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.pre_layrnorm\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.encoder = CLIPEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        # self.post_layernorm = nn.LayerNorm.load(prefix=f\"{prefix}.post_layernorm\", weights=weights, eps=config.layer_norm_eps)\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        \"\"\"\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.pre_layrnorm(hidden_states)\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n        )\n        last_hidden_state = encoder_outputs\n        # pooled_output = last_hidden_state[:, 0, :]\n        # pooled_output = self.post_layernorm(pooled_output)\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            # pooler_output=pooled_output,\n            # hidden_states=encoder_outputs,\n        )\n\n\nclass CLIPVisionModel(CLIPPreTrainedModel):\n    config_class = CLIPVisionConfig\n    main_input_name = \"pixel_values\"\n    _no_split_modules = [\"CLIPEncoderLayer\"]\n\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__(config)\n        self.vision_model = CLIPVisionTransformer(config)\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self) -> nn.Module:\n        return self.vision_model.embeddings.patch_embedding\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPVisionModel\n\n        >>> model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> outputs = model(**inputs)\n        >>> last_hidden_state = outputs.last_hidden_state\n        >>> pooled_output = outputs.pooler_output  # pooled CLS states\n        ```\"\"\"\n\n        return self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n\nclass CLIPModel(nn.Module):\n    def __init__(self, prefix, config: CLIPConfig, weights):\n        super().__init__()\n        text_config = config.text_config\n        vision_config = config.vision_config\n\n        self.projection_dim = config.projection_dim\n        self.text_embed_dim = text_config.hidden_size\n        self.vision_embed_dim = vision_config.hidden_size\n\n        self.text_model = CLIPTextTransformer(text_config)\n        self.vision_model = CLIPVisionTransformer(vision_config)\n\n        self.visual_projection = nn.Linear(\n            self.vision_embed_dim, self.projection_dim, bias=False\n        )\n        self.text_projection = nn.Linear(\n            self.text_embed_dim, self.projection_dim, bias=False\n        )\n        self.logit_scale = nn.Parameter(\n            torch.tensor(self.config.logit_scale_init_value)\n        )\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_text_features(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by\n            applying the projection layer to the pooled output of [`CLIPTextModel`].\n\n        Examples:\n\n        ```python\n        >>> from transformers import AutoTokenizer, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> inputs = tokenizer([\"a photo of a cat\", \"a photo of a dog\"], padding=True, return_tensors=\"pt\")\n        >>> text_features = model.get_text_features(**inputs)\n        ```\"\"\"\n        text_outputs = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        pooled_output = text_outputs[1]\n        text_features = self.text_projection(pooled_output)\n\n        return text_features\n\n    def get_image_features(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        r\"\"\"\n        Returns:\n            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by\n            applying the projection layer to the pooled output of [`CLIPVisionModel`].\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(images=image, return_tensors=\"pt\")\n\n        >>> image_features = model.get_image_features(**inputs)\n        ```\"\"\"\n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        vision_outputs = self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n        pooled_output = vision_outputs[1]  # pooled_output\n        image_features = self.visual_projection(pooled_output)\n\n        return image_features\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n    ):\n        r\"\"\"\n        Returns:\n\n        Examples:\n\n        ```python\n        >>> from PIL import Image\n        >>> import requests\n        >>> from transformers import AutoProcessor, CLIPModel\n\n        >>> model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n        >>> processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n\n        >>> url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n        >>> image = Image.open(requests.get(url, stream=True).raw)\n\n        >>> inputs = processor(\n        ...     text=[\"a photo of a cat\", \"a photo of a dog\"], images=image, return_tensors=\"pt\", padding=True\n        ... )\n\n        >>> outputs = model(**inputs)\n        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score\n        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities\n        ```\"\"\"\n        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.\n        vision_outputs = self.vision_model(\n            pixel_values=pixel_values,\n        )\n\n        text_outputs = self.text_model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n        )\n\n        image_embeds = vision_outputs[1]\n        image_embeds = self.visual_projection(image_embeds)\n\n        text_embeds = text_outputs[1]\n        text_embeds = self.text_projection(text_embeds)\n\n        # normalized features\n        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)\n        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)\n\n        # cosine similarity as logits\n        logit_scale = self.logit_scale.exp()\n        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale\n        logits_per_image = logits_per_text.t()\n\n        return logits_per_image, logits_per_text\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Cohere team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\nif SYSTEM == \"cuda\":\n    import dropout_layer_norm\nelse:\n    dropout_layer_norm = None\n\n\nclass CohereRotary(PositionRotaryEmbedding):\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        # Such controlflows may add some overhead.\n        if SYSTEM == \"cuda\":\n            from text_generation_server.utils.kernels import load_kernel\n\n            rotary = load_kernel(module=\"rotary\", repo_id=\"kernels-community/rotary\")\n\n            q1 = query[..., ::2]\n            q2 = query[..., 1::2]\n\n            rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)\n\n            k1 = key[..., ::2]\n            k2 = key[..., 1::2]\n\n            rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)\n        elif SYSTEM == \"rocm\":\n            import vllm._custom_ops as ops\n\n            # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.\n            # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773\n\n            head_size = query.shape[-1]\n\n            # Inplace operation, updating query and key.\n            ops.rotary_embedding(query, key, head_size, cos, sin, False)\n        elif SYSTEM == \"ipex\":\n            import intel_extension_for_pytorch as ipex\n\n            ipex.llm.functional.rotary_embedding(\n                query, key, sin, cos, query.size(-1), False\n            )\n        else:\n            raise ValueError(\n                \"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.\"\n            )\n\n\nclass CohereLayerNorm(nn.Module):\n    def __init__(self, prefix, weights, eps):\n        super().__init__()\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0)\n        self.weight = nn.Parameter(weight)\n        # Fake weights\n        self.ones = weight.new_ones(weight.shape[1])\n        self.eps = eps\n\n    def forward(self, hidden_states):\n        if hidden_states.shape[-1] > 8192 or SYSTEM != \"cuda\":\n            hidden_states = hidden_states.reshape(\n                -1, self.weight.shape[0], self.weight.shape[1]\n            )\n            input_dtype = hidden_states.dtype\n            hidden_states = hidden_states.to(torch.float32)\n            mean = hidden_states.mean(-1, keepdim=True)\n            hidden_states_minus_mean = hidden_states - mean\n            variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)\n            hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)\n            hidden_states = self.weight.to(torch.float32) * hidden_states\n            hidden_states = hidden_states.view(-1, self.weight.shape[1])\n            return hidden_states.to(input_dtype)\n\n        (\n            hidden_states,\n            *rest,\n        ) = dropout_layer_norm.dropout_add_ln_fwd(\n            hidden_states,\n            None,\n            self.ones,\n            None,\n            None,\n            None,\n            None,\n            None,\n            0.0,\n            self.eps,\n            1.0,\n            0,\n            None,\n            False,\n            False,\n        )\n\n        # Required to apply one weight matrix per head\n        hidden_states = hidden_states.view(\n            -1, self.weight.shape[0], self.weight.shape[1]\n        )\n        hidden_states = self.weight * hidden_states\n        hidden_states = hidden_states.view(-1, self.weight.shape[1])\n\n        return hidden_states\n\n\ndef load_attention(config, prefix, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    if config.attention_bias:\n        w = [\n            weights.get_sharded(f\"{p}.bias\", dim=0)\n            for p in [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n        ]\n        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=bias))\n\n\nclass FlashCohereAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = CohereRotary.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.use_qk_norm = config.use_qk_norm\n        if self.use_qk_norm:\n            self.q_norm = CohereLayerNorm(\n                prefix=f\"{prefix}.q_norm\",\n                weights=weights,\n                eps=config.layer_norm_eps,\n            )\n            self.k_norm = CohereLayerNorm(\n                prefix=f\"{prefix}.k_norm\",\n                weights=weights,\n                eps=config.layer_norm_eps,\n            )\n        else:\n            self.q_norm = None\n            self.k_norm = None\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, key, value = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_key_value_heads,\n                self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        if self.use_qk_norm:\n            query = query.reshape(-1, self.head_size)\n            key = key.reshape(-1, self.head_size)\n            query = self.q_norm(query.contiguous())\n            key = self.k_norm(key.contiguous())\n\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_key_value_heads, self.head_size)\n        value = value.view(-1, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), reduce=False\n        )\n\n\nclass CohereMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False\n        )\n\n\nclass FlashCohereLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = FlashCohereAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = CohereMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        mlp_output = self.mlp(normed_hidden_states)\n        output = attn_output + mlp_output\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(output, group=self.process_group)\n\n        return output, res\n\n\nclass FlashCohereModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.layers = nn.ModuleList(\n            [\n                FlashCohereLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.layer_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: torch.Tensor,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashCohereForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = FlashCohereModel(prefix, config, weights)\n        try:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=\"lm_head\",\n                weights=weights,\n            )\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=f\"{prefix}.embed_tokens\",\n                weights=weights,\n            )\n        self.logit_scale = config.logit_scale\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        logits *= self.logit_scale\n        if speculative_logits is not None:\n            speculative_logits *= self.logit_scale\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple, Any\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.kernels import load_kernel\n\nif SYSTEM == \"ipex\":\n    from intel_extension_for_pytorch.llm.modules import GatedMLPMOE\nelif SYSTEM == \"cuda\":\n    moe_kernels = load_kernel(module=\"moe\", repo_id=\"kernels-community/moe\")\nelse:\n    import moe_kernels\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    FastLinear,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\n\n\nclass DbrxAttentionConfig(PretrainedConfig):\n    def __init__(\n        self,\n        attn_pdrop: float = 0,\n        clip_qkv: Optional[float] = None,\n        kv_n_heads: int = 1,\n        rope_theta: float = 10000.0,\n        **kwargs: Any,\n    ):\n        super().__init__(**kwargs)\n        self.attn_pdrop = attn_pdrop\n        self.clip_qkv = clip_qkv\n        self.kv_n_heads = kv_n_heads\n        self.rope_theta = rope_theta\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n\nclass DbrxFFNConfig(PretrainedConfig):\n    def __init__(\n        self,\n        ffn_act_fn: Optional[dict] = None,\n        ffn_hidden_size: int = 3584,\n        moe_num_experts: int = 4,\n        moe_top_k: int = 1,\n        moe_jitter_eps: Optional[float] = None,\n        moe_loss_weight: float = 0.01,\n        moe_normalize_expert_weights: Optional[float] = 1,\n        uniform_expert_assignment: bool = False,\n        **kwargs: Any,\n    ):\n        super().__init__()\n        if ffn_act_fn is None:\n            ffn_act_fn = {\"name\": \"silu\"}\n        self.ffn_act_fn = ffn_act_fn\n        self.ffn_hidden_size = ffn_hidden_size\n        self.moe_num_experts = moe_num_experts\n        self.moe_top_k = moe_top_k\n        self.moe_jitter_eps = moe_jitter_eps\n        self.moe_loss_weight = moe_loss_weight\n        self.moe_normalize_expert_weights = moe_normalize_expert_weights\n        self.uniform_expert_assignment = uniform_expert_assignment\n\n        if uniform_expert_assignment:\n            raise ValueError(\"`uniform_expert_assignment = True` is not supported\")\n\n        for k in [\"model_type\"]:\n            if k in kwargs:\n                kwargs.pop(k)\n        if len(kwargs) != 0:\n            raise ValueError(f\"Found unknown {kwargs=}\")\n\n\nclass DbrxConfig(PretrainedConfig):\n    attribute_map = {\n        \"hidden_size\": \"d_model\",\n        \"num_attention_heads\": \"n_heads\",\n        \"num_hidden_layers\": \"n_layers\",\n    }\n\n    def __init__(\n        self,\n        d_model: int = 2048,\n        n_heads: int = 16,\n        n_layers: int = 24,\n        max_seq_len: int = 2048,\n        vocab_size: int = 32000,\n        resid_pdrop: float = 0.0,\n        emb_pdrop: float = 0.0,\n        attn_config: Optional[DbrxAttentionConfig] = None,\n        ffn_config: Optional[DbrxFFNConfig] = None,\n        use_cache: bool = True,\n        initializer_range: float = 0.02,\n        output_router_logits: bool = False,\n        router_aux_loss_coef: float = 0.05,\n        **kwargs: Any,\n    ):\n        if attn_config is None:\n            self.attn_config = DbrxAttentionConfig()\n        elif isinstance(attn_config, dict):\n            self.attn_config = DbrxAttentionConfig(**attn_config)\n        else:\n            self.attn_config = attn_config\n\n        if ffn_config is None:\n            self.ffn_config = DbrxFFNConfig()\n        elif isinstance(ffn_config, dict):\n            self.ffn_config = DbrxFFNConfig(**ffn_config)\n        else:\n            self.ffn_config = ffn_config\n\n        self.d_model = d_model\n        self.n_heads = n_heads\n        self.n_layers = n_layers\n        self.max_seq_len = max_seq_len\n        self.vocab_size = vocab_size\n        self.resid_pdrop = resid_pdrop\n        self.emb_pdrop = emb_pdrop\n        self.use_cache = use_cache\n        self.initializer_range = initializer_range\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\"tie_word_embeddings is not supported for Dbrx models.\")\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    @property\n    def num_key_value_heads(self):\n        # We can't use the attribute map, since this the number of KV\n        # heads is not top-level.\n        return self.attn_config.kv_n_heads\n\n\ndef promote_scalar(x: torch.Tensor) -> torch.Tensor:\n    return x.view(1) if len(x.size()) == 0 else x\n\n\ndef load_attention(config, prefix, weights):\n    return TensorParallelColumnLinear.load_qkv(\n        config,\n        prefix=f\"{prefix}.Wqkv\",\n        weights=weights,\n        bias=False,\n        num_heads=config.n_heads,\n        num_key_value_heads=config.attn_config.kv_n_heads,\n    )\n\n\ndef _load_experts(config, prefix, weights):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.ffn_config.ffn_hidden_size % world_size == 0\n    ), f\"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards\"\n\n    expert_size = config.ffn_config.ffn_hidden_size\n    block_size = expert_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    tensor = torch.empty(\n        (config.ffn_config.moe_num_experts * block_size, config.d_model),\n        dtype=weights.dtype,\n        device=weights.device,\n    )\n\n    slice_ = weights._get_slice(f\"{prefix}\")\n\n    for i in range(config.ffn_config.moe_num_experts):\n        offset = i * expert_size\n        expert_slice = slice_[start + offset : stop + offset]\n\n        tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(\n            dtype=weights.dtype\n        ).to(device=weights.device)\n    return tensor\n\n\ndef _load_experts_quantized(config, prefix, weights, cls):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.ffn_config.ffn_hidden_size % world_size == 0\n    ), f\"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards\"\n\n    expert_size = config.ffn_config.ffn_hidden_size\n    block_size = expert_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    slice_ = weights._get_slice(f\"{prefix}\")\n\n    experts = []\n    for i in range(config.ffn_config.moe_num_experts):\n        if config.quantize in [\"gptq\", \"awq\"]:\n            raise NotImplementedError(\n                \"Dbrx does not support gptq/awq quantization yet.\"\n            )\n        else:\n            offset = i * expert_size\n            expert_slice = (\n                slice_[start + offset : stop + offset]\n                .to(dtype=weights.dtype)\n                .to(device=weights.device)\n            )\n\n        if cls == TensorParallelRowLinear:\n            expert_slice = expert_slice.t().contiguous()\n            linear = get_linear(expert_slice, None)\n            experts.append(cls(linear, weights.process_group))\n        else:\n            linear = get_linear(expert_slice, None)\n            experts.append(cls(linear))\n\n    return experts\n\n\nclass DbrxAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.clip_qkv = config.attn_config.clip_qkv\n        self.num_heads = config.n_heads\n        self.hidden_size = config.d_model\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.attn_config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.attn_config.kv_n_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        if self.clip_qkv is not None:\n            qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)\n\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass DbrxNormAttentionNorm(nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.norm_1 = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_1\", weights=weights, eps=1e-5\n        )\n        self.self_attn = DbrxAttention(\n            prefix=f\"{prefix}.attn\", config=config, weights=weights\n        )\n        self.norm_2 = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_2\",\n            weights=weights,\n            eps=1e-5,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        normed_hidden_states, res = self.norm_1(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.norm_2(attn_output, res)\n\n        return normed_attn_res_output, attn_res\n\n\n@torch.jit.script\ndef select_experts(\n    gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int\n):\n    # all_probs: (sequence_length, n_experts) and upcast for softmax\n    all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n    # weights, selected_experts: (sequence_length, top-k)\n    weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)\n    if moe_normalize_expert_weights:\n        weights = weights / torch.norm(\n            weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True\n        )\n    weights = weights.view(-1)\n    selected_experts = selected_experts.view(-1)\n\n    return selected_experts, weights\n\n\n@torch.jit.script\ndef round_up(x: torch.Tensor, value: int):\n    return torch.div(x + (value - 1), value, rounding_mode=\"trunc\") * value\n\n\nclass BlockSparseMoE(nn.Module):\n    def __init__(self, prefix, config: DbrxConfig, weights):\n        super().__init__()\n        self.moe_normalize_expert_weights = (\n            config.ffn_config.moe_normalize_expert_weights\n        )\n        self.hidden_dim = config.d_model\n        self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()\n        self.num_experts = config.ffn_config.moe_num_experts\n        self.top_k = config.ffn_config.moe_top_k\n\n        act = config.ffn_config.ffn_act_fn[\"name\"]\n        if \"gelu\" in act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        elif \"silu\" in act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[act]\n\n        # gating\n        self.gate = FastLinear.load(\n            config, f\"{prefix}.router.layer\", weights, bias=False\n        )\n\n        # merged expert weights, all of size  (n_experts * ffn_dim, hidden_dim)\n        w1 = _load_experts(config, f\"{prefix}.experts.mlp.w1\", weights).view(\n            self.num_experts, self.ffn_dim, self.hidden_dim\n        )\n        v1 = _load_experts(config, f\"{prefix}.experts.mlp.v1\", weights).view(\n            self.num_experts, self.ffn_dim, self.hidden_dim\n        )\n        self.wv1 = torch.cat([w1, v1], dim=1)\n        self.w2 = (\n            _load_experts(config, f\"{prefix}.experts.mlp.w2\", weights)\n            .view(self.num_experts, self.ffn_dim, self.hidden_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n        self.process_group = weights.process_group\n        if SYSTEM == \"ipex\":\n            self.ipex_fused_moe = GatedMLPMOE(\n                W13=self.wv1, W2=self.w2, use_prepack=True\n            )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n\n        if SYSTEM == \"ipex\":\n            out = self.ipex_fused_moe(\n                hidden_states=x,\n                router_logits=router_logits,\n                top_k=self.top_k,\n                renormalize=self.moe_normalize_expert_weights,\n                use_grouped_topk=False,\n                num_expert_group=None,\n                topk_group=None,\n            )\n        else:\n            out = moe_kernels.fused_moe(\n                x,\n                self.wv1,\n                self.w2,\n                router_logits,\n                self.top_k,\n                renormalize=self.moe_normalize_expert_weights,\n                inplace=True,\n            )\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DenseMoE(nn.Module):\n    def __init__(self, prefix, config: DbrxConfig, weights):\n        super().__init__()\n\n        self.moe_normalize_expert_weights = (\n            config.ffn_config.moe_normalize_expert_weights\n        )\n        self.hidden_dim = config.d_model\n        self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()\n        self.num_experts = config.ffn_config.moe_num_experts\n        self.top_k = config.ffn_config.moe_top_k\n\n        act = config.ffn_config.ffn_act_fn[\"name\"]\n        if \"gelu\" in act:\n            self.act = lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        elif \"silu\" in act:\n            self.act = torch.nn.functional.silu\n        else:\n            self.act = ACT2FN[act]\n\n        # gating\n        self.gate = FastLinear.load(\n            config, f\"{prefix}.router.layer\", weights, bias=False\n        )\n\n        self.w1 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.w1\",\n            weights=weights,\n            cls=TensorParallelColumnLinear,\n        )\n        self.w2 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.w2\",\n            weights=weights,\n            cls=TensorParallelRowLinear,\n        )\n        self.v1 = _load_experts_quantized(\n            config,\n            prefix=f\"{prefix}.experts.mlp.v1\",\n            weights=weights,\n            cls=TensorParallelColumnLinear,\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        x: (sequence_length, model_dim)\n        gate_logits: (sequence_length, n_experts)\n        \"\"\"\n        # optional reshape\n        input_shape = x.shape\n        x = x.view(-1, input_shape[-1])\n\n        # gate_logits: (sequence_length, n_experts)\n        gate_logits = self.gate(x)\n        # all_probs: (sequence_length, n_experts) and upcast for softmax\n        weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n\n        if self.top_k < self.num_experts:\n            _, not_selected_experts = torch.topk(\n                weights,\n                self.num_experts - self.top_k,\n                largest=False,\n                sorted=False,\n                dim=1,\n            )\n            # Mask not selected experts\n            weights.scatter_(1, not_selected_experts, 0)\n\n        # Re-normalize\n        if self.moe_normalize_expert_weights:\n            weights = weights / torch.norm(\n                weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True\n            )\n        weights = weights.to(x.dtype)\n\n        # Final output tensor\n        out = x.new_zeros(x.shape[0], self.hidden_dim)\n        for i in range(self.num_experts):\n            h = self.act(self.w1[i](x)) * self.v1[i](x)\n            h = self.w2[i](h, reduce=False)\n            # Add expert output to out with masking\n            out += h * weights[:, i].view(-1, 1)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out\n\n\nclass DbrxLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.blocks.{layer_id}\"\n\n        self.attn = DbrxNormAttentionNorm(\n            prefix=f\"{prefix}.norm_attn_norm\", config=config, weights=weights\n        )\n\n        moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE\n        self.moe = moe_cls(f\"{prefix}.ffn\", config, weights)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        # Self Attention\n        attn_output, attn_res = self.attn(\n            hidden_states,\n            residual,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        moe_output = self.moe(attn_output)\n\n        return moe_output, attn_res\n\n\nclass DbrxModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wte\", weights=weights\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                DbrxLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.n_layers)\n            ]\n        )\n        self.norm = FastLayerNorm.load_no_bias(\n            prefix=f\"{prefix}.norm_f\", weights=weights, eps=1e-5\n        )\n\n        self.head_size = self.layers[0].attn.self_attn.head_size\n        self.num_heads = self.layers[0].attn.self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDbrxForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        self.model = DbrxModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention,\n)\nfrom text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.weights import Weights\n\nif SYSTEM == \"rocm\":\n    try:\n        import vllm._custom_ops as ops\n    except Exception as e:\n        raise ImportError(f\"Could not load `vllm._custom_ops`. Full error: {e}\")\n\n\nclass DeepseekV2Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=2,\n        n_routed_experts=160,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=8,\n        topk_group=3,\n        num_experts_per_tok=6,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\n                \"tie_word_embeddings is not supported for Deepseek V2 models.\"\n            )\n\n        if ep_size != 1:\n            raise ValueError(\n                f\"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}\"\n            )\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass DeepseekV2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights: Weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.kv_lora_rank = config.kv_lora_rank\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim\n        self.value_head_size = config.v_head_dim\n        self.head_pad_size = max(self.head_size, self.value_head_size)\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.qk_rope_head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        mscale = get_mscale(\n            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim\n        )\n        self.softmax_scale = self.head_size**-0.5 * mscale * mscale\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        if self.q_lora_rank is None:\n            self.q_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n        else:\n            self.q_a_proj = get_linear(\n                weight=weights.get_weights(f\"{prefix}.q_a_proj\"),\n                bias=(\n                    weights.get_tensor(f\"{prefix}.q_a_proj.bias\")\n                    if config.attention_bias\n                    else None\n                ),\n            )\n            self.q_a_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.q_a_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.q_b_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_b_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n\n        self.kv_a_proj_with_mqa = get_linear(\n            weight=weights.get_weights(f\"{prefix}.kv_a_proj_with_mqa\"),\n            bias=(\n                weights.get_tensor(f\"{prefix}.kv_a_proj_with_mqa.bias\")\n                if config.attention_bias\n                else None\n            ),\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.kv_a_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.kv_a_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.kv_b_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.kv_b_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache: KVCache,\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ):\n        if self.q_lora_rank is None:\n            query = self.q_proj(hidden_states)\n        else:\n            query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])\n        query = query.view(-1, self.num_heads, self.head_size)\n\n        _, query_pe = torch.split(\n            query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, key_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n\n        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)\n        kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(\n            -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size\n        )\n\n        key_nope, value = torch.split(\n            kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1\n        )\n\n        batch_size, heads, head_dim = query_pe.shape\n        query_pe = (\n            query_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        batch_size, heads, head_dim = key_pe.shape\n        key_pe = (\n            key_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        self.rotary_emb(query_pe, key_pe, cos, sin)\n\n        query[..., self.qk_nope_head_dim :] = query_pe\n        key = torch.empty_like(query)\n        key[..., : self.qk_nope_head_dim] = key_nope\n        key[..., self.qk_nope_head_dim :] = key_pe\n\n        # We need to pad the heads because Flash Attention does not support\n        # qk and v with different head sizes.\n        query = torch.nn.functional.pad(\n            query, (0, self.head_pad_size - self.head_size), value=0\n        )\n        key = torch.nn.functional.pad(\n            key, (0, self.head_pad_size - self.head_size), value=0\n        )\n        value = torch.nn.functional.pad(\n            value, (0, self.head_pad_size - self.value_head_size), value=0\n        )\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        # Remove padding.\n        attn_output = attn_output[..., : self.value_head_size]\n\n        return self.o_proj(\n            attn_output.reshape(-1, self.num_heads * self.value_head_size)\n        )\n\n\nclass DeepseekV2MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, intermediate_size: int):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        if self.hidden_act != \"silu\":\n            # Bail out because MoE only supports silu.\n            raise NotImplementedError(\n                \"Currently only `silu` is supported as an activation for Deepseek V2.\"\n            )\n        self.act = ACT2FN[self.hidden_act]\n\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states: torch.Tensor, reduce: bool = True):\n        if (\n            SYSTEM == \"rocm\"\n            and self.hidden_act == \"silu\"\n            and hidden_states.dtype == torch.float16\n            and hidden_states.shape[0] == 1\n            and not self.quantize\n        ):\n            out = torch.empty(\n                hidden_states.shape[0],\n                self.intermediate_size,\n                dtype=hidden_states.dtype,\n                device=\"cuda\",\n            )\n            ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)\n            return self.down_proj(out, reduce=reduce)\n        else:\n            gate_up_states = self.gate_up_proj(hidden_states)\n            gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n            return self.down_proj(\n                self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce\n            )\n\n\nclass DeepseekV2MoE(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config: DeepseekV2Config,\n        moe_layer_cls: Type[MoELayer],\n        weights,\n    ):\n        super().__init__()\n\n        self.hidden_dim = config.hidden_size\n        self.moe_intermediate_size = (\n            config.moe_intermediate_size // weights.process_group.size()\n        )\n        self.routed_scaling_factor = config.routed_scaling_factor\n\n        # Gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe_layer = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.n_routed_experts,\n            n_expert_group=config.n_group,\n            renormalize=config.norm_topk_prob,\n            topk=config.num_experts_per_tok,\n            topk_group=config.topk_group,\n            weights=weights,\n        )\n        assert isinstance(self.moe_layer, MoELayer)\n\n        if config.n_shared_experts is not None:\n            self.shared_experts = DeepseekV2MLP(\n                prefix=f\"{prefix}.shared_experts\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.moe_intermediate_size\n                * config.n_shared_experts,\n            )\n        else:\n            self.shared_experts = None\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.shared_experts is not None:\n            shared_output = self.shared_experts(x, reduce=False)\n        else:\n            shared_output = None\n\n        router_logits = self.gate(x)\n\n        out = self.moe_layer(x, gating_output=router_logits)\n\n        if shared_output is not None:\n            out = out + shared_output\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DeepseekV2Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = DeepseekV2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n        )\n\n        if (\n            config.n_routed_experts is not None\n            and layer_id >= config.first_k_dense_replace\n            and layer_id % config.moe_layer_freq == 0\n        ):\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = DeepseekV2MoE(f\"{prefix}.mlp\", config, moe_layer_cls, weights)\n        else:\n            self.mlp = DeepseekV2MLP(\n                prefix=f\"{prefix}.mlp\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.intermediate_size,\n            )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache,\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, residual = self.post_attention_layernorm(\n            attn_output, residual\n        )\n\n        output = self.mlp(normed_attn_res_output)\n\n        return output, residual\n\n\nclass DeepseekV2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV2Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDeepseekV2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.model = DeepseekV2Model(\n            \"model\" if not prefix else f\"{prefix}.model\", config, weights\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention,\n)\nfrom text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils.weights import Weights\n\nif SYSTEM == \"rocm\":\n    try:\n        import vllm._custom_ops as ops\n    except Exception as e:\n        raise ImportError(f\"Could not load `vllm._custom_ops`. Full error: {e}\")\n\n\nclass DeepseekV3Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size=1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts=2,\n        n_routed_experts=160,\n        ep_size=1,\n        routed_scaling_factor=1.0,\n        kv_lora_rank=512,\n        q_lora_rank=1536,\n        qk_rope_head_dim=64,\n        v_head_dim=128,\n        qk_nope_head_dim=128,\n        topk_method=\"gready\",\n        n_group=8,\n        topk_group=3,\n        num_experts_per_tok=6,\n        moe_layer_freq=1,\n        first_k_dense_replace=0,\n        norm_topk_prob=False,\n        scoring_func=\"softmax\",\n        aux_loss_alpha=0.001,\n        seq_aux=True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        tie_word_embeddings = kwargs.pop(\"tie_word_embeddings\", False)\n        if tie_word_embeddings:\n            raise ValueError(\n                \"tie_word_embeddings is not supported for Deepseek V2 models.\"\n            )\n\n        if ep_size != 1:\n            raise ValueError(\n                f\"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}\"\n            )\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass DeepseekV3Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights: Weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.kv_lora_rank = config.kv_lora_rank\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim\n        self.value_head_size = config.v_head_dim\n        self.head_pad_size = max(self.head_size, self.value_head_size)\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.qk_rope_head_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        mscale = get_mscale(\n            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim\n        )\n        self.softmax_scale = self.head_size**-0.5 * mscale * mscale\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        if self.q_lora_rank is None:\n            self.q_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n        else:\n            self.q_a_proj = get_linear(\n                weight=weights.get_weights(f\"{prefix}.q_a_proj\"),\n                bias=(\n                    weights.get_tensor(f\"{prefix}.q_a_proj.bias\")\n                    if config.attention_bias\n                    else None\n                ),\n            )\n            self.q_a_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.q_a_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.q_b_proj = TensorParallelColumnLinear.load(\n                config,\n                prefix=f\"{prefix}.q_b_proj\",\n                weights=weights,\n                bias=config.attention_bias,\n            )\n\n        self.kv_a_proj_with_mqa = get_linear(\n            weight=weights.get_weights(f\"{prefix}.kv_a_proj_with_mqa\"),\n            bias=(\n                weights.get_tensor(f\"{prefix}.kv_a_proj_with_mqa.bias\")\n                if config.attention_bias\n                else None\n            ),\n        )\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.kv_a_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.kv_a_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.kv_b_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.kv_b_proj\",\n            weights=weights,\n            bias=config.attention_bias,\n        )\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache: KVCache,\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ):\n        if self.q_lora_rank is None:\n            query = self.q_proj(hidden_states)\n        else:\n            query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])\n        query = query.view(-1, self.num_heads, self.head_size)\n\n        _, query_pe = torch.split(\n            query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, key_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n\n        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)\n        kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(\n            -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size\n        )\n\n        key_nope, value = torch.split(\n            kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1\n        )\n\n        batch_size, heads, head_dim = query_pe.shape\n        query_pe = (\n            query_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        batch_size, heads, head_dim = key_pe.shape\n        key_pe = (\n            key_pe.view(batch_size, heads, head_dim // 2, 2)\n            .transpose(2, 3)\n            .reshape(batch_size, heads, head_dim)\n        )\n        self.rotary_emb(query_pe, key_pe, cos, sin)\n\n        query[..., self.qk_nope_head_dim :] = query_pe\n        key = torch.empty_like(query)\n        key[..., : self.qk_nope_head_dim] = key_nope\n        key[..., self.qk_nope_head_dim :] = key_pe\n\n        # We need to pad the heads because Flash Attention does not support\n        # qk and v with different head sizes.\n        query = torch.nn.functional.pad(\n            query, (0, self.head_pad_size - self.head_size), value=0\n        )\n        key = torch.nn.functional.pad(\n            key, (0, self.head_pad_size - self.head_size), value=0\n        )\n        value = torch.nn.functional.pad(\n            value, (0, self.head_pad_size - self.value_head_size), value=0\n        )\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        # Remove padding.\n        attn_output = attn_output[..., : self.value_head_size]\n\n        return self.o_proj(\n            attn_output.reshape(-1, self.num_heads * self.value_head_size)\n        )\n\n\nclass DeepseekV3MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, intermediate_size: int):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        if self.hidden_act != \"silu\":\n            # Bail out because MoE only supports silu.\n            raise NotImplementedError(\n                \"Currently only `silu` is supported as an activation for Deepseek V2.\"\n            )\n        self.act = ACT2FN[self.hidden_act]\n\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states: torch.Tensor, reduce: bool = True):\n        if (\n            SYSTEM == \"rocm\"\n            and self.hidden_act == \"silu\"\n            and hidden_states.dtype == torch.float16\n            and hidden_states.shape[0] == 1\n            and not self.quantize\n        ):\n            out = torch.empty(\n                hidden_states.shape[0],\n                self.intermediate_size,\n                dtype=hidden_states.dtype,\n                device=\"cuda\",\n            )\n            ops.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)\n            return self.down_proj(out, reduce=reduce)\n        else:\n            gate_up_states = self.gate_up_proj(hidden_states)\n            gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n            return self.down_proj(\n                self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce\n            )\n\n\nclass DeepseekV3MoE(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config: DeepseekV3Config,\n        moe_layer_cls: Type[MoELayer],\n        weights,\n    ):\n        super().__init__()\n\n        self.hidden_dim = config.hidden_size\n        self.moe_intermediate_size = (\n            config.moe_intermediate_size // weights.process_group.size()\n        )\n        self.routed_scaling_factor = config.routed_scaling_factor\n\n        # Gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        if config.topk_method == \"noaux_tc\":\n            self.gate.e_score_correction_bias = torch.zeros(\n                config.n_routed_experts, device=weights.device\n            )\n        else:\n            self.gate.e_score_correction_bias = None\n\n        self.moe_layer = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.n_routed_experts,\n            n_expert_group=config.n_group,\n            renormalize=config.norm_topk_prob,\n            topk=config.num_experts_per_tok,\n            topk_group=config.topk_group,\n            weights=weights,\n            scoring_func=config.scoring_func,\n            e_score_correction_bias=self.gate.e_score_correction_bias,\n        )\n        assert isinstance(self.moe_layer, MoELayer)\n\n        if config.n_shared_experts is not None:\n            self.shared_experts = DeepseekV3MLP(\n                prefix=f\"{prefix}.shared_experts\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.moe_intermediate_size\n                * config.n_shared_experts,\n            )\n        else:\n            self.shared_experts = None\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.shared_experts is not None:\n            shared_output = self.shared_experts(x, reduce=False)\n        else:\n            shared_output = None\n\n        router_logits = self.gate(x)\n\n        out = self.moe_layer(x, gating_output=router_logits)\n\n        if shared_output is not None:\n            out = out + shared_output\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass DeepseekV3Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = DeepseekV3Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n        )\n\n        if (\n            config.n_routed_experts is not None\n            and layer_id >= config.first_k_dense_replace\n            and layer_id % config.moe_layer_freq == 0\n        ):\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = DeepseekV3MoE(f\"{prefix}.mlp\", config, moe_layer_cls, weights)\n        else:\n            self.mlp = DeepseekV3MLP(\n                prefix=f\"{prefix}.mlp\",\n                config=config,\n                weights=weights,\n                intermediate_size=config.intermediate_size,\n            )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n        cu_seqlen_prefill: torch.Tensor,\n        kv_cache,\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, residual = self.post_attention_layernorm(\n            attn_output, residual\n        )\n\n        output = self.mlp(normed_attn_res_output)\n\n        return output, residual\n\n\nclass DeepseekV3Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV3Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashDeepseekV3ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights: Weights):\n        super().__init__()\n\n        self.model = DeepseekV3Model(\n            \"model\" if not prefix else f\"{prefix}.model\", config, weights\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\nclass Gemma2Config(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=256128,\n        hidden_size=3072,\n        intermediate_size=24576,\n        num_hidden_layers=28,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        head_dim=256,\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=8192,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=True,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.head_dim = head_dim\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass Gemma2FastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemma2Attention(torch.nn.Module):\n    def __init__(\n        self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n        if is_sliding:\n            self.window_size = config.sliding_window\n        else:\n            self.window_size = -1\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        # self.softmax_scale = self.head_size**-0.5\n        self.softmax_scale = config.query_pre_attn_scalar**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n        self.softcap = config.attn_logit_softcapping\n\n        query_key_value = load_attention(config, prefix, weights)\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.window_size,\n                softcap=self.softcap,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                softcap=self.softcap,\n                kv_scales=self.kv_scales,\n                window_size_left=self.window_size,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Gemma2MLP(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        act = config.hidden_activation\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass FlashGemma2Layer(nn.Module):\n    def __init__(\n        self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool\n    ):\n        super().__init__()\n        self.self_attn = FlashGemma2Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n            causal=causal,\n            is_sliding=is_sliding,\n        )\n        self.mlp = Gemma2MLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.pre_feedforward_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.pre_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.post_feedforward_layernorm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.post_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            adapter_data,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)\n        normed_attn_res_output = normed_attn_res_output + res\n        res = normed_attn_res_output\n\n        pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)\n        mlp_output = self.mlp(pre_normed, adapter_data)\n        post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)\n\n        return post_hidden_states, normed_attn_res_output\n\n\nclass FlashGemma2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                FlashGemma2Layer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                    causal=causal,\n                    is_sliding=layer_id % 2 == 0,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = Gemma2FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                adapter_data,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemma2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemma2Model(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n        self.softcap = config.final_logit_softcapping\n        assert isinstance(self.softcap, float)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        logits /= self.softcap\n        logits = torch.tanh(logits)\n        logits *= self.softcap\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom typing import Optional, List, Tuple\nimport copy\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n    #\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\n\nimport torch\nimport torch.nn.functional as F\n\n\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\n\n\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.models.globals import ATTENTION\nfrom text_generation_server.utils.weights import UnquantizedWeight\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\n\n\nATTENTION_TYPE_GLOBAL = \"global\"\nATTENTION_TYPE_LOCAL = \"local_sliding\"\n\n\nclass Gemma3FastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemma3Attention(torch.nn.Module):\n    def __init__(\n        self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n        if is_sliding:\n            self.window_size = config.sliding_window\n            # TODO: remove this hack to support local sliding window\n            config = copy.deepcopy(config)\n            config.rope_scaling = dict(rope_type=\"default\")\n            self.rotary_emb = PositionRotaryEmbedding.static(\n                config=config,\n                dim=config.head_dim,\n                base=config.rope_local_base_freq,\n                device=weights.device,\n            )\n        else:\n            self.window_size = -1\n            self.rotary_emb = PositionRotaryEmbedding.static(\n                config=config,\n                dim=config.head_dim,\n                base=config.rope_theta,\n                device=weights.device,\n            )\n\n        self.softmax_scale = (\n            config.query_pre_attn_scalar**-0.5\n            if config.query_pre_attn_scalar is not None\n            else None\n        )\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n        self.softcap = None  # config.attn_logit_softcapping\n\n        query_key_value = load_attention(config, prefix, weights)\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n        self.q_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.k_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.enable_gqa = self.num_heads != self.num_key_value_heads\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n        attention_mask,\n    ):\n\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size)\n        key = kv[:, 0]\n        value = kv[:, 1]\n\n        query = query.reshape(-1, self.head_size)\n        key = key.reshape(-1, self.head_size)\n\n        query, _ = self.q_norm(query.contiguous())\n        key, _ = self.k_norm(key.contiguous())\n\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_key_value_heads, self.head_size)\n        value = value.view(-1, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            if attention_mask is None or ATTENTION == \"flashinfer\":\n                # flash attention\n                attn_output = attention(\n                    query=query,\n                    key=key,\n                    value=value,\n                    kv_cache=kv_cache,\n                    kv_scales=self.kv_scales,\n                    seqlen=seqlen,\n                    block_tables=block_tables,\n                    softmax_scale=self.softmax_scale,\n                    window_size_left=self.window_size,\n                    softcap=self.softcap,\n                )\n            else:\n                lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]\n\n                # Split tensors using vectorized split\n                query_list = torch.split(query, lengths.tolist(), dim=0)\n                key_list = torch.split(key, lengths.tolist(), dim=0)\n                value_list = torch.split(value, lengths.tolist(), dim=0)\n\n                padded_query = torch.nn.utils.rnn.pad_sequence(\n                    query_list, batch_first=True\n                )\n                padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)\n                padded_value = torch.nn.utils.rnn.pad_sequence(\n                    value_list, batch_first=True\n                )\n\n                padded_query = padded_query.transpose(1, 2).contiguous()\n                padded_key = padded_key.transpose(1, 2).contiguous()\n                padded_value = padded_value.transpose(1, 2).contiguous()\n\n                # Compute attention\n                attn_output = F.scaled_dot_product_attention(\n                    padded_query,\n                    padded_key,\n                    padded_value,\n                    attn_mask=attention_mask,\n                    scale=self.softmax_scale,\n                    enable_gqa=self.enable_gqa,\n                )\n\n                attn_output = attn_output.transpose(\n                    1, 2\n                )  # [batch_size, seq_len, num_heads, head_dim]\n                max_seq_len = padded_query.size(2)\n                seq_range = torch.arange(\n                    max_seq_len, device=padded_query.device\n                ).unsqueeze(0)\n                lengths_tensor = torch.tensor(\n                    lengths, device=padded_query.device\n                ).unsqueeze(1)\n                mask = seq_range < lengths_tensor  # [batch, max_seq_len]\n                attn_output = attn_output[mask]  # [total_seq_len, num_heads, head_dim]\n\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                softcap=self.softcap,\n                kv_scales=self.kv_scales,\n                window_size_left=self.window_size,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Gemma3MLP(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        act = config.hidden_activation\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass FlashGemma3Layer(nn.Module):\n    def __init__(\n        self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool\n    ):\n        super().__init__()\n        self.self_attn = FlashGemma3Attention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n            causal=causal,\n            is_sliding=is_sliding,\n        )\n        self.mlp = Gemma3MLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.pre_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.post_feedforward_layernorm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.post_feedforward_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n        attention_mask,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            adapter_data,\n            attention_mask,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, _ = self.post_attention_layernorm(attn_output)\n        normed_attn_res_output = normed_attn_res_output + res\n        res = normed_attn_res_output\n\n        pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)\n        mlp_output = self.mlp(pre_normed, adapter_data)\n        post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)\n\n        return post_hidden_states, normed_attn_res_output\n\n\nclass FlashGemma3Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        self.layers = nn.ModuleList(\n            [\n                FlashGemma3Layer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                    causal=causal,\n                    is_sliding=bool((layer_id + 1) % config.sliding_window_pattern),\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        adapter_data: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        attention_mask_local: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin(\n                position_ids, max_s, hidden_states.dtype\n            )\n\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                adapter_data,\n                (\n                    attention_mask\n                    if self.layers[i].self_attn.window_size == -1\n                    else attention_mask_local\n                ),\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemma3ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemma3Model(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n        # self.softcap = config.attn_logit_softcapping\n        # assert isinstance(self.softcap, float)\n        self.softcap = None\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        return logits, speculative_logits\n\n\nclass Gemma3MultimodalInputProjection(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.mm_input_projection_weight = weights.get_tensor(\n            \"multi_modal_projector.mm_input_projection_weight\"\n        )\n\n        self.mm_soft_emb_norm = Gemma3FastRMSNorm.load(\n            prefix=f\"{prefix}.mm_soft_emb_norm\",\n            weights=weights,\n            eps=config.vision_config.layer_norm_eps,\n        )\n\n        self.patches_per_image = int(\n            config.vision_config.image_size // config.vision_config.patch_size\n        )\n        self.tokens_per_side = int(config.mm_tokens_per_image**0.5)\n        self.kernel_size = self.patches_per_image // self.tokens_per_side\n        self.avg_pool = nn.AvgPool2d(\n            kernel_size=self.kernel_size, stride=self.kernel_size\n        )\n\n    def forward(self, vision_outputs: torch.Tensor):\n        batch_size, _, seq_length = vision_outputs.shape\n\n        reshaped_vision_outputs = vision_outputs.transpose(1, 2)\n        reshaped_vision_outputs = reshaped_vision_outputs.reshape(\n            batch_size, seq_length, self.patches_per_image, self.patches_per_image\n        )\n        reshaped_vision_outputs = reshaped_vision_outputs.contiguous()\n\n        pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)\n        pooled_vision_outputs = pooled_vision_outputs.flatten(2)\n        pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)\n\n        normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs)\n\n        projected_vision_outputs = torch.matmul(\n            normed_vision_outputs, self.mm_input_projection_weight\n        )\n        return projected_vision_outputs.type_as(vision_outputs)\n\n\nclass Gemma3ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.config = config\n\n        if config.vision_config is not None:\n\n            config.vision_config.quantize = config.quantize\n\n            self.post_vision_model_layernorm = nn.LayerNorm.load(\n                prefix=\"vision_tower.vision_model.post_layernorm\",\n                weights=weights,\n                eps=config.vision_config.layer_norm_eps,\n            )\n\n            self.multimodal_projector = Gemma3MultimodalInputProjection(\n                prefix=\"multi_modal_projector\",\n                config=config,\n                weights=weights,\n            )\n\n            text_config = config.text_config\n            text_config.speculator = config.speculator\n            text_config.quantize = config.quantize\n\n            self.vision_model = load_vision_model(\n                prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n                config=config.vision_config,\n                weights=weights,\n            )\n\n            self.text_model = load_text_model(\n                prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n                config=config.text_config,\n                weights=weights,\n            )\n        else:\n            config.text_config.quantize = config.quantize\n            config.text_config.speculator = config.speculator\n            self.text_model = load_text_model(\n                prefix=prefix,\n                config=config.text_config,\n                weights=weights,\n            )\n\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n        self.dtype = weights.dtype\n\n    def get_attention_mask(\n        self,\n        input_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        dtype: torch.dtype,\n        bool_mask: bool = False,\n    ):\n        image_token_mask = (input_ids == self.config.image_token_index).to(\n            input_ids.device\n        )\n\n        device = input_ids.device\n        min_dtype = torch.finfo(dtype).min\n\n        lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()\n        batch_size = len(lengths)\n\n        sequence_length = max(lengths)\n        target_length = sequence_length\n        # Create the padding mask from the computed lengths.\n        # pad_mask: [batch, sequence_length] where True indicates valid tokens.\n        seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)\n        lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)\n        pad_mask = seq_range < lengths_tensor  # shape: [batch, sequence_length]\n\n        # Build the base causal mask (for non-image tokens):\n        causal_mask = torch.tril(\n            torch.ones(\n                (sequence_length, sequence_length), dtype=torch.bool, device=device\n            )\n        )\n        base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(\n            1\n        )  # [batch, sequence_length, sequence_length]\n        base_mask = base_mask & causal_mask.unsqueeze(0)  # apply causal constraint\n\n        image_token_mask = torch.nn.utils.rnn.pad_sequence(\n            torch.split(image_token_mask, lengths), batch_first=True, padding_value=0\n        )\n        bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(\n            1\n        )\n\n        # Combine the causal base mask and the bidirectional mask.\n        combined_mask = torch.logical_or(\n            base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)\n        ).to(device)\n        # combined_mask now has shape [batch, 1, sequence_length, sequence_length]\n\n        full_attention_mask = torch.zeros(\n            (batch_size, 1, sequence_length, target_length),\n            device=device,\n            dtype=torch.bool,\n        )\n        full_attention_mask[:, :, :, :sequence_length] = combined_mask\n\n        if bool_mask:\n            return full_attention_mask\n        else:\n            return torch.where(full_attention_mask, 0, min_dtype).to(device)\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        pixel_values = pixel_values.to(dtype=self.dtype)\n        image_outputs = self.vision_model(pixel_values)\n        vision_outputs = self.post_vision_model_layernorm(\n            image_outputs.last_hidden_state\n        )\n        image_features = self.multimodal_projector(vision_outputs)\n        image_features = image_features.view(-1, image_features.shape[-1])\n        return image_features\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # Replace the image token embeddings with the vision features\n            image_token_mask = (input_ids == self.config.image_token_index).to(\n                input_ids.device\n            )\n            inputs_embeds[image_token_mask] = vision_embeds.view(\n                -1, vision_embeds.shape[-1]\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        pixel_values: torch.FloatTensor = None,\n        # Unused here\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        if cu_seqlen_prefill is not None:\n            max_s += 1\n            position_ids += 1\n\n        # Use flash attention for text-only input\n        # else:\n        #     if cu_seqlen_prefill is not None:\n        #         min_dtype = torch.finfo(inputs_embeds.dtype).min\n        #         lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()\n\n        #         # Determine the maximum sequence length (after padding) from query.\n        #         sequence_length = max(lengths)\n        #         target_length = sequence_length\n\n        #         # Create the padding mask from the computed lengths.\n        #         # pad_mask: [batch, sequence_length] where True indicates valid tokens.\n        #         seq_range = torch.arange(\n        #             sequence_length, device=input_ids.device\n        #         ).unsqueeze(0)\n        #         lengths_tensor = torch.tensor(\n        #             lengths, device=input_ids.device\n        #         ).unsqueeze(1)\n        #         pad_mask = seq_range < lengths_tensor  # shape: [batch, sequence_length]\n\n        #         # Build the base causal mask (for non-image tokens):\n        #         causal_mask = torch.tril(\n        #             torch.ones(\n        #                 (sequence_length, sequence_length),\n        #                 dtype=torch.bool,\n        #                 device=input_ids.device,\n        #             )\n        #         )\n        #         base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(\n        #             1\n        #         )  # [batch, sequence_length, sequence_length]\n        #         base_mask = base_mask & causal_mask.unsqueeze(0)\n        #         attention_mask = base_mask.unsqueeze(\n        #             1\n        #         )  # [batch, 1, sequence_length, sequence_length]\n        #         full_attention_mask = torch.zeros(\n        #             (len(lengths), 1, sequence_length, target_length),\n        #             device=input_ids.device,\n        #             dtype=torch.bool,\n        #         )\n        #         full_attention_mask[:, :, :, :sequence_length] = attention_mask\n\n        #         attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(\n        #             input_ids.device\n        #         )\n\n        if attention_mask is not None:\n            min_dtype = torch.finfo(inputs_embeds.dtype).min\n            # prefill may be larger than sliding window\n            effective_seq_len = max(\n                position_ids.shape[0], self.config.text_config.sliding_window\n            )\n            sliding_window_mask = torch.tril(\n                torch.ones_like(attention_mask, dtype=torch.bool),\n                diagonal=-self.config.text_config.sliding_window,\n            )\n            attention_mask_local = torch.where(\n                sliding_window_mask, min_dtype, attention_mask\n            )\n            offset = max(0, position_ids.shape[0] - effective_seq_len)\n            attention_mask_local = attention_mask_local[\n                :, :, :, offset : offset + effective_seq_len\n            ]\n        else:\n            attention_mask_local = None\n\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            attention_mask=attention_mask,\n            attention_mask_local=attention_mask_local,\n        )\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n\n        # pad logit with 1 zero logit for the image token\n        if pixel_values is not None:\n            logits = torch.cat(\n                [logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1\n            )\n            if speculative_logits is not None:\n                speculative_logits = torch.cat(\n                    [\n                        speculative_logits,\n                        torch.zeros(\n                            speculative_logits.size(0),\n                            1,\n                            device=speculative_logits.device,\n                        ),\n                    ],\n                    dim=1,\n                )\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\nclass GemmaConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=256128,\n        hidden_size=3072,\n        intermediate_size=24576,\n        num_hidden_layers=28,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        head_dim=256,\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=8192,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=True,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.head_dim = head_dim\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass GemmaFastRMSNorm(FastRMSNorm):\n    @classmethod\n    def load(cls, prefix: str, weights, eps=1e-6):\n        dtype = weights.dtype\n        weights.dtype = torch.float32\n        weight = weights.get_tensor(f\"{prefix}.weight\") + 1\n        weights.dtype = dtype\n        new = cls(weight, eps)\n        new.dtype = dtype\n        return new\n\n    # perform the multiplication in full precision and downcast after\n    def forward(self, hidden_states, residual=None):\n        if residual is not None:\n            hidden_states += residual\n        residual = hidden_states\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = hidden_states * self.weight\n        return hidden_states.to(self.dtype), residual\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.head_dim\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\nclass FlashGemmaAttention(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.head_size = config.head_dim\n        self.causal = causal\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                causal=self.causal,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GemmaMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])\n\n\nclass FlashGemmaLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n        self.self_attn = FlashGemmaAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights, causal=causal\n        )\n        self.mlp = GemmaMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output)\n\n        return mlp_output, attn_res\n\n\nclass FlashGemmaModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, causal: bool):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                FlashGemmaLayer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    causal=causal,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = GemmaFastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGemmaForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, *, causal: bool = True):\n        super().__init__()\n\n        embed_norm = config.hidden_size**0.5\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.embed_tokens.weight *= embed_norm\n\n        self.model = FlashGemmaModel(\n            prefix=prefix, config=config, weights=weights, causal=causal\n        )\n        self.lm_head = SpeculativeHead.load(\n            prefix=(\n                f\"{prefix}.embed_tokens\"\n                if config.tie_word_embeddings\n                else f\"{prefix}.lm_head\"\n            ),\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            input_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\n\n\ndef load_qkv(config, prefix: str, weights, head_size, num_heads):\n    if config.quantize == \"gptq\":\n        return _load_qkv_gptq(\n            config,\n            prefix,\n            weights,\n        )\n    elif config.quantize == \"marlin\":\n        raise RuntimeError(\n            \"GPT-2 models with marlin quantization are not yet supported\"\n        )\n    else:\n        return _load_qkv(config, prefix, weights, head_size, num_heads)\n\n\ndef _load_qkv_gptq(config, prefix: str, weights):\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    # Weights\n    weight = weights.get_weights_col_packed_qkv(\n        f\"{prefix}.c_attn\",\n        config.num_attention_heads,\n        config.num_attention_heads,\n    )\n\n    # Bias\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n    shape = slice_.get_shape()\n    total_size = shape[0]\n    assert total_size % 3 == 0, f\"Prepacked is not divisible by {3}\"\n    single_size = total_size // 3\n    assert single_size % world_size == 0\n    block_size = single_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    tensors = []\n    for i in range(3):\n        tensor = slice_[start + i * single_size : stop + i * single_size]\n        tensors.append(tensor)\n    bias = torch.cat(tensors, dim=0)\n    bias = bias.to(device=weights.device)\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef _load_qkv(config, prefix: str, weights, head_size, num_heads):\n    \"\"\"Load QKV from a single, transposed matrix.\"\"\"\n\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.weight\")\n    shape = slice_.get_shape()\n    total_size = shape[1]\n    assert total_size % 3 == 0, f\"Prepacked is not divisible by {3}\"\n    world_size = weights.process_group.size()\n    single_size = total_size // 3\n    assert single_size % world_size == 0\n    rank = weights.process_group.rank()\n\n    # Weights\n    block_size = single_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    tensors = []\n    for i in range(3):\n        tensor = slice_[:, start + i * single_size : stop + i * single_size]\n        tensors.append(tensor)\n    weight = torch.cat(tensors, dim=1).T\n    weight = weight.to(dtype=weights.dtype)\n    weight = weight.to(device=weights.device)\n\n    # Bias\n    slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n    shape = slice_.get_shape()\n    total_size = shape[0]\n    single_size = total_size // 3\n    block_size = single_size // world_size\n    assert single_size % world_size == 0\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n    b = []\n    for i in range(3):\n        tensor = slice_[start + i * single_size : stop + i * single_size]\n        b.append(tensor)\n    bias = torch.cat(b, dim=0)\n    bias = bias.to(dtype=weights.dtype)\n    bias = bias.to(device=weights.device)\n    assert list(bias.shape) == [\n        3 * num_heads * head_size\n    ], f\"{weight.shape} != {[3 * num_heads * head_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    \"\"\"load_row, but with transposed weight matrices.\"\"\"\n\n    if config.quantize == \"gptq\":\n        weight = weights.get_weights_row(prefix)\n    else:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0).T\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    return TensorParallelRowLinear(\n        get_linear(weight, bias), process_group=weights.process_group\n    )\n\n\ndef load_col(config, prefix: str, weights, bias: bool):\n    \"\"\"load_col, but with transposed weight matrices.\"\"\"\n    if config.quantize == \"gptq\":\n        weight = weights.get_multi_weights_col([prefix], dim=1)\n    else:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=1).T\n\n    if bias:\n        bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\nclass FlashGPT2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n\n        self.head_size = self.hidden_size // self.num_heads\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = load_qkv(\n            config,\n            prefix=prefix,\n            weights=weights,\n            head_size=self.head_size,\n            num_heads=self.num_heads,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = load_row(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        query, key, value = self.query_key_value(hidden_states).split(\n            self.head_size * self.num_heads, dim=1\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_heads, self.head_size)\n        value = value.view(-1, self.num_heads, self.head_size)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GPT2MLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.c_fc = load_col(\n            config, prefix=f\"{prefix}.c_fc\", weights=weights, bias=True\n        )\n        self.c_proj = load_row(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=True,\n        )\n\n        intermediate_size = (\n            config.n_inner if config.n_inner is not None else 4 * config.hidden_size\n        )\n\n        self.intermediate_size = intermediate_size // weights.process_group.size()\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        return self.c_proj(hidden_states)\n\n\nclass FlashGPT2Layer(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.self_attn = FlashGPT2Attention(\n            prefix=f\"{prefix}.attn\", config=config, weights=weights\n        )\n        self.mlp = GPT2MLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.ln_2\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        hidden_states = attn_output + residual\n        residual = hidden_states\n\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        mlp_output = self.mlp(hidden_states)\n\n        return residual + mlp_output, residual\n\n\nclass FlashGPT2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                FlashGPT2Layer(\n                    prefix=(\n                        f\"h.{layer_id}\" if not prefix else f\"{prefix}.h.{layer_id}\"\n                    ),\n                    config=config,\n                    weights=weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.norm = nn.LayerNorm.load(\n            prefix=\"ln_f\" if not prefix else f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass FlashGPT2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\"wte\" if not prefix else f\"{prefix}.wte\"),\n            weights=weights,\n        )\n        self.embed_positions = TensorParallelEmbedding(\n            prefix=(\"wpe\" if not prefix else f\"{prefix}.wpe\"),\n            weights=weights,\n        )\n\n        self.model = FlashGPT2Model(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"wte\" if not prefix else f\"{prefix}.wte\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        token_embeds = self.embed_tokens(input_ids)\n        position_embeds = self.embed_positions(position_ids)\n        inputs_embeds = token_embeds + position_embeds\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=prefill_cache_indices,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\n\n\ndef load_attention(config, prefix: str, weights):\n    return TensorParallelColumnLinear.load_multi(\n        config,\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n        weights=weights,\n        bias=False,\n    )\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\nclass GPTJRotary(PositionRotaryEmbedding):\n    def forward(\n        self,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        cos: torch.Tensor,\n        sin: torch.Tensor,\n    ):\n        # Such controlflows may add some overhead.\n        if SYSTEM == \"cuda\":\n            from text_generation_server.utils.kernels import load_kernel\n\n            rotary = load_kernel(module=\"rotary\", repo_id=\"kernels-community/rotary\")\n\n            q1 = query[..., ::2]\n            q2 = query[..., 1::2]\n\n            rotary.apply_rotary(q1, q2, cos, sin, q1, q2, False)\n\n            k1 = key[..., ::2]\n            k2 = key[..., 1::2]\n\n            rotary.apply_rotary(k1, k2, cos, sin, k1, k2, False)\n        elif SYSTEM == \"rocm\":\n            import vllm._custom_ops as ops\n\n            # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.\n            # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773\n\n            head_size = query.shape[-1]\n\n            # Inplace operation, updating query and key.\n            ops.rotary_embedding(query, key, head_size, cos, sin, False)\n        elif SYSTEM == \"ipex\":\n            import intel_extension_for_pytorch as ipex\n\n            ipex.llm.functional.rotary_embedding(\n                query, key, sin, cos, query.size(-1), False\n            )\n        else:\n            raise ValueError(\n                \"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.\"\n            )\n\n\nclass FlashGPTJAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n\n        self.head_size = self.hidden_size // self.num_heads\n        self.softmax_scale = self.head_size**-0.5\n        self.rotary_dim = config.rotary_dim\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = load_attention(\n            config,\n            prefix=prefix,\n            weights=weights,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = load_row(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n        self.rotary_emb = GPTJRotary.static(\n            config=config,\n            dim=self.rotary_dim,\n            base=10000,\n            device=weights.device,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        query, key, value = self.query_key_value(hidden_states).split(\n            self.head_size * self.num_heads, dim=1\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        key = key.view(-1, self.num_heads, self.head_size)\n        value = value.view(-1, self.num_heads, self.head_size)\n\n        # Compute rotary embeddings on rotary_ndims\n        if self.rotary_dim is not None:\n            self.rotary_emb(\n                query[..., : self.rotary_dim], key[..., : self.rotary_dim], cos, sin\n            )\n        else:\n            self.rotary_emb(query, key, cos, sin)\n\n        kv_cache.store(\n            key=key,\n            value=value,\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key,\n                value=value,\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass GPTJMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.fc_in = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.fc_in\", weights=weights, bias=True\n        )\n\n        self.fc_out = load_row(\n            config,\n            prefix=f\"{prefix}.fc_out\",\n            weights=weights,\n            bias=True,\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.fc_in(hidden_states)\n        hidden_states = self.act(hidden_states)\n        return self.fc_out(hidden_states)\n\n\nclass FlashGPTJLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.self_attn = FlashGPTJAttention(\n            prefix=f\"{prefix}.attn\", config=config, weights=weights\n        )\n        self.mlp = GPTJMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        hidden_states, residual = self.input_layernorm(hidden_states, residual)\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        feed_forward_hidden_states = self.mlp(hidden_states)\n\n        return attn_output + feed_forward_hidden_states, residual\n\n\nclass FlashGPTJModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.config = config\n\n        self.wte = TensorParallelEmbedding(prefix=f\"{prefix}.wte\", weights=weights)\n        self.layers = nn.ModuleList(\n            [\n                FlashGPTJLayer(\n                    prefix=(\n                        f\"h.{layer_id}\" if not prefix else f\"{prefix}.h.{layer_id}\"\n                    ),\n                    config=config,\n                    weights=weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.ln_f = FastLayerNorm.load(\n            prefix=\"ln_f\" if not prefix else f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor],\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n    ) -> torch.Tensor:\n        hidden_states = self.wte(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGPTJForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n        self.model = FlashGPTJModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            prefill_cache_indices=prefill_cache_indices,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_llama_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom contextlib import contextmanager\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\n\nfrom text_generation_server.layers.attention import (\n    KVCache,\n    get_kv_scales,\n)\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n    FastLayerNorm,\n)\nfrom text_generation_server.layers import (\n    FastLinear,\n)\nfrom text_generation_server.utils.weights import (\n    Weights,\n)\nfrom text_generation_server.layers.fp8 import HybridFP8UnquantLoader\n\nif SYSTEM != \"ipex\":\n    pass\n\nif SYSTEM == \"rocm\":\n    try:\n        import vllm._custom_ops as ops\n    except Exception as e:\n        raise ImportError(f\"Could not load `vllm._custom_ops`. Full error: {e}\")\n\n\ndef load_attention(config, prefix: str, weights, layer_id):\n    # Only defined in granite.\n    bias = getattr(config, \"attention_bias\", False)\n    head_size = config.hidden_size // config.num_attention_heads\n    sizes = None\n    prefixes = None\n\n    if config.model_type == \"phi3\":\n        base_layer = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv_proj\",\n            weights=weights,\n            bias=bias,\n            num_heads=config.num_attention_heads,\n            num_key_value_heads=config.num_key_value_heads,\n        )\n        prefixes = [\"qkv_proj\"]\n    elif config.model_type == \"baichuan\":\n        prefix = f\"{prefix}.W_pack\"\n        base_layer = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=prefix,\n            weights=weights,\n            bias=bias,\n            num_heads=config.num_attention_heads,\n            num_key_value_heads=config.num_key_value_heads,\n        )\n        prefixes = [prefix]\n    else:\n        prefixes = [\"q_proj\", \"k_proj\", \"v_proj\"]\n        sizes = [\n            head_size * config.num_attention_heads,\n            head_size * config.num_key_value_heads,\n            head_size * config.num_key_value_heads,\n        ]\n        base_layer = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=bias,\n        )\n\n    return TensorParallelMultiAdapterLinear.load(\n        base_layer=base_layer,\n        layer_id=layer_id,\n        layer_names=prefixes,\n        sizes=sizes,\n        process_group=weights.process_group,\n    )\n\n\n@contextmanager\ndef no_fp8(weights: Weights):\n    \"\"\"De-activate fp8 auto conversion for the duration of this context manager\"\"\"\n    weights_loader = weights.weights_loader\n    if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8:\n        weights_loader = HybridFP8UnquantLoader(\n            weights_loader.activation_scale_ub, to_fp8=False\n        )\n\n    with weights.use_loader(weights_loader):\n        yield\n\n\nclass FlashLlamaAttention(torch.nn.Module):\n    def __init__(\n        self,\n        index: int,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        # Setting defaults for baichuan custom config which doesn't apply them.\n        config.rope_theta = getattr(config, \"rope_theta\", 10000)\n        config.num_key_value_heads = getattr(\n            config, \"num_key_value_heads\", config.num_attention_heads\n        )\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        # `config.attention_multiplier` is used in Granite\n        self.softmax_scale = getattr(\n            config, \"attention_multiplier\", self.head_size**-0.5\n        )\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        if config.num_key_value_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights, index)\n        self.index = index\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=getattr(config, \"attention_bias\", False),\n        )\n\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            index,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache: KVCache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_scales=self.kv_scales,\n                kv_cache=kv_cache,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Phi3MoE(nn.Module):\n    def __init__(\n        self, prefix: str, config, moe_layer_cls: Type[MoELayer], weights: Weights\n    ):\n        super().__init__()\n\n        # gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe = moe_layer_cls(\n            prefix=f\"{prefix}.experts\",\n            n_experts=config.num_local_experts,\n            n_expert_group=None,\n            renormalize=True,\n            topk=config.num_experts_per_tok,\n            topk_group=None,\n            weights=weights,\n            gate_proj_name=\"w1\",\n            up_proj_name=\"w3\",\n            down_proj_name=\"w2\",\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(self, x, adapter_data) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n        out = self.moe(x, gating_output=router_logits)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, prefix, config, weights, index):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        self.act = (\n            ACT2FN[self.hidden_act]\n            if \"gelu\" not in self.hidden_act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if self.hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        )\n        prefixes = None\n        sizes = None\n\n        # Fuse gate and up proj\n        bias = getattr(config, \"mlp_bias\", False)\n        if config.model_type == \"phi3\":\n            gate_up_proj = TensorParallelColumnLinear.load_gate_up(\n                config,\n                prefix=f\"{prefix}.gate_up_proj\",\n                weights=weights,\n                bias=bias,\n            )\n        else:\n            prefixes = [\"gate_proj\", \"up_proj\"]\n            sizes = [\n                config.intermediate_size,\n                config.intermediate_size,\n            ]\n            gate_up_proj = TensorParallelColumnLinear.load_multi(\n                config,\n                prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n                weights=weights,\n                dim=0,\n                bias=bias,\n            )\n\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            index,\n            layer_names=prefixes,\n            sizes=sizes,\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=bias,\n        )\n\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            index,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n        self.hidden_size = config.hidden_size\n\n    def forward(self, hidden_states, adapter_data):\n        if (\n            SYSTEM == \"rocm\"\n            and self.hidden_act == \"silu\"\n            and hidden_states.dtype == torch.float16\n            and hidden_states.shape[0] == 1\n            and not self.quantize\n            and self.hidden_size\n            != 16384  # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.\n        ):\n            out = torch.empty(\n                hidden_states.shape[0],\n                self.intermediate_size,\n                dtype=hidden_states.dtype,\n                device=\"cuda\",\n            )\n            ops.LLMM_Silu(\n                self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8\n            )\n            return self.down_proj(out, adapter_data)\n        else:\n            gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n            gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n            return self.down_proj(\n                self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n            )\n\n\nclass FlashLlamaLayer(nn.Module):\n    def __init__(self, index, prefix, config, weights):\n        super().__init__()\n\n        with no_fp8(weights):\n            self.self_attn = FlashLlamaAttention(\n                index=index,\n                prefix=f\"{prefix}.self_attn\",\n                config=config,\n                weights=weights,\n            )\n\n        if config.model_type == \"phimoe\":\n            moe_layer_cls = (\n                SparseMoELayer\n                if SparseMoELayer.is_supported(weights)\n                else DenseMoELayer\n            )\n            self.mlp = Phi3MoE(\n                f\"{prefix}.block_sparse_moe\", config, moe_layer_cls, weights\n            )\n            # with moe the layernorms are are not rmsnorms and they have bias\n            self.input_layernorm = FastLayerNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.post_attention_layernorm = FastLayerNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n        else:\n            self.mlp = LlamaMLP(\n                prefix=f\"{prefix}.mlp\", config=config, weights=weights, index=index\n            )\n            self.input_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.post_attention_layernorm = FastRMSNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n\n        # Used in Granite\n        # This could eventually be baked into the weights like we do for the embeddings/lm_head\n        # but this would mean modifying the lora code\n        self.residual_multiplier = getattr(config, \"residual_multiplier\", None)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n        cross_attention_states,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            adapter_data,\n        )\n        if self.residual_multiplier is not None:\n            attn_output *= self.residual_multiplier\n\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n        if self.residual_multiplier is not None:\n            mlp_output *= self.residual_multiplier\n\n        return mlp_output, attn_res\n\n\nclass FlashLlamaModel(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n\n        # Skip fp8 quant for first and last layers\n        self.layers = nn.ModuleList()\n        self.cross_attention_layers = getattr(config, \"cross_attention_layers\", [])\n        with no_fp8(weights):\n            self.layers.append(\n                FlashLlamaLayer(\n                    index=0,\n                    prefix=f\"{prefix}.layers.0\",\n                    config=config,\n                    weights=weights,\n                )\n            )\n\n        # Skip first and last layers\n        for layer_id in range(1, config.num_hidden_layers - 1):\n            if layer_id in self.cross_attention_layers:\n                from text_generation_server.models.custom_modeling.mllama import (\n                    FlashLlamaCrossLayer,\n                )\n\n                self.layers.append(\n                    FlashLlamaCrossLayer(\n                        index=layer_id,\n                        prefix=(f\"{prefix}.layers.{layer_id}\"),\n                        config=config,\n                        weights=weights,\n                    )\n                )\n            else:\n                self.layers.append(\n                    FlashLlamaLayer(\n                        index=layer_id,\n                        prefix=(f\"{prefix}.layers.{layer_id}\"),\n                        config=config,\n                        weights=weights,\n                    )\n                )\n\n        with no_fp8(weights):\n            last_layer_id = config.num_hidden_layers - 1\n            self.layers.append(\n                FlashLlamaLayer(\n                    index=last_layer_id,\n                    prefix=(f\"{prefix}.layers.{last_layer_id}\"),\n                    config=config,\n                    weights=weights,\n                )\n            )\n\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        adapter_data,\n        cross_attention_states=None,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                adapter_data,\n                cross_attention_states,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashLlamaForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, name=None):\n        if name is None:\n            name = \"model\"\n        super().__init__()\n        with no_fp8(weights):\n            self.embed_tokens = TensorParallelEmbedding(\n                prefix=(\n                    f\"{name}.embed_tokens\"\n                    if not prefix\n                    else f\"{prefix}.{name}.embed_tokens\"\n                ),\n                weights=weights,\n            )\n        self.model = FlashLlamaModel(\n            prefix=name if not prefix else f\"{prefix}.{name}\",\n            config=config,\n            weights=weights,\n        )\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        # Used in Granite\n        embedding_multiplier = getattr(config, \"embedding_multiplier\", None)\n        if embedding_multiplier is not None:\n            self.embed_tokens.weight.data *= embedding_multiplier\n        prefix = suffix if not prefix or name != \"model\" else f\"{prefix}.{suffix}\"\n        with no_fp8(weights):\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix,\n                weights,\n            )\n\n        # Used in Granite\n        self.logits_scaling = getattr(config, \"logits_scaling\", None)\n        if self.logits_scaling is not None and self.lm_head.head is not None:\n            try:\n                # Scale the weights directly\n                self.lm_head.head.linear.weight.data /= self.logits_scaling\n                self.logits_scaled = True\n            except Exception:\n                self.logits_scaled = False\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        cross_attention_states=None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        inputs_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=prefill_cache_indices,\n            adapter_data=adapter_data,\n            cross_attention_states=cross_attention_states,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        # Used in Granite\n        if self.logits_scaling is not None and not self.logits_scaled:\n            logits /= self.logits_scaling\n            if speculative_logits is not None:\n                speculative_logits /= self.logits_scaling\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n)\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\n\n\nif SYSTEM == \"rocm\":\n    try:\n        import vllm._custom_ops as ops\n    except Exception as e:\n        raise ImportError(f\"Could not load `vllm._custom_ops`. Full error: {e}\")\n\n\nclass MistralConfig(PretrainedConfig):\n    model_type = \"mistral\"\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        sliding_window=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass MistralAttention(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        if getattr(config, \"head_dim\", None) is not None:\n            self.head_size = config.head_dim\n        else:\n            self.head_size = self.hidden_size // self.num_heads\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        query_key_value = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n        self.query_key_value = TensorParallelMultiAdapterLinear.load(\n            query_key_value,\n            layer_id,\n            [\"q_proj\", \"k_proj\", \"v_proj\"],\n            sizes=[\n                self.head_size * config.num_attention_heads,\n                self.head_size * config.num_key_value_heads,\n                self.head_size * config.num_key_value_heads,\n            ],\n            process_group=weights.process_group,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            layer_id,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        if prefill_cache_indices is not None:\n            kv_to_cache = kv[prefill_cache_indices]\n        else:\n            kv_to_cache = kv\n\n        kv_cache.store(\n            key=kv_to_cache[:, 0],\n            value=kv_to_cache[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv_to_cache[:, 0],\n                value=kv_to_cache[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass MistralMLP(nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id):\n        super().__init__()\n        self.hidden_act = config.hidden_act\n        self.act = (\n            ACT2FN[self.hidden_act]\n            if \"gelu\" not in self.hidden_act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\"\n                    if self.hidden_act in [\"gelu_fast\", \"gelu_pytorch_tanh\"]\n                    else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id,\n            [\"gate_proj\", \"up_proj\"],\n            sizes=[\n                config.intermediate_size,\n                config.intermediate_size,\n            ],\n            process_group=weights.process_group,\n        )\n\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            layer_id,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        # TODO: This is a hotfix to be removed & properly refactored.\n        self.quantize = config.quantize\n\n    def forward(self, hidden_states, adapter_data):\n        if (\n            SYSTEM == \"rocm\"\n            and self.hidden_act == \"silu\"\n            and hidden_states.dtype == torch.float16\n            and hidden_states.shape[0] == 1\n            and not self.quantize\n        ):\n            out = torch.empty(\n                hidden_states.shape[0],\n                self.intermediate_size,\n                dtype=hidden_states.dtype,\n                device=\"cuda\",\n            )\n            ops.LLMM_Silu(\n                self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8\n            )\n            return self.down_proj(out, adapter_data)\n        else:\n            gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n            gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n            return self.down_proj(\n                self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n            )\n\n\nclass MistralLayer(nn.Module):\n    def __init__(self, prefix: str, config, weights, layer_id):\n        super().__init__()\n        self.self_attn = MistralAttention(\n            prefix=f\"{prefix}.self_attn\",\n            config=config,\n            weights=weights,\n            layer_id=layer_id,\n        )\n        self.mlp = MistralMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, layer_id=layer_id\n        )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n\n        return mlp_output, attn_res\n\n\nclass MistralModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                MistralLayer(\n                    prefix=f\"{prefix}.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                    layer_id=layer_id,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, true_max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                prefill_cache_indices,\n                adapter_data,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n        return hidden_states\n\n\nclass FlashMistralForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights, name=None):\n        if name is None:\n            name = \"model\"\n        super().__init__()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\n                f\"{name}.embed_tokens\"\n                if not prefix\n                else f\"{prefix}.{name}.embed_tokens\"\n            ),\n            weights=weights,\n        )\n        self.model = MistralModel(\n            prefix=name if not prefix else f\"{prefix}.{name}\",\n            config=config,\n            weights=weights,\n        )\n        self.lm_head = SpeculativeHead.load(\n            config,\n            # TODO dirty hack for idefics2.\n            prefix=(\n                \"lm_head\" if not prefix or name != \"model\" else f\"{prefix}.lm_head\"\n            ),\n            weights=weights,\n        )\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        true_max_s = max_s\n        if prefill_cache_indices is not None:\n            # Slots also need to be sliced as it has the same size as the whole kv tensor\n            slots = slots[prefill_cache_indices]\n        elif self.max_past is not None:\n            # Clamp in decode mode as paged attention requires clamped values whereas the flash attention\n            # kernel requires the true values\n            seqlen = seqlen.clamp(max=self.max_past_tensor)\n\n        inputs_embeds = self.embed_tokens(input_ids)\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Optional, Tuple, Type\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\n\nfrom text_generation_server.layers import (\n    FastLinear,\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n    attention,\n    paged_attention,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\nclass MixtralConfig(PretrainedConfig):\n    model_type = \"mixtral\"\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=14336,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-05,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        sliding_window=None,\n        num_experts_per_tok=2,\n        num_local_experts=8,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_local_experts = num_local_experts\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\ndef promote_scalar(x: torch.Tensor) -> torch.Tensor:\n    return x.view(1) if len(x.size()) == 0 else x\n\n\ndef load_attention(config, prefix: str, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=None))\n\n\ndef _load_experts(config, prefix: str, mat, weights):\n    if config.quantize is not None:\n        raise NotImplementedError(\"Mixtral does not support weight quantization yet.\")\n\n    assert mat in [\"w1\", \"w2\", \"w3\"]\n\n    world_size = weights.process_group.size()\n    rank = weights.process_group.rank()\n\n    assert (\n        config.intermediate_size % world_size == 0\n    ), f\"The chosen size {config.intermediate_size} is not compatible with sharding on {world_size} shards\"\n\n    block_size = config.intermediate_size // world_size\n    start = rank * block_size\n    stop = (rank + 1) * block_size\n\n    tensor = torch.empty(\n        (config.num_local_experts * block_size, config.hidden_size),\n        dtype=weights.dtype,\n        device=weights.device,\n    )\n\n    for i in range(config.num_local_experts):\n        slice_ = weights._get_slice(f\"{prefix}.{i}.{mat}.weight\")\n\n        if mat == \"w2\":\n            expert_slice = slice_[:, start:stop].t().contiguous()\n        else:\n            expert_slice = slice_[start:stop]\n        tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(\n            dtype=weights.dtype\n        ).to(device=weights.device)\n    return tensor\n\n\nclass MixtralAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        if prefill_cache_indices is not None:\n            kv_to_cache = kv[prefill_cache_indices]\n        else:\n            kv_to_cache = kv\n\n        kv_cache.store(\n            key=kv_to_cache[:, 0],\n            value=kv_to_cache[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv_to_cache[:, 0],\n                value=kv_to_cache[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\n@torch.jit.script\ndef select_experts(gate_logits: torch.Tensor, top_k: int):\n    # all_probs: (sequence_length, n_experts) and upcast for softmax\n    all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)\n    # weights, selected_experts: (sequence_length, top-k)\n    weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)\n    weights /= weights.sum(dim=-1, keepdim=True)\n    weights = weights.view(-1)\n    selected_experts = selected_experts.view(-1)\n\n    return selected_experts, weights\n\n\n@torch.jit.script\ndef round_up(x: torch.Tensor, value: int):\n    return torch.div(x + (value - 1), value, rounding_mode=\"trunc\") * value\n\n\nclass MixtralMoE(nn.Module):\n    def __init__(\n        self, prefix, config: MixtralConfig, moe_layer_cls: Type[MoELayer], weights\n    ):\n        super().__init__()\n\n        # gating\n        self.gate = FastLinear.load(config, f\"{prefix}.gate\", weights, bias=False)\n\n        self.moe = moe_layer_cls(\n            n_expert_group=None,\n            n_experts=config.num_local_experts,\n            prefix=f\"{prefix}.experts\",\n            renormalize=True,\n            topk=config.num_experts_per_tok,\n            topk_group=None,\n            weights=weights,\n            gate_proj_name=\"w1\",\n            up_proj_name=\"w3\",\n            down_proj_name=\"w2\",\n        )\n        assert isinstance(self.moe, MoELayer)\n\n        self.process_group = weights.process_group\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        # router_logits: (num_tokens, n_experts)\n        router_logits = self.gate(x)\n        out = self.moe(x, gating_output=router_logits)\n\n        # Reduce sum\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(out, group=self.process_group)\n\n        return out.view(*x.shape)\n\n\nclass MixtralLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n\n        self.self_attn = MixtralAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n\n        moe_layer_cls = (\n            SparseMoELayer if SparseMoELayer.is_supported(weights) else DenseMoELayer\n        )\n        self.moe = MixtralMoE(\n            f\"{prefix}.block_sparse_moe\", config, moe_layer_cls, weights\n        )\n\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            prefill_cache_indices,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        moe_output = self.moe(normed_attn_res_output)\n\n        return moe_output, attn_res\n\n\nclass MixtralModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=(\n                \"model.embed_tokens\" if not prefix else f\"{prefix}.model.embed_tokens\"\n            ),\n            weights=weights,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                MixtralLayer(\n                    \"model\" if not prefix else f\"{prefix}.model\",\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=\"model.norm\" if not prefix else f\"{prefix}.model.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, true_max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                prefill_cache_indices,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashMixtralForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.model = MixtralModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\" if not prefix else f\"{prefix}.lm_head\",\n            weights=weights,\n        )\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        true_max_s = max_s\n        if prefill_cache_indices is not None:\n            # Slots also need to be sliced as it has the same size as the whole kv tensor\n            slots = slots[prefill_cache_indices]\n        elif self.max_past is not None:\n            # Clamp in decode mode as paged attention requires clamped values whereas the flash attention\n            # kernel requires the true values\n            seqlen = seqlen.clamp(max=self.max_past_tensor)\n\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s,\n            prefill_cache_indices,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_neox_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig\nfrom typing import Optional, List, Tuple\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\nclass GPTNeoXConfig(TransformersGPTNeoXConfig):\n    attribute_map = {\n        \"num_key_value_heads\": \"num_attention_heads\",\n    }\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    if config.use_parallel_residual:\n        return linear\n    else:\n        return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\ndef load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):\n    weight = weights.get_multi_weights_col([prefix], dim=0)\n    if isinstance(weight, UnquantizedWeight):\n        # Only on non quantized versions\n        weight.weight = (\n            weight.weight.view(\n                num_heads,\n                3,\n                head_size,\n                hidden_size,\n            )\n            .permute(1, 0, 2, 3)\n            .reshape(-1, hidden_size)\n        )\n\n    bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)\n\n    linear = get_linear(weight, bias)\n    if config.use_parallel_residual:\n        return linear\n    else:\n        return TensorParallelColumnLinear(linear)\n\n\nclass FlashNeoxAttention(torch.nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        num_heads = config.num_attention_heads\n        hidden_size = config.hidden_size\n\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n\n        self.rotary_dim = int(config.rotary_pct * self.head_size)\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.rotary_dim,\n            base=config.rotary_emb_base,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        self.query_key_value = load_qkv(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            num_heads=self.num_heads,\n            head_size=self.head_size,\n            hidden_size=self.hidden_size,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=True\n        )\n        self.kv_head_mapping = torch.arange(\n            0, self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        qkv = qkv.view(-1, 3, self.num_heads, self.head_size)\n\n        # Compute rotary embeddings on rotary_ndims\n        query_rot = qkv[:, 0][..., : self.rotary_dim]\n        query_pass = qkv[:, 0][..., self.rotary_dim :]\n        key_rot = qkv[:, 1][..., : self.rotary_dim]\n        key_pass = qkv[:, 1][..., self.rotary_dim :]\n\n        # Inplace rotary\n        self.rotary_emb(query_rot, key_rot, cos, sin)\n        qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)\n        qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)\n\n        kv_cache.store(\n            key=qkv[:, 1],\n            value=qkv[:, 2],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=qkv[:, 0],\n                key=qkv[:, 1],\n                value=qkv[:, 2],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                qkv[:, 0],\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass FlashMLP(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=True\n        )\n        self.dense_4h_to_h = load_row(\n            config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        return hidden_states\n\n\nclass FlashNeoXLayer(nn.Module):\n    def __init__(self, layer_id, config, weights):\n        super().__init__()\n\n        layer_norm_eps = config.layer_norm_eps\n\n        prefix = f\"gpt_neox.layers.{layer_id}\"\n\n        self.use_parallel_residual = config.use_parallel_residual\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=layer_norm_eps\n        )\n        self.post_attention_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=layer_norm_eps,\n        )\n        self.attention = FlashNeoxAttention(\n            config, prefix=f\"{prefix}.attention\", weights=weights\n        )\n\n        self.mlp = FlashMLP(config, prefix=f\"{prefix}.mlp\", weights=weights)\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        if self.use_parallel_residual:\n            ln1_hidden_states, _ = self.input_layernorm(hidden_states)\n\n            attn_output = self.attention(\n                ln1_hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n            ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)\n\n            mlp_output = self.mlp(ln2_hidden_states)\n            intermediate = mlp_output + attn_output\n\n            if self.process_group.size() > 1:\n                torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n            return intermediate + hidden_states, None\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            hidden_states = self.attention(\n                hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n            hidden_states, residual = self.post_attention_layernorm(\n                hidden_states, residual\n            )\n\n            mlp_output = self.mlp(hidden_states)\n\n            return mlp_output, residual\n\n\nclass FlashGPTNeoXPreTrainedModel(PreTrainedModel):\n    config_class = GPTNeoXConfig\n    base_model_prefix = \"gpt_neox\"\n    supports_gradient_checkpointing = False\n    _no_split_modules = None\n\n\nclass FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.config = config\n\n        self.embed_in = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_in\", weights=weights\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                FlashNeoXLayer(layer_id, config, weights)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.final_layer_norm = FastLayerNorm.load(\n            prefix=f\"{prefix}.final_layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].attention.head_size\n        self.num_heads = self.layers[0].attention.num_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_in(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.final_layer_norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):\n    def __init__(self, prefix, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"gpt_neox\"\n        else:\n            prefix = f\"{prefix}.gpt_neox\"\n\n        self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)\n\n        self.embed_out = SpeculativeHead.load(\n            config, prefix=\"embed_out\", weights=weights\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.gpt_neox(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.embed_out(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear\nfrom text_generation_server.layers.attention import Seqlen\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\n\n\nclass PaliGemmaForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = config.quantize\n        self.vision_tower = load_vision_model(\n            prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n            config=config.vision_config,\n            weights=weights,\n        )\n        self.post_vision_tower_layernorm = nn.LayerNorm.load(\n            prefix=\"vision_tower.vision_model.post_layernorm\",\n            weights=weights,\n            eps=config.vision_config.layer_norm_eps,\n        )\n\n        self.multi_modal_projector = TensorParallelColumnLinear.load(\n            config,\n            prefix=\"multi_modal_projector.linear\",\n            weights=weights,\n            bias=True,\n        )\n\n        self.vocab_size = config.text_config.vocab_size\n        self.config = config\n\n        text_config = config.text_config\n        text_config.speculator = config.speculator\n        text_config.quantize = config.quantize\n        self.text_model = load_text_model(\n            prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n            config=config.text_config,\n            weights=weights,\n        )\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n        self.dtype = weights.dtype\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        pixel_values = pixel_values.to(dtype=self.dtype)\n        image_outputs = self.vision_tower(pixel_values)\n        last_hidden_state = self.post_vision_tower_layernorm(\n            image_outputs.last_hidden_state\n        )\n        image_features = self.multi_modal_projector(last_hidden_state)\n        image_features = image_features.view(-1, image_features.shape[-1])\n        return image_features\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            mask = input_ids == self.config.image_token_index\n            inputs_embeds[mask] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        # Unused here\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # TODO This is odd but apparently pali gemma position ids start at 1.\n        if cu_seqlen_prefill is not None:\n            max_s += 1\n            position_ids += 1\n\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n        )\n\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_phi_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\n\n\nclass PhiConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=51200,\n        hidden_size=2560,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        hidden_act=\"gelu_fast\",  # llama uses silu\n        layer_norm_eps=1e-05,  # rms in llama,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        resid_pdrop=0.1,  # llama doesn't have this\n        partial_rotary_factor=0.5,  # important difference between llama and phi\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.layer_norm_eps = layer_norm_eps\n        self.rope_theta = rope_theta\n        self.resid_pdrop = resid_pdrop\n        self.partial_rotary_factor = partial_rotary_factor\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n# this is the same as llama except for Phi uses bias=True\ndef load_attention(config, prefix, weights):\n    if config.num_attention_heads != config.num_key_value_heads:\n        return _load_gqa(config, prefix, weights)\n    else:\n        return TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if config.quantize not in [\"gptq\", \"awq\", \"marlin\"]:\n        weight = weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    # this is the same as llama except for Phi uses bias=True\n    return TensorParallelColumnLinear(get_linear(weight, bias=True))\n\n\nclass FlashPhiAttention(torch.nn.Module):\n    def __init__(\n        self,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.softmax_scale = self.head_size**-0.5\n        self.rotary_dim = int(config.partial_rotary_factor * self.head_size)\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.rotary_dim,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        # in llama the dense layer is called \"o_proj\" and has bias=False\n        self.dense = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.dense\",\n            weights=weights,\n            bias=True,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        # Compute query, key, value and split\n        qkv = self.query_key_value(hidden_states)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n\n        # Reshape query and key for rotary embeddings\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        # NOTE: this is the main difference between Llama and Phi\n        # in llama the rotary embeddings are applied to the whole query and key.\n        # Phi uses PARTIAL rotary embeddings, which are applied to the first 32 dimensions\n        #\n        # Apply partial positional embeddings in place\n        self.rotary_emb(\n            query[:, :, : self.rotary_dim], kv[:, 0, :, : self.rotary_dim], cos, sin\n        )\n\n        # Reshape key and value and cache\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_scales=self.kv_scales,\n                kv_cache=kv_cache,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass PhiMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        # llama weights are up_proj and down_proj and bias=False\n        self.up_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.fc1\",\n            weights=weights,\n            bias=True,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.fc2\",\n            weights=weights,\n            bias=True,\n        )\n\n    def forward(self, hidden_states):\n        # NOTE: Llama requires the gate up states to an intermediate size\n        # Phi does not and we can avoid the `view` operation\n        return self.down_proj(self.act(self.up_proj(hidden_states)))\n\n\nclass FlashPhiLayer(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = FlashPhiAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = PhiMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        hidden_states, res = self.input_layernorm(hidden_states, residual)\n        # Self Attention\n        attn_output = self.self_attn(\n            hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        hidden_states = self.resid_dropout(attn_output).add(\n            self.resid_dropout(self.mlp(hidden_states))\n        )\n\n        return hidden_states, res\n\n\nclass FlashPhiModel(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.layers = nn.ModuleList(\n            [\n                FlashPhiLayer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n        self.norm = FastLayerNorm.load(\n            prefix=\"model.final_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashPhiForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = FlashPhiModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=\"lm_head\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n\n        return self.lm_head(hidden_states)\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_phi_moe_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"PyTorch Phi-MoE model.\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nPHIMOE_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"microsoft/Phi-3.5-MoE-instruct\": \"https://huggingface.co/microsoft/Phi-3.5-MoE-instruct/resolve/main/config.json\",\n}\n\n\nclass PhiMoEConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`PhiMoEModel`]. It is used to instantiate a Phi-MoE\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the\n    [microsoft/Phi-3.5-MoE-instruct](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32064):\n            Vocabulary size of the PhiMoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`PhiMoEModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 6400):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to `4096*32`):\n            The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention\n            allows sequence of up to 4096*32 tokens.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-05):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            The id of the padding token.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            The id of the \"beginning-of-sequence\" token.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            The id of the \"end-of-sequence\" token.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`dict`, *optional*):\n            The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must\n            contain the following keys: `type`, `short_factor`, `long_factor`, `short_mscale`, `long_mscale` and\n            `original_max_position_embeddings`. The `type` must be `longrope`, the `short_mscale` and `long_scale` must\n            be numbers, the `short_factor` and `long_factor` must be lists of numbers with the same length as half of\n            the attention head size and the `original_max_position_embeddings` must be an integer.\n        sliding_window (`int`, *optional*):\n            Sliding window attention window size. If not specified, will default to `262144`.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        num_experts_per_tok (`int`, *optional*, defaults to 2):\n            The number of experts to root per-token, can be also interpreted as the `top-p` routing\n            parameter\n        num_local_experts (`int`, *optional*, defaults to 16):\n            Number of experts per Sparse MLP layer.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss. See [here]() for more details\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.0):\n            The aux loss factor for the total loss.\n        router_jitter_noise (`float`, *optional*, defaults to 0.01):\n            Amount of noise to add to the router.\n\n    ```python\n    >>> from transformers import PhiMoEModel, PhiMoEConfig\n\n    >>> # Initializing a Phi-3 style configuration\n    >>> configuration = PhiMoEConfig.from_pretrained(\"microsoft/Phi-3.5-MoE-instruct\")\n\n    >>> # Initializing a model from the configuration\n    >>> model = PhiMoEModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"phimoe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=32064,\n        hidden_size=4096,\n        intermediate_size=6400,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=4096 * 32,\n        initializer_range=0.02,\n        rms_norm_eps=1e-5,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        rope_theta=1e6,\n        rope_scaling=None,\n        sliding_window=None,\n        attention_dropout=0.0,\n        num_experts_per_tok=2,\n        num_local_experts=16,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        router_jitter_noise=0.01,\n        input_jitter_noise=0.0,\n        attention_bias=False,\n        lm_head_bias=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n        self.attention_bias = attention_bias\n        self.lm_head_bias = lm_head_bias\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_local_experts = num_local_experts\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.router_jitter_noise = router_jitter_noise\n        self.input_jitter_noise = input_jitter_noise\n\n        self.rope_scaling = rope_scaling\n        self._rope_scaling_validation()\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n    def _rope_scaling_validation(self):\n        \"\"\"\n        Validate the `rope_scaling` configuration.\n        \"\"\"\n        if self.rope_scaling is None:\n            return\n\n        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:\n            raise ValueError(\n                \"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, \"\n                f\"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}\"\n            )\n        rope_scaling_type = self.rope_scaling.get(\"type\", None)\n        rope_scaling_short_factor = self.rope_scaling.get(\"short_factor\", None)\n        rope_scaling_long_factor = self.rope_scaling.get(\"long_factor\", None)\n        rope_scaling_short_mscale = self.rope_scaling.get(\"short_mscale\", None)\n        rope_scaling_long_mscale = self.rope_scaling.get(\"long_mscale\", None)\n        original_max_position_embeddings = self.rope_scaling.get(\n            \"original_max_position_embeddings\", None\n        )\n        if rope_scaling_type is None or rope_scaling_type not in [\"longrope\"]:\n            raise ValueError(\n                f\"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}\"\n            )\n        if not (\n            isinstance(rope_scaling_short_factor, list)\n            and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}\"\n            )\n        if (\n            not len(rope_scaling_short_factor)\n            == self.hidden_size // self.num_attention_heads // 2\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}\"\n            )\n        if not (\n            isinstance(rope_scaling_long_factor, list)\n            and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}\"\n            )\n        if (\n            not len(rope_scaling_long_factor)\n            == self.hidden_size // self.num_attention_heads // 2\n        ):\n            raise ValueError(\n                f\"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}\"\n            )\n        if not isinstance(rope_scaling_short_mscale, (int, float)):\n            raise ValueError(\n                f\"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}\"\n            )\n        if not isinstance(rope_scaling_long_mscale, (int, float)):\n            raise ValueError(\n                f\"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}\"\n            )\n        if not isinstance(original_max_position_embeddings, int):\n            raise ValueError(\n                f\"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}\"\n            )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.layernorm import (\n    FastRMSNorm,\n)\n\n\ndef load_attention(config, prefix, weights, layer_id):\n    prefixes = [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n    head_size = config.hidden_size // config.num_attention_heads\n    sizes = [\n        head_size * config.num_attention_heads,\n        head_size * config.num_key_value_heads,\n        head_size * config.num_key_value_heads,\n    ]\n    if config.num_attention_heads != config.num_key_value_heads:\n        base_layer = _load_gqa(config, prefix, weights)\n    else:\n        base_layer = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n    return TensorParallelMultiAdapterLinear.load(\n        base_layer=base_layer,\n        layer_id=layer_id,\n        layer_names=prefixes,\n        sizes=sizes,\n        process_group=weights.process_group,\n    )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    return TensorParallelColumnLinear.load_multi(\n        config,\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n        weights=weights,\n        bias=True,\n    )\n\n\nclass Qwen2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        index: int,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.window_size = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights, index)\n\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            index,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        if prefill_cache_indices is not None:\n            kv_to_cache = kv[prefill_cache_indices]\n        else:\n            kv_to_cache = kv\n\n        kv_cache.store(\n            key=kv_to_cache[:, 0],\n            value=kv_to_cache[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv_to_cache[:, 0],\n                value=kv_to_cache[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.window_size,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n                window_size_left=self.window_size,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Qwen2MLP(nn.Module):\n    def __init__(self, prefix, config, weights, index):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        prefixes = [f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"]\n        sizes = [\n            config.intermediate_size,\n            config.intermediate_size,\n        ]\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            layer_id=index,\n            layer_names=prefixes,\n            sizes=sizes,\n            process_group=weights.process_group,\n        )\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            index,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nclass Qwen2Layer(nn.Module):\n    def __init__(self, prefix, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.layers.{layer_id}\"\n        self.self_attn = Qwen2Attention(\n            index=layer_id, prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = Qwen2MLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, index=layer_id\n        )\n        self.input_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = FastRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        normed_hidden_states, residual = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n        hidden_states = attn_output + residual\n\n        # faster post attention rms norm\n        hidden_states, residual = self.post_attention_layernorm(hidden_states)\n        mlp_output = self.mlp(hidden_states, adapter_data)\n        hidden_states = mlp_output + residual\n        return hidden_states\n\n\nclass Qwen2Model(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        prefix = f\"{prefix}.model\" if prefix else \"model\"\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.layers = nn.ModuleList(\n            [\n                Qwen2Layer(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        adapter_data,\n    ) -> torch.Tensor:\n        hidden_states = inputs_embeds\n\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids,\n            true_max_s,\n            hidden_states.dtype,\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                prefill_cache_indices,\n                adapter_data,\n            )\n\n        hidden_states, _ = self.norm(hidden_states)\n\n        return hidden_states\n\n\nclass Qwen2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        self.model = Qwen2Model(prefix, config, weights)\n\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=f\"{prefix}.{suffix}\" if prefix else suffix,\n            weights=weights,\n        )\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\" if prefix else \"model.embed_tokens\",\n            weights=weights,\n        )\n\n        self.window_size = config.sliding_window\n        self.window_size_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.window_size is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor] = None,\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        true_max_s = max_s\n        if prefill_cache_indices is not None:\n            # Slots also need to be sliced as it has the same size as the whole kv tensor\n            slots = slots[prefill_cache_indices]\n        elif self.window_size is not None:\n            # Clamp in decode mode as paged attention requires clamped values whereas the flash attention\n            # kernel requires the true values\n            seqlen = seqlen.clamp(max=self.window_size_tensor)\n\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        hidden_states = self.model(\n            inputs_embeds,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_rw_modeling.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_utils import PreTrainedModel\nfrom text_generation_server.layers import (\n    SpeculativeHead,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import FastLayerNorm\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.layers.attention import (\n    attention,\n    paged_attention,\n    Seqlen,\n)\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n\n    linear = get_linear(weight, bias)\n    if config.parallel_attn:\n        return linear\n    else:\n        return TensorParallelRowLinear(linear, process_group=weights.process_group)\n\n\nclass RWConfig(PretrainedConfig):\n    attribute_map = {\n        \"num_hidden_layers\": \"n_layer\",\n        \"num_attention_heads\": \"n_head\",\n        \"num_key_value_heads\": \"n_head_kv\",\n    }\n\n    def __init__(\n        self,\n        model_type=\"RefinedWeb\",\n        vocab_size=250880,\n        hidden_size=64,\n        num_hidden_layers=None,\n        num_attention_heads=None,\n        num_ln_in_prallel_attention=None,\n        layer_norm_epsilon=1e-5,\n        initializer_range=0.02,\n        use_cache=True,\n        bos_token_id=1,\n        eos_token_id=2,\n        hidden_dropout=0.0,\n        attention_dropout=0.0,\n        num_kv_heads=None,\n        multi_query=False,\n        alibi=False,\n        new_decoder_architecture=None,\n        bias=False,\n        parallel_attn=False,\n        rope_theta=10_000.0,\n        **kwargs,\n    ):\n        if alibi:\n            raise NotImplementedError(\n                \"alibi is not supported by this version of the model\"\n            )\n\n        self.model_type = model_type\n        self.alibi = False\n        self.rotary = True\n        self.rope_theta = rope_theta\n        self.max_position_embeddings = 2048\n\n        self.vocab_size = vocab_size\n        # Backward compatibility with n_embed kwarg\n        n_embed = kwargs.pop(\"n_embed\", None)\n        self.hidden_size = hidden_size if n_embed is None else n_embed\n        self.n_layer = (\n            num_hidden_layers\n            if num_hidden_layers is not None\n            else kwargs.pop(\"n_layer\", 2)\n        )\n        self.n_head = (\n            num_attention_heads\n            if num_attention_heads is not None\n            else kwargs.pop(\"n_head\", 8)\n        )\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.num_ln_in_parallel_attn = num_ln_in_prallel_attention\n        self.initializer_range = initializer_range\n        self.use_cache = use_cache\n        self.hidden_dropout = hidden_dropout\n        self.attention_dropout = attention_dropout\n        self.bias = bias\n        self.parallel_attn = parallel_attn\n\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n\n        if num_kv_heads is not None:\n            self.n_head_kv = num_kv_heads\n        else:\n            old_n_head_kv = kwargs.pop(\"n_head_kv\", None)\n            if old_n_head_kv is not None:\n                self.n_head_kv = old_n_head_kv\n            else:\n                self.n_head_kv = 1 if multi_query else self.n_head\n\n        if new_decoder_architecture is not None:\n            self.new_decoder_architecture = new_decoder_architecture\n        elif model_type == \"RefinedWeb\":\n            self.new_decoder_architecture = True\n        else:\n            self.new_decoder_architecture = False\n\n        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)\n\n\nclass FlashRWAttention(torch.nn.Module):\n    def __init__(\n        self,\n        config,\n        prefix: str,\n        weights,\n    ):\n        super().__init__()\n        self.num_heads = config.n_head\n        self.num_heads_kv = config.n_head_kv\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n        self.rope_theta = config.rope_theta\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=self.rope_theta,\n            device=weights.device,\n        )\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=config.bias,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=config.bias\n        )\n\n        if self.num_heads_kv == 1:\n            self.kv_head_mapping = torch.zeros(\n                self.num_heads, dtype=torch.int32, device=weights.device\n            )\n        else:\n            self.kv_head_mapping = torch.arange(\n                0, self.num_heads, dtype=torch.int32, device=weights.device\n            )\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n\n        # Split query from key_value\n        query, kv = qkv.split(\n            [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],\n            dim=1,\n        )\n\n        # Prepare query and key_value for indexing\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)\n\n        # Inplace rotary\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, 0],\n            value=kv[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, 0],\n                value=kv[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.dense(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass FlashRWLargeAttention(torch.nn.Module):\n    def __init__(\n        self,\n        config,\n        prefix: str,\n        weights,\n    ):\n        super().__init__()\n\n        hidden_size = config.hidden_size\n        num_heads = config.n_head\n        # num_heads_kv = config.n_head_kv\n        num_groups = config.n_head_kv\n\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n        self.num_groups = num_groups\n        self.rope_theta = config.rope_theta\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=self.rope_theta,\n            device=weights.device,\n        )\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        # self.num_groups = num_heads // (num_heads_kv * 2)\n        self.num_heads = num_heads // self.num_groups\n        # self.num_heads_kv = num_heads_kv // self.num_groups\n        process_group = weights.process_group\n\n        if process_group.size() > self.num_groups:\n            raise NotImplementedError(\n                \"Tensor Parallelism is not implemented for world_size > n groups\"\n            )\n        if self.num_groups % process_group.size() != 0:\n            raise NotImplementedError(\n                f\"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}\"\n            )\n\n        self.num_groups = self.num_groups // process_group.size()\n\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.query_key_value\",\n            weights=weights,\n            bias=config.bias,\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.dense = load_row(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=config.bias\n        )\n\n        self.kv_head_mapping = torch.arange(\n            0, self.num_groups, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_heads)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.query_key_value(hidden_states)\n        qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)\n\n        # Split on group dimension\n        query, kv = qkv.split(\n            [self.num_heads, 2],\n            dim=2,\n        )\n        # Merge groups and heads\n        query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)\n\n        # Inplace rotary\n        self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)\n\n        kv_cache.store(\n            key=kv[:, :, 0].contiguous(),\n            value=kv[:, :, 1].contiguous(),\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv[:, :, 0],\n                value=kv[:, :, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.dense(\n            attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)\n        )\n\n\nclass FlashMLP(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        self.act = torch.nn.functional.gelu\n\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=config.bias\n        )\n        self.dense_4h_to_h = load_row(\n            config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=config.bias\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        return hidden_states\n\n\nclass FlashRWLayer(nn.Module):\n    def __init__(\n        self,\n        layer_id,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n\n        parallel_attn = config.parallel_attn\n        self.parallel_attn = parallel_attn\n\n        prefix = f\"{prefix}.h.{layer_id}\"\n\n        # NOTE: Falcon 180B uses the ln_attn prefix\n        ln_prefix = \"input_layernorm\"\n        if config.num_hidden_layers == 80:\n            ln_prefix = \"ln_attn\"\n\n        self.input_layernorm = FastLayerNorm.load(\n            prefix=f\"{prefix}.{ln_prefix}\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.self_attention = FlashRWAttention(\n            config,\n            prefix=f\"{prefix}.self_attention\",\n            weights=weights,\n        )\n        self.post_attention_layernorm = (\n            FastLayerNorm.load(\n                prefix=f\"{prefix}.post_attention_layernorm\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n            if not parallel_attn\n            else None\n        )\n\n        self.mlp = FlashMLP(\n            config,\n            prefix=f\"{prefix}.mlp\",\n            weights=weights,\n        )\n\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        if self.parallel_attn:\n            ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            attn_output = self.self_attention(\n                ln_hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n            mlp_output = self.mlp(ln_hidden_states)\n            intermediate = mlp_output + attn_output\n\n            if self.process_group.size() > 1:\n                torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n            return intermediate, residual\n        else:\n            hidden_states, residual = self.input_layernorm(hidden_states, residual)\n\n            hidden_states = self.self_attention(\n                hidden_states,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache,\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n            if self.post_attention_layernorm is not None:\n                hidden_states, residual = self.post_attention_layernorm(\n                    hidden_states, residual\n                )\n\n            mlp_output = self.mlp(hidden_states)\n\n            return mlp_output, residual\n\n\nclass FlashRWLayerNorm(nn.Module):\n    def __init__(self, config, prefix: str, weights):\n        super().__init__()\n        # Falcon2 includes the number of layer norms in the config\n        # in the case no number of layer norms is provided, we default to 1\n        self.num_ln = getattr(config, \"num_ln_in_parallel_attn\", 1)\n\n        # Falcon 180B uses the ln_attn prefix and has 2 layer norms\n        if config.num_hidden_layers == 80:\n            self.num_ln = 2\n\n        if self.num_ln == 1:\n            self.input_ln = FastLayerNorm.load(\n                prefix=f\"{prefix}.input_layernorm\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n        elif self.num_ln == 2:\n            self.ln_attn = FastLayerNorm.load(\n                prefix=f\"{prefix}.ln_attn\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n            self.ln_mlp = FastLayerNorm.load(\n                prefix=f\"{prefix}.ln_mlp\",\n                weights=weights,\n                eps=config.layer_norm_epsilon,\n            )\n        else:\n            raise ValueError(\"Number of layer norms can either be 1 or 2.\")\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n    ):\n        if self.num_ln == 1:\n            ln_hidden_states, residual = self.input_ln(hidden_states, residual)\n            return ln_hidden_states, ln_hidden_states, residual\n        elif self.num_ln == 2:\n            ln_attn, residual = self.ln_attn(hidden_states, residual)\n            ln_mlp, _ = self.ln_mlp(residual)\n            return ln_attn, ln_mlp, residual\n\n\nclass FlashRWLargeLayer(nn.Module):\n    def __init__(self, layer_id, prefix: str, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.h.{layer_id}\"\n\n        self.ln_layer = FlashRWLayerNorm(config, prefix, weights)\n\n        self.self_attention = FlashRWLargeAttention(\n            config,\n            prefix=f\"{prefix}.self_attention\",\n            weights=weights,\n        )\n        assert config.parallel_attn, \"This version doesn't support non parallel_attn\"\n\n        self.mlp = FlashMLP(config, prefix=f\"{prefix}.mlp\", weights=weights)\n\n        self.process_group = weights.process_group\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        # Layer norm.\n        ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)\n\n        # Self attention.\n        attn_output = self.self_attention(\n            ln_attn,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        # MLP.\n        mlp_output = self.mlp(ln_mlp)\n\n        intermediate = attn_output + mlp_output\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(intermediate, group=self.process_group)\n\n        return intermediate, residual\n\n\nclass FlashRWPreTrainedModel(PreTrainedModel):\n    config_class = RWConfig\n\n\nclass FlashRWModel(FlashRWPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.config = config\n\n        self.word_embeddings = TensorParallelEmbedding(\n            prefix=f\"{prefix}.word_embeddings\", weights=weights\n        )\n\n        if config.new_decoder_architecture:\n            self.h = nn.ModuleList(\n                [\n                    FlashRWLargeLayer(layer_id, prefix, config, weights)\n                    for layer_id in range(config.num_hidden_layers)\n                ]\n            )\n            self.cache_size = self.h[0].self_attention.num_groups\n        else:\n            self.h = nn.ModuleList(\n                [\n                    FlashRWLayer(layer_id, prefix, config, weights)\n                    for layer_id in range(config.num_hidden_layers)\n                ]\n            )\n            self.cache_size = self.h[0].self_attention.num_heads_kv\n\n        self.ln_f = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_f\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n\n        self.head_size = self.h[0].self_attention.head_size\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.word_embeddings(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(\n            position_ids, max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.h):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashRWForCausalLM(FlashRWPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        self.transformer = FlashRWModel(prefix, config, weights)\n\n        self.lm_head = SpeculativeHead.load(config, prefix=\"lm_head\", weights=weights)\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.transformer(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    SpeculativeHead,\n    TensorParallelEmbedding,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.gptq import GPTQWeightsLoader\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n)\n\n\ndef load_multi_mqa(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    if config.quantize == \"gptq\":\n        return _load_multi_mqa_gptq(\n            config, prefix, weights, bias, head_size, num_heads, hidden_size\n        )\n    elif config.quantize == \"marlin\":\n        raise RuntimeError(\n            \"santacoder models with marlin quantization are not yet supported\"\n        )\n    else:\n        return _load_multi_mqa(\n            config, prefix, weights, bias, head_size, num_heads, hidden_size\n        )\n\n\ndef _load_multi_mqa_gptq(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    from text_generation_server.layers.gptq import GPTQWeight\n\n    if any(\"c_attn\" in k for k in weights.routing.keys()) and not config.transpose:\n        world_size = weights.process_group.size()\n        rank = weights.process_group.rank()\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.qweight\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - 2 * head_size) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert (shape[1] - 2 * head_size) % world_size == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size :]\n        qweight = torch.cat([q_tensor, kv_tensor], dim=1)\n        qweight = qweight.to(device=weights.device)\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.scales\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - 2 * head_size) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert (shape[1] - 2 * head_size) % world_size == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size :]\n        scales = torch.cat([q_tensor, kv_tensor], dim=1)\n        scales = scales.to(device=weights.device)\n\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.qzeros\")\n        shape = slice_.get_shape()\n        block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        assert 2 * head_size % (32 // 4) == 0\n        q_tensor = slice_[:, start:stop]\n        kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]\n        qzeros = torch.cat([q_tensor, kv_tensor], dim=1)\n        qzeros = qzeros.to(device=weights.device)\n\n        loader = weights.weights_loader\n        assert isinstance(loader, GPTQWeightsLoader)\n        loader._get_gptq_params(weights)\n        if loader.quant_method == \"gptq\":\n            g_idx = weights.get_tensor(f\"{prefix}.c_attn.g_idx\")\n            g_idx = g_idx.to(device=weights.device)\n        elif loader.quant_method == \"awq\":\n            g_idx = None\n            from text_generation_server.layers.awq.conversion_utils import (\n                fast_awq_to_gptq,\n            )\n\n            qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)\n\n        from text_generation_server.layers.gptq import HAS_EXLLAMA\n\n        weight = GPTQWeight(\n            qweight=qweight,\n            qzeros=qzeros,\n            scales=scales,\n            g_idx=g_idx,\n            bits=loader.bits,\n            groupsize=loader.groupsize,\n            use_awq_kernel=loader.quantize == \"awq\",\n            use_exllama=HAS_EXLLAMA,\n        )\n\n        if bias:\n            slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n            shape = slice_.get_shape()\n            block_size = (shape[0] - 2 * head_size) // world_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[start:stop]\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            bias = torch.cat([q_tensor, kv_tensor], dim=0)\n            bias = bias.to(device=weights.device)\n\n        return TensorParallelColumnLinear(get_linear(weight, bias))\n    else:\n        raise NotImplementedError(\"Gptq loading with santacoder is not implemented\")\n\n\ndef _load_multi_mqa(\n    config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size\n):\n    if any(\"c_attn\" in k for k in weights.routing.keys()):\n        slice_ = weights._get_slice(f\"{prefix}.c_attn.weight\")\n        shape = slice_.get_shape()\n        world_size = weights.process_group.size()\n        rank = weights.process_group.rank()\n        if config.transpose:\n            block_size = (shape[1] - 2 * head_size) // world_size\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            assert (shape[1] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[:, start:stop]\n            kv_tensor = slice_[:, -2 * head_size :]\n            weight = torch.cat([q_tensor, kv_tensor], dim=1).T\n        else:\n            block_size = (shape[0] - 2 * head_size) // world_size\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            weight = torch.cat([q_tensor, kv_tensor], dim=0)\n        if bias:\n            slice_ = weights._get_slice(f\"{prefix}.c_attn.bias\")\n            shape = slice_.get_shape()\n            block_size = (shape[0] - 2 * head_size) // world_size\n            assert (shape[0] - 2 * head_size) % world_size == 0\n            start = rank * block_size\n            stop = (rank + 1) * block_size\n            q_tensor = slice_[start:stop]\n            kv_tensor = slice_[-2 * head_size :]\n            bias = torch.cat([q_tensor, kv_tensor], dim=0)\n    else:\n        if config.transpose:\n            w = [\n                weights.get_sharded(f\"{prefix}.q_attn.weight\", dim=1).T,\n                weights.get_tensor(f\"{prefix}.kv_attn.weight\").T,\n            ]\n            weight = torch.cat(w, dim=0)\n        else:\n            w = [\n                weights.get_sharded(f\"{prefix}.q_attn.weight\", dim=0),\n                weights.get_tensor(f\"{prefix}.kv_attn.weight\"),\n            ]\n            weight = torch.cat(w, dim=1)\n\n        if bias:\n            b = [\n                weights.get_sharded(f\"{prefix}.q_attn.bias\", dim=0),\n                weights.get_tensor(f\"{prefix}.kv_attn.bias\"),\n            ]\n            bias = torch.cat(b, dim=0)\n        else:\n            bias = None\n\n    weight = weight.to(dtype=weights.dtype).to(device=weights.device)\n    assert list(weight.shape) == [\n        (num_heads + 2) * head_size,\n        hidden_size,\n    ], f\"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}\"\n    if bias is not None:\n        bias = bias.to(dtype=weights.dtype).to(device=weights.device)\n        assert list(bias.shape) == [\n            (num_heads + 2) * head_size\n        ], f\"{weight.shape} != {[(num_heads + 2) * head_size]}\"\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_col(config, prefix: str, weights, bias: bool):\n    if config.transpose:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=1).T\n    else:\n        weight = weights.get_multi_weights_col([prefix], dim=0)\n\n    if bias:\n        bias = weights.get_sharded(f\"{prefix}.bias\", dim=0)\n    else:\n        bias = None\n    return TensorParallelColumnLinear(get_linear(weight, bias))\n\n\ndef load_row(config, prefix: str, weights, bias: bool):\n    if config.transpose:\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=0).T\n    else:\n        weight = weights.get_weights_row(prefix)\n\n    if bias and weights.process_group.rank() == 0:\n        # Rank is only on the first rank process\n        bias = weights.get_tensor(f\"{prefix}.bias\")\n    else:\n        bias = None\n    return TensorParallelRowLinear(\n        get_linear(weight, bias), process_group=weights.process_group\n    )\n\n\nclass FlashMQAttention(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        num_heads = config.num_attention_heads\n        hidden_size = config.hidden_size\n\n        self.num_heads = num_heads\n        self.hidden_size = hidden_size\n        self.head_size = hidden_size // num_heads\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n\n        self.softmax_scale = self.head_size ** (-0.5)\n\n        self.c_attn = load_multi_mqa(\n            config,\n            prefix=prefix,\n            weights=weights,\n            bias=True,\n            head_size=self.head_size,\n            hidden_size=hidden_size,\n            num_heads=self.num_heads,\n        )\n        self.c_proj = load_row(\n            config, prefix=f\"{prefix}.c_proj\", weights=weights, bias=True\n        )\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n        self.kv_head_mapping = torch.zeros(\n            self.num_heads, dtype=torch.int32, device=weights.device\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        qkv = self.c_attn(hidden_states)\n\n        # Split query from key_value\n        query, key_value = qkv.split(\n            [self.head_size * self.num_heads, 2 * self.head_size], dim=1\n        )\n\n        # Prepare query and key_value for indexing\n        query = query.view(-1, self.num_heads, self.head_size)\n        key_value = key_value.view(-1, 2, 1, self.head_size)\n\n        kv_cache.store(\n            key=key_value[:, 0],\n            value=key_value[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=key_value[:, 0],\n                value=key_value[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n            )\n\n        return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n\nclass MLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.activation_function\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n\n        self.c_fc = load_col(\n            config, prefix=f\"{prefix}.c_fc\", weights=weights, bias=True\n        )\n        self.c_proj = load_row(\n            config, prefix=f\"{prefix}.c_proj\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        return hidden_states\n\n\nclass Block(nn.Module):\n    def __init__(self, prefix: str, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"{prefix}.h.{layer_id}\"\n        self.ln_1 = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_1\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.ln_2 = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_2\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.self_attn = FlashMQAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.mlp = MLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n    ):\n        hidden_states, residual = self.ln_1(hidden_states, residual)\n        hidden_states = self.self_attn(\n            hidden_states,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n\n        hidden_states, residual = self.ln_2(hidden_states, residual)\n\n        mlp_output = self.mlp(hidden_states)\n\n        return mlp_output, residual\n\n\nclass FlashSantacoderModel(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.config = config\n\n        self.process_group = weights.process_group\n        self.wte = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wte\",\n            weights=weights,\n            reduce=False,\n        )\n        self.wpe = TensorParallelEmbedding(\n            prefix=f\"{prefix}.wpe\",\n            weights=weights,\n            reduce=False,\n        )\n\n        self.layers = nn.ModuleList(\n            [\n                Block(\n                    prefix,\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.ln_f = FastLayerNorm.load(\n            prefix=\"transformer.ln_f\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n    ) -> torch.Tensor:\n        hidden_states = self.wte(input_ids) + self.wpe(position_ids)\n\n        if self.process_group.size() > 1:\n            torch.distributed.all_reduce(hidden_states, group=self.process_group)\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n            )\n\n        hidden_states, _ = self.ln_f(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashSantacoderForCausalLM(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        config.transpose = config.architectures[0].startswith(\"GPT2\")\n        self.model = FlashSantacoderModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config, prefix=f\"{prefix}.wte\", weights=weights\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.distributed\n\nfrom torch import nn\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom typing import Optional, List, Tuple\n\nfrom text_generation_server.layers.attention import (\n    paged_attention,\n    attention,\n    Seqlen,\n)\nfrom text_generation_server.layers import (\n    TensorParallelMultiAdapterLinear,\n    TensorParallelAdapterRowLinear,\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    get_linear,\n)\nfrom text_generation_server.layers.attention.kv_cache import get_kv_scales\nfrom text_generation_server.layers.layernorm import (\n    FastLayerNorm,\n    FastRMSNorm,\n)\nfrom text_generation_server.layers.rotary import (\n    PositionRotaryEmbedding,\n)\nfrom text_generation_server.utils.weights import UnquantizedWeight\n\n\nclass Starcoder2Config(PretrainedConfig):\n    model_type = \"starcoder2\"\n\n    def __init__(\n        self,\n        vocab_size=49152,\n        hidden_size=3072,\n        intermediate_size=12288,\n        num_hidden_layers=30,\n        num_attention_heads=24,\n        num_key_value_heads=2,\n        mlp_type=\"default\",\n        hidden_act=\"gelu_pytorch_tanh\",\n        max_position_embeddings=4096,\n        initializer_range=0.018042,\n        norm_type=\"layer_norm\",\n        norm_epsilon=1e-5,\n        use_cache=True,\n        bos_token_id=50256,\n        eos_token_id=50256,\n        rope_theta=10000.0,\n        sliding_window=None,\n        attention_dropout=0.0,\n        residual_dropout=0.0,\n        embedding_dropout=0.0,\n        use_bias: bool = True,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.sliding_window = sliding_window\n        self.use_bias = use_bias\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.mlp_type = mlp_type\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.norm_type = norm_type\n        self.norm_epsilon = norm_epsilon\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n        self.residual_dropout = residual_dropout\n        self.embedding_dropout = embedding_dropout\n\n        super().__init__(\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            **kwargs,\n        )\n\n\ndef load_attention(config, prefix, weights, layer_id):\n    prefixes = [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n    head_size = config.hidden_size // config.num_attention_heads\n    sizes = [\n        head_size * config.num_attention_heads,\n        head_size * config.num_key_value_heads,\n        head_size * config.num_key_value_heads,\n    ]\n    if config.num_attention_heads != config.num_key_value_heads:\n        base_layer = _load_gqa(config, prefix, weights)\n    else:\n        base_layer = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            dim=0,\n            weights=weights,\n            bias=config.use_bias,\n        )\n    return TensorParallelMultiAdapterLinear.load(\n        base_layer=base_layer,\n        layer_id=layer_id,\n        layer_names=prefixes,\n        sizes=sizes,\n        process_group=weights.process_group,\n    )\n\n\ndef _load_gqa(config, prefix: str, weights):\n    assert config.hidden_size % config.num_attention_heads == 0\n    assert config.num_attention_heads % weights.process_group.size() == 0\n\n    weight = weights.get_multi_weights_col(\n        prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n        dim=0,\n    )\n\n    if isinstance(weight, UnquantizedWeight):\n        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)\n\n        head_size = config.hidden_size // config.num_attention_heads\n        num_heads = config.num_attention_heads // weights.process_group.size()\n        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()\n        assert list(weight.weight.shape) == [\n            (num_heads + 2 * num_key_value_heads) * head_size,\n            config.hidden_size,\n        ], f\"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}\"\n\n    if config.use_bias:\n        w = [\n            weights.get_sharded(f\"{p}.bias\", dim=0)\n            for p in [f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"]\n        ]\n        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)\n    else:\n        bias = None\n\n    return TensorParallelColumnLinear(get_linear(weight, bias=bias))\n\n\nclass Starcoder2Attention(torch.nn.Module):\n    def __init__(\n        self,\n        index: int,\n        prefix: str,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.max_past = (\n            config.sliding_window if config.sliding_window is not None else -1\n        )\n        self.num_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_heads\n\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config,\n            dim=self.head_size,\n            base=config.rope_theta,\n            device=weights.device,\n        )\n\n        self.softmax_scale = self.head_size**-0.5\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            config.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.query_key_value = load_attention(config, prefix, weights, index)\n        self.kv_scales = get_kv_scales(weights, f\"{prefix}\")\n\n        o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=getattr(config, \"use_bias\", False),\n        )\n\n        self.o_proj = TensorParallelAdapterRowLinear.load(\n            o_proj,\n            index,\n            \"o_proj\",\n            process_group=weights.process_group,\n        )\n\n        self.num_groups = self.num_heads // self.num_key_value_heads\n        self.kv_head_mapping = torch.arange(\n            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device\n        ).repeat_interleave(self.num_groups)\n\n    def forward(\n        self,\n        hidden_states,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        qkv = self.query_key_value(hidden_states, adapter_data)\n        query, kv = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                2 * self.head_size * self.num_key_value_heads,\n            ],\n            dim=1,\n        )\n        query = query.view(-1, self.num_heads, self.head_size)\n        kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)\n\n        self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)\n\n        if prefill_cache_indices is not None:\n            kv_to_cache = kv[prefill_cache_indices]\n        else:\n            kv_to_cache = kv\n\n        kv_cache.store(\n            key=kv_to_cache[:, 0],\n            value=kv_to_cache[:, 1],\n            slots=slots,\n            kv_scales=self.kv_scales,\n        )\n\n        # Prefill\n        if cu_seqlen_prefill is not None:\n            # flash attention\n            attn_output = attention(\n                query=query,\n                key=kv_to_cache[:, 0],\n                value=kv_to_cache[:, 1],\n                kv_cache=kv_cache,\n                kv_scales=self.kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=self.softmax_scale,\n                window_size_left=self.max_past,\n            )\n        # Decode\n        else:\n            attn_output = paged_attention(\n                query,\n                kv_cache,\n                self.kv_head_mapping,\n                self.softmax_scale,\n                block_tables,\n                seqlen,\n                max_s,\n                kv_scales=self.kv_scales,\n                window_size_left=self.max_past,\n            )\n\n        return self.o_proj(\n            attn_output.view(-1, self.num_heads * self.head_size), adapter_data\n        )\n\n\nclass Starcoder2MLP(nn.Module):\n    def __init__(self, prefix, config, weights, index):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        c_fc = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.c_fc\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n        c_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n\n        self.c_fc = TensorParallelMultiAdapterLinear.load(\n            c_fc,\n            layer_id=index,\n            layer_names=[f\"{prefix}.c_fc\"],\n            sizes=[config.intermediate_size, config.intermediate_size],\n            process_group=weights.process_group,\n        )\n\n        self.c_proj = TensorParallelAdapterRowLinear.load(\n            c_proj,\n            index,\n            \"c_proj\",\n            process_group=weights.process_group,\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        hidden_states = self.c_fc(hidden_states, adapter_data)\n        hidden_states = self.act(hidden_states)\n        return self.c_proj(hidden_states, adapter_data)\n\n\nclass Starcoder2GatedMLP(nn.Module):\n    def __init__(self, index, prefix, config, weights):\n        super().__init__()\n        act = config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        # Fuse gate and up proj\n        prefixes = [f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"]\n        sizes = [\n            config.intermediate_size,\n            config.intermediate_size,\n        ]\n        gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=prefixes,\n            weights=weights,\n            dim=0,\n            bias=config.use_bias,\n        )\n        self.gate_up_proj = TensorParallelMultiAdapterLinear.load(\n            gate_up_proj,\n            index,\n            layer_names=prefixes,\n            sizes=sizes,\n            process_group=weights.process_group,\n        )\n        down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=config.use_bias,\n        )\n        self.down_proj = TensorParallelAdapterRowLinear.load(\n            down_proj,\n            index,\n            \"down_proj\",\n            process_group=weights.process_group,\n        )\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n    def forward(self, hidden_states, adapter_data):\n        gate_up_states = self.gate_up_proj(hidden_states, adapter_data)\n        gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data\n        )\n\n\nSTARCODER2_NORMALIZATION_CLASSES = {\n    \"layer_norm\": FastLayerNorm,\n    \"rms_norm\": FastRMSNorm,\n}\n\nSTARCODER2_MLP_CLASSES = {\n    \"default\": Starcoder2MLP,\n    \"gated\": Starcoder2GatedMLP,\n}\n\n\nclass Starcoder2Layer(nn.Module):\n    def __init__(self, layer_id, config, weights):\n        super().__init__()\n        prefix = f\"model.layers.{layer_id}\"\n        self.self_attn = Starcoder2Attention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights, index=layer_id\n        )\n\n        self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights, index=layer_id\n        )\n\n        self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.norm_epsilon\n        )\n        self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[\n            config.norm_type\n        ].load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.norm_epsilon,\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        prefill_cache_indices,\n        adapter_data,\n    ):\n        normed_hidden_states, res = self.input_layernorm(hidden_states, residual)\n\n        # Self Attention\n        attn_output = self.self_attn(\n            normed_hidden_states,\n            cos,\n            sin,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n\n        # faster post attention rms norm\n        normed_attn_res_output, attn_res = self.post_attention_layernorm(\n            attn_output, res\n        )\n\n        mlp_output = self.mlp(normed_attn_res_output, adapter_data)\n\n        return mlp_output, attn_res\n\n\nclass Starcoder2Model(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        process_group = weights.process_group\n        self.tp_rank = process_group.rank()\n        self.tp_world_size = process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_tokens\", weights=weights\n        )\n        self.layers = nn.ModuleList(\n            [\n                Starcoder2Layer(\n                    layer_id,\n                    config,\n                    weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.norm_epsilon\n        )\n\n        self.gradient_checkpointing = False\n\n        self.head_size = self.layers[0].self_attn.head_size\n        self.num_heads = self.layers[0].self_attn.num_heads\n        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        true_max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        adapter_data,\n    ) -> torch.Tensor:\n        hidden_states = self.embed_tokens(input_ids)\n\n        # Get rotary cos and sin for this forward\n        # Avoid to index in each layer\n        cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(\n            position_ids, true_max_s, hidden_states.dtype\n        )\n\n        residual = None\n        for i, layer in enumerate(self.layers):\n            hidden_states, residual = layer(\n                hidden_states,\n                residual,\n                cos,\n                sin,\n                cu_seqlen_prefill,\n                kv_cache[i],\n                block_tables,\n                slots,\n                seqlen,\n                max_s,\n                prefill_cache_indices,\n                adapter_data,\n            )\n\n        hidden_states, _ = self.norm(hidden_states, residual)\n\n        return hidden_states\n\n\nclass FlashStarcoder2ForCausalLM(torch.nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"model\"\n        else:\n            prefix = f\"{prefix}.model\"\n\n        self.model = Starcoder2Model(prefix, config, weights)\n        try:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=\"lm_head\",\n                weights=weights,\n            )\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(\n                config,\n                prefix=f\"{prefix}.embed_tokens\",\n                weights=weights,\n            )\n\n        self.max_past = config.sliding_window\n        self.max_past_tensor = (\n            torch.tensor(config.sliding_window, device=weights.device)\n            if self.max_past is not None\n            else None\n        )\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        true_max_s = max_s\n        if prefill_cache_indices is not None:\n            # Slots also need to be sliced as it has the same size as the whole kv tensor\n            slots = slots[prefill_cache_indices]\n        elif self.max_past is not None:\n            # Clamp in decode mode as paged attention requires clamped values whereas the flash attention\n            # kernel requires the true values\n            seqlen = seqlen.clamp(max=self.max_past_tensor)\n\n        hidden_states = self.model(\n            input_ids,\n            position_ids,\n            cu_seqlen_prefill,\n            kv_cache,\n            block_tables,\n            slots,\n            seqlen,\n            max_s,\n            true_max_s,\n            prefill_cache_indices,\n            adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits = self.lm_head(hidden_states)\n        return logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/gemma3/configuration_gemma3.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/gemma3/modular_gemma3.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_gemma3.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Optional\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\nfrom transformers import SiglipVisionConfig\n\nlogger = logging.get_logger(__name__)\n\n\nclass Gemma3TextConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Gemma3Model`]. It is used to instantiate a Gemma3\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the Gemma3-4B.\n    e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 262144):\n            Vocabulary size of the Gemma3 model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Gemma3Model`]\n        hidden_size (`int`, *optional*, defaults to 2304):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 9216):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 26):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 8):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        num_key_value_heads (`int`, *optional*, defaults to 4):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        head_dim (`int`, *optional*, defaults to 256):\n            The attention head dimension.\n        sliding_window (`int`, *optional*, defaults to 4096): in Gemma3, every other layer uses sliding window\n            attention. This is the size of the sliding window.\n        query_pre_attn_scalar (`float`, *optional*):\n            The scaling factor used on the attention scores, not that\n        rope_theta (`float`, *optional*, defaults to 1000000.0):\n            The base period of the RoPE embeddings used for global attention.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        rope_local_base_freq (float, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings for local attention.\n        sliding_window_pattern (`int`, *optional*, defaults to 6):\n            Pattern for the sliding window attention.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        hidden_activation (`str` or `function`, *optional*, defaults to `\"gelu_pytorch_tanh\"`):\n            The non-linear activation function (function or string) in the decoder. Will default to\n            `\"gelu_pytorch_tanh\"` if not specified. `\"gelu_pytorch_tanh\"` uses an approximation of the `\"gelu\"`\n            activation function.\n        pad_token_id (`int`, *optional*, defaults to 0):\n            Padding token id.\n        eos_token_id (`int`, *optional*, defaults to 1):\n            End of stream token id.\n        bos_token_id (`int`, *optional*, defaults to 2):\n            Beginning of stream token id.\n        tie_word_embeddings (`bool`, *optional*, defaults to `True`):\n            Whether to tie weight embeddings\n        max_position_embeddings (`int`, *optional*, defaults to 131072):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        final_logit_softcapping (`bool`, *optional*, defaults to `True`):\n            Whether to apply logit softcapping or nor\n        attn_logit_softcapping (`float`, *optional*, defaults to 50.0):\n            Scaling factor when applying tanh soft-capping on the attention scorexs.\n        cache_implementation (`str`, *optional*, defaults to `\"hybrid\"`):\n            The cache type to be used with `generate`.\n\n    ```python\n    >>> from transformers import Gemma3Model, Gemma3TextConfig\n    >>> # Initializing a Gemma3 gemma3-4b style configuration\n    >>> configuration = Gemma3Config()\n    >>> # Initializing a model from the gemma3-4b style configuration\n    >>> model = Gemma3Model(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"gemma3_text\"\n\n    def __init__(\n        self,\n        vocab_size: int = 262_144,\n        hidden_size: int = 2304,\n        intermediate_size: int = 9216,\n        num_hidden_layers: int = 26,\n        num_attention_heads: int = 8,\n        num_key_value_heads: int = 4,\n        head_dim: int = 256,\n        sliding_window: int = 4096,\n        query_pre_attn_scalar: Optional[float] = 256,\n        rope_theta: float = 1_000_000.0,\n        rope_scaling=None,\n        rope_local_base_freq: float = 10_000.0,\n        sliding_window_pattern: int = 6,\n        rms_norm_eps: float = 1e-6,\n        hidden_activation: str = \"gelu_pytorch_tanh\",\n        pad_token_id: int = 0,\n        eos_token_id: int = 1,\n        bos_token_id: int = 2,\n        tie_word_embeddings: bool = True,\n        max_position_embeddings: int = 131_072,\n        initializer_range: float = 0.02,\n        attention_bias: bool = False,\n        attention_dropout: float = 0.0,\n        use_cache: bool = True,\n        final_logit_softcapping=None,\n        attn_logit_softcapping=None,\n        cache_implementation: str = \"hybrid\",\n        **kwargs,\n    ):\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.head_dim = head_dim\n        self.num_key_value_heads = num_key_value_heads\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.rope_local_base_freq = rope_local_base_freq\n        # For configuring HybridCache to work with 5:1 attention pattern\n        self.sliding_window_pattern = sliding_window_pattern\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.hidden_activation = hidden_activation\n        self.query_pre_attn_scalar = query_pre_attn_scalar\n        self.sliding_window = sliding_window\n        self.final_logit_softcapping = final_logit_softcapping\n        self.attn_logit_softcapping = attn_logit_softcapping\n        self.cache_implementation = cache_implementation\n        rope_config_validation(self)\n\n\nclass Gemma3Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an\n    Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of the PaliGemma-2B.\n\n    e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n    Args:\n        text_config (`Union[Gemma3TextConfig, dict]`, *optional*):\n            The config object of the text backbone.\n        vision_config (`Union[AutoConfig, dict]`,  *optional*):\n            Custom vision config or dict.\n        mm_tokens_per_image (`int`, *optional*, defaults to 256):\n            The number of tokens per image embedding.\n        boi_token_index (`int`, *optional*, defaults to 255999):\n            The begin-of-image token index to wrap the image prompt.\n        eoi_token_index (`int`, *optional*, defaults to 256000):\n            The end-of-image token index to wrap the image prompt.\n        image_token_index (`int`, *optional*, defaults to 262144):\n            The image token index to encode the image prompt.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n\n    Example:\n\n    ```python\n    >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig\n\n    >>> # Initializing a Siglip-like vision config\n    >>> vision_config = SiglipVisionConfig()\n\n    >>> # Initializing a Gemma3 Text config\n    >>> text_config = Gemma3TextConfig()\n\n    >>> # Initializing a Gemma3 gemma-3-4b style configuration\n    >>> configuration = Gemma3Config(vision_config, text_config)\n\n    >>> # Initializing a model from the gemma-3-4b style configuration\n    >>> model = Gemma3TextConfig(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"gemma3\"\n    sub_configs = {\n        \"text_config\": Gemma3TextConfig,\n        \"vision_config\": SiglipVisionConfig,\n    }\n\n    def __init__(\n        self,\n        text_config: Optional[Gemma3TextConfig] = None,\n        vision_config: Optional[SiglipVisionConfig] = None,\n        mm_tokens_per_image: int = 256,\n        boi_token_index: int = 255_999,\n        eoi_token_index: int = 256_000,\n        image_token_index: int = 262_144,\n        initializer_range: float = 0.02,\n        **kwargs,\n    ):\n        if text_config is None:\n            text_config = Gemma3TextConfig()\n            logger.info(\n                \"text_config is None, using default Gemma3TextConfig vision config.\"\n            )\n        elif isinstance(text_config, dict):\n            text_config = Gemma3TextConfig(**text_config)\n\n        if isinstance(vision_config, dict):\n            vision_config = SiglipVisionConfig(**vision_config)\n        else:\n            vision_config = SiglipVisionConfig()\n            logger.info(\n                \"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited \"\n                \"to text tasks.\"\n            )\n\n        self.text_config = text_config\n        self.vision_config = vision_config\n        self.mm_tokens_per_image = mm_tokens_per_image\n        self.boi_token_index = boi_token_index\n        self.eoi_token_index = eoi_token_index\n        self.image_token_index = image_token_index\n        self.initializer_range = initializer_range\n\n        super().__init__(**kwargs)\n\n\n__all__ = [\"Gemma3Config\", \"Gemma3TextConfig\"]\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/gemma3/image_processing_gemma3.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Image processor class for Gemma3.\"\"\"\n\nimport itertools\nimport math\nfrom typing import Dict, List, Optional, Union\n\nimport numpy as np\n\nfrom transformers.image_processing_utils import (\n    BaseImageProcessor,\n    BatchFeature,\n    get_size_dict,\n)\nfrom transformers.image_transforms import (\n    convert_to_rgb,\n    resize,\n    to_channel_dimension_format,\n)\nfrom transformers.image_utils import (\n    IMAGENET_STANDARD_MEAN,\n    IMAGENET_STANDARD_STD,\n    ChannelDimension,\n    ImageInput,\n    PILImageResampling,\n    get_image_size,\n    infer_channel_dimension_format,\n    is_scaled_image,\n    to_numpy_array,\n    valid_images,\n    validate_preprocess_arguments,\n)\nfrom transformers.utils import (\n    TensorType,\n    filter_out_non_signature_kwargs,\n    is_vision_available,\n    logging,\n)\n\nfrom .utils import make_nested_list_of_images\n\n\nlogger = logging.get_logger(__name__)\n\n\nif is_vision_available():\n    import PIL\n\n\nclass Gemma3ImageProcessor(BaseImageProcessor):\n    r\"\"\"\n    Constructs a SigLIP image processor.\n\n    Args:\n        do_resize (`bool`, *optional*, defaults to `True`):\n            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by\n            `do_resize` in the `preprocess` method.\n        size (`Dict[str, int]` *optional*, defaults to `{\"height\": 224, \"width\": 224}`):\n            Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.\n        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):\n            Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.\n        do_rescale (`bool`, *optional*, defaults to `True`):\n            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in\n            the `preprocess` method.\n        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):\n            Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`\n            method.\n        do_normalize (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the image by the specified mean and standard deviation. Can be overridden by\n            `do_normalize` in the `preprocess` method.\n        image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):\n            Mean to use if normalizing the image. This is a float or list of floats the length of the number of\n            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.\n        image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):\n            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the\n            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.\n            Can be overridden by the `image_std` parameter in the `preprocess` method.\n        do_convert_rgb (`bool`, *optional*, defaults to `True`):\n            Whether to convert the image to RGB.\n        do_pan_and_scan (`bool`, *optional*):\n            Whether to apply `pan_and_scan` to images.\n    \"\"\"\n\n    model_input_names = [\"pixel_values\", \"num_crops\"]\n\n    def __init__(\n        self,\n        do_resize: bool = True,\n        size: Dict[str, int] = None,\n        resample: PILImageResampling = PILImageResampling.BILINEAR,\n        do_rescale: bool = False,\n        rescale_factor: Union[int, float] = 1 / 255,\n        do_normalize: bool = True,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        do_convert_rgb: bool = None,\n        do_pan_and_scan: bool = None,\n        pan_and_scan_min_crop_size: int = None,\n        pan_and_scan_max_num_crops: int = None,\n        pan_and_scan_min_ratio_to_activate: float = None,\n        **kwargs,\n    ) -> None:\n        super().__init__(**kwargs)\n        size = size if size is not None else {\"height\": 224, \"width\": 224}\n        image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN\n        image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD\n\n        self.do_resize = do_resize\n        self.size = size\n        self.resample = resample\n        self.do_rescale = do_rescale\n        self.rescale_factor = rescale_factor\n        self.do_normalize = do_normalize\n        self.image_mean = image_mean\n        self.image_std = image_std\n        self.do_convert_rgb = do_convert_rgb\n        self.do_pan_and_scan = do_pan_and_scan\n        self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size\n        self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops\n        self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate\n\n    def pan_and_scan(\n        self,\n        image: np.ndarray,\n        pan_and_scan_min_crop_size: int,\n        pan_and_scan_max_num_crops: int,\n        pan_and_scan_min_ratio_to_activate: float,\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n    ):\n        \"\"\"\n        Pan and Scan and image, whatever it means. TODO: write-up docs\n\n        Args:\n            image (`np.ndarray`):\n                Image to resize.\n            pan_and_scan_min_crop_size (`int`):\n                Size of pan_and_scan_min_crop_size.\n            pan_and_scan_max_num_crops (`int`):\n                pan_and_scan_max_num_crops for the image.\n            pan_and_scan_min_ratio_to_activate (`int`):\n                pan_and_scan_min_ratio_to_activate for the image..\n            data_format (`str` or `ChannelDimension`, *optional*):\n                The channel dimension format of the image. If not provided, it will be the same as the input image.\n            input_data_format (`ChannelDimension` or `str`, *optional*):\n                The channel dimension format of the input image. If not provided, it will be inferred.\n        \"\"\"\n        height, width = get_image_size(image)\n\n        # Square or landscape image.\n        if width >= height:\n            # Only apply PaS if the image is sufficiently exaggerated\n            if width / height < pan_and_scan_min_ratio_to_activate:\n                return []\n\n            # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.\n            num_crops_w = int(\n                math.floor(width / height + 0.5)\n            )  # Half round up rounding.\n            num_crops_w = min(\n                int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w\n            )\n\n            # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].\n            num_crops_w = max(2, num_crops_w)\n            num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)\n            num_crops_h = 1\n\n        # Portrait image.\n        else:\n            # Only apply PaS if the image is sufficiently exaggerated\n            if height / width < pan_and_scan_min_ratio_to_activate:\n                return []\n\n            # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.\n            num_crops_h = int(math.floor(height / width + 0.5))\n            num_crops_h = min(\n                int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h\n            )\n\n            # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].\n            num_crops_h = max(2, num_crops_h)\n            num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)\n            num_crops_w = 1\n\n        crop_size_w = int(math.ceil(width / num_crops_w))\n        crop_size_h = int(math.ceil(height / num_crops_h))\n\n        # Don't apply PaS if crop size is too small.\n        if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:\n            return []\n\n        crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]\n        crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]\n\n        if input_data_format == ChannelDimension.LAST:\n            image_crops = [\n                image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]\n                for pos_h, pos_w in itertools.product(\n                    crop_positions_h, crop_positions_w\n                )\n            ]\n        else:\n            image_crops = [\n                image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]\n                for pos_h, pos_w in itertools.product(\n                    crop_positions_h, crop_positions_w\n                )\n            ]\n\n        return image_crops\n\n    def _process_images_for_pas(\n        self,\n        images: List[np.ndarray],\n        do_pan_and_scan: bool,\n        pan_and_scan_min_crop_size: int,\n        pan_and_scan_max_num_crops: int,\n        pan_and_scan_min_ratio_to_activate: float,\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n    ):\n        pas_images_list = []\n        num_crops = []\n        for image in images:\n            pas_images = self.pan_and_scan(\n                image=image,\n                pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,\n                pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,\n                pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,\n                data_format=data_format,\n                input_data_format=input_data_format,\n            )\n            pas_images_list.extend([image] + pas_images)\n            num_crops.append(len(pas_images))\n        return pas_images_list, num_crops\n\n    @filter_out_non_signature_kwargs()\n    def preprocess(\n        self,\n        images: ImageInput,\n        do_resize: bool = None,\n        size: Dict[str, int] = None,\n        resample: PILImageResampling = None,\n        do_rescale: bool = None,\n        rescale_factor: float = None,\n        do_normalize: bool = None,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        return_tensors: Optional[Union[str, TensorType]] = None,\n        data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        do_convert_rgb: bool = True,\n        do_pan_and_scan: bool = None,\n        pan_and_scan_min_crop_size: int = None,\n        pan_and_scan_max_num_crops: int = None,\n        pan_and_scan_min_ratio_to_activate: float = None,\n    ) -> PIL.Image.Image:\n        \"\"\"\n        Preprocess an image or batch of images.\n\n        Args:\n            images (`ImageInput`):\n                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If\n                passing in images with pixel values between 0 and 1, set `do_rescale=False`.\n            do_resize (`bool`, *optional*, defaults to `self.do_resize`):\n                Whether to resize the image.\n            size (`Dict[str, int]`, *optional*, defaults to `self.size`):\n                Size of the image after resizing.\n            resample (`int`, *optional*, defaults to `self.resample`):\n                Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only\n                has an effect if `do_resize` is set to `True`.\n            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):\n                Whether to rescale the image.\n            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):\n                Rescale factor to rescale the image by if `do_rescale` is set to `True`.\n            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):\n                Whether to normalize the image.\n            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):\n                Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.\n            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):\n                Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to\n                `True`.\n            return_tensors (`str` or `TensorType`, *optional*):\n                The type of tensors to return. Can be one of:\n                - Unset: Return a list of `np.ndarray`.\n                - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.\n                - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.\n                - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.\n                - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.\n            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):\n                The channel dimension format for the output image. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n                - Unset: Use the channel dimension format of the input image.\n            input_data_format (`ChannelDimension` or `str`, *optional*):\n                The channel dimension format for the input image. If unset, the channel dimension format is inferred\n                from the input image. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n                - `\"none\"` or `ChannelDimension.NONE`: image in (height, width) format.\n            do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):\n                Whether to convert the image to RGB.\n            do_pan_and_scan (`bool`, *optional*, defaults to `self.do_convert_rgb`):\n                Whether to apply `pan_and_scan` to images.\n        \"\"\"\n        do_resize = do_resize if do_resize is not None else self.do_resize\n        size = size if size is not None else self.size\n        size = get_size_dict(size, param_name=\"size\", default_to_square=False)\n        resample = resample if resample is not None else self.resample\n        do_rescale = do_rescale if do_rescale is not None else self.do_rescale\n        rescale_factor = (\n            rescale_factor if rescale_factor is not None else self.rescale_factor\n        )\n        do_normalize = do_normalize if do_normalize is not None else self.do_normalize\n        image_mean = image_mean if image_mean is not None else self.image_mean\n        image_std = image_std if image_std is not None else self.image_std\n        do_convert_rgb = (\n            do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb\n        )\n        do_pan_and_scan = (\n            do_pan_and_scan if do_pan_and_scan is not None else self.do_pan_and_scan\n        )\n        pan_and_scan_min_crop_size = (\n            pan_and_scan_min_crop_size\n            if pan_and_scan_min_crop_size is not None\n            else self.pan_and_scan_min_crop_size\n        )\n        pan_and_scan_max_num_crops = (\n            pan_and_scan_max_num_crops\n            if pan_and_scan_max_num_crops is not None\n            else self.pan_and_scan_max_num_crops\n        )\n        pan_and_scan_min_ratio_to_activate = (\n            pan_and_scan_min_ratio_to_activate\n            if pan_and_scan_min_ratio_to_activate is not None\n            else self.pan_and_scan_min_ratio_to_activate\n        )\n\n        images_list = make_nested_list_of_images(images)\n\n        if not valid_images(images_list[0]):\n            raise ValueError(\n                \"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, \"\n                \"torch.Tensor, tf.Tensor or jax.ndarray.\"\n            )\n\n        validate_preprocess_arguments(\n            do_rescale=do_rescale,\n            rescale_factor=rescale_factor,\n            do_normalize=do_normalize,\n            image_mean=image_mean,\n            image_std=image_std,\n            do_resize=do_resize,\n            size=size,\n            resample=resample,\n        )\n        if do_convert_rgb:\n            images_list = [\n                [convert_to_rgb(image) for image in images] for images in images_list\n            ]\n\n        # All transformations expect numpy arrays.\n        images_list = [\n            [to_numpy_array(image) for image in images] for images in images_list\n        ]\n\n        if do_rescale and is_scaled_image(images_list[0][0]):\n            logger.warning_once(\n                \"It looks like you are trying to rescale already rescaled images. If the input\"\n                \" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.\"\n            )\n\n        if input_data_format is None:\n            # We assume that all images have the same channel dimension format.\n            input_data_format = infer_channel_dimension_format(images_list[0][0])\n\n        if do_pan_and_scan:\n            images_list_and_num_crops = [\n                self._process_images_for_pas(\n                    images=images,\n                    do_pan_and_scan=do_pan_and_scan,\n                    pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,\n                    pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,\n                    pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,\n                    data_format=data_format,\n                    input_data_format=input_data_format,\n                )\n                for images in images_list\n            ]\n            images_list = [images for images, _ in images_list_and_num_crops]\n            num_crops = [num_crops for _, num_crops in images_list_and_num_crops]\n        else:\n            num_crops = [[0] for images in images_list]\n\n        if do_resize:\n            height, width = size[\"height\"], size[\"width\"]\n            images_list = [\n                [\n                    resize(\n                        image=image,\n                        size=(height, width),\n                        resample=resample,\n                        input_data_format=input_data_format,\n                    )\n                    for image in images\n                ]\n                for images in images_list\n            ]\n\n        if do_rescale:\n            images_list = [\n                [\n                    self.rescale(\n                        image=image,\n                        scale=rescale_factor,\n                        input_data_format=input_data_format,\n                    )\n                    for image in images\n                ]\n                for images in images_list\n            ]\n\n        if do_normalize:\n            images_list = [\n                [\n                    self.normalize(\n                        image=image,\n                        mean=image_mean,\n                        std=image_std,\n                        input_data_format=input_data_format,\n                    )\n                    for image in images\n                ]\n                for images in images_list\n            ]\n\n        images = [\n            to_channel_dimension_format(\n                image, data_format, input_channel_dim=input_data_format\n            )\n            for images in images_list\n            for image in images\n        ]\n\n        data = {\"pixel_values\": images, \"num_crops\": num_crops}\n        return BatchFeature(data=data, tensor_type=return_tensors)\n\n\n__all__ = [\"Gemma3ImageProcessor\"]\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py",
    "content": "# coding=utf-8\n# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.\n#\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport re\nfrom typing import List, Optional, Union\n\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.image_utils import ImageInput\nfrom transformers.processing_utils import (\n    ImagesKwargs,\n    ProcessingKwargs,\n    ProcessorMixin,\n    Unpack,\n)\nfrom transformers.tokenization_utils_base import PreTokenizedInput, TextInput\nfrom transformers.utils import to_py_obj\nfrom text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import (\n    Gemma3ImageProcessor,\n)\n\nfrom transformers.image_utils import PILImageResampling\n\nfrom .utils import make_nested_list_of_images\n\n\nclass Gemma3ImagesKwargs(ImagesKwargs):\n    do_pan_and_scan: Optional[bool]\n    pan_and_scan_min_crop_size: Optional[int]\n    pan_and_scan_max_num_crops: Optional[int]\n    pan_and_scan_min_ratio_to_activate: Optional[float]\n    do_convert_rgb: Optional[bool]\n\n\nclass Gemma3ProcessorKwargs(ProcessingKwargs, total=False):\n    _defaults = {\n        \"text_kwargs\": {\n            \"padding\": False,\n        },\n        \"images_kwargs\": {\n            \"do_pan_and_scan\": False,\n            \"pan_and_scan_min_crop_size\": 256,\n            \"pan_and_scan_max_num_crops\": 4,\n            \"pan_and_scan_min_ratio_to_activate\": 1.2,\n        },\n    }\n\n\nclass Gemma3Processor(ProcessorMixin):\n    attributes = [\"image_processor\", \"tokenizer\"]\n    valid_kwargs = [\"chat_template\"]\n    # # image_processor_class = \"Gemma3ImageProcessor\"\n    image_processor_class = \"AutoProcessor\"\n    tokenizer_class = \"AutoTokenizer\"\n\n    def __init__(\n        self,\n        image_processor,\n        tokenizer,\n        chat_template=None,\n        num_mm_soft_tokens_per_image: int = 256,\n        **kwargs,\n    ):\n        num_mm_soft_tokens_per_image = 256\n        chat_template = None\n\n        image_processor = Gemma3ImageProcessor(\n            image_mean=(127.5,) * 3,\n            image_std=(127.5,) * 3,\n            size={\"height\": 896, \"width\": 896},\n            do_rescale=False,\n            resample=PILImageResampling.BILINEAR,\n        )\n\n        self.image_token_id = tokenizer.image_token_id\n        image_tokens_expanded = \"\".join(\n            [tokenizer.image_token] * num_mm_soft_tokens_per_image\n        )\n        self.full_image_sequence = (\n            f\"\\n\\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\\n\\n\"\n        )\n\n        self.image_processor = image_processor\n        self.tokenizer = tokenizer\n        self.chat_template = chat_template\n\n        # super().__init__(\n        #     image_processor=image_processor,\n        #     tokenizer=tokenizer,\n        #     chat_template=chat_template,\n        #     **kwargs,\n        # )\n\n    def __call__(\n        self,\n        images: ImageInput = None,\n        text: Union[\n            TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]\n        ] = None,\n        videos=None,\n        audio=None,\n        **kwargs: Unpack[Gemma3ProcessorKwargs],\n    ) -> BatchFeature:\n        if text is None and images is None:\n            raise ValueError(\"Provide at least one of `text` or `images`.\")\n\n        output_kwargs = self._merge_kwargs(\n            Gemma3ProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer.init_kwargs,\n            **kwargs,\n        )\n\n        if isinstance(text, str):\n            text = [text]\n        elif not isinstance(text, list) and not isinstance(text[0], str):\n            raise ValueError(\n                \"Invalid input text. Please provide a string, or a list of strings\"\n            )\n\n        image_inputs = {}\n        if images is not None:\n            batched_images = make_nested_list_of_images(images)\n            image_inputs = self.image_processor(\n                batched_images, **output_kwargs[\"images_kwargs\"]\n            )\n\n            # Create empty text to be replaced with placeholders\n            if not text:\n                text = [\n                    \" \".join([\"<image>\"] * len(images)) for images in batched_images\n                ]\n\n            if len(batched_images) != len(text):\n                raise ValueError(\n                    f\"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)}).\"\n                )\n\n            # Replace image tokens by the full expanded sequence\n            batch_num_crops = to_py_obj(image_inputs.pop(\"num_crops\"))\n            for prompt, images, num_crops in zip(text, batched_images, batch_num_crops):\n                image_indexes = [m.start() for m in re.finditer(\"<image>\", prompt)]\n\n                if len(images) != len(image_indexes):\n                    raise ValueError(\n                        f\"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images.\"\n                    )\n\n                # Insert additional image tokens for Pan-and-Scan crops\n                for num, idx in reversed(list(zip(num_crops, image_indexes))):\n                    if num:\n                        formatted_image_text = (\n                            \"Here is the original image <image> and here are some crops to help you see better \"\n                            + \" \".join([\"<image>\"] * num)\n                        )\n                        prompt = (\n                            prompt[:idx]\n                            + formatted_image_text\n                            + prompt[idx + len(\"<image>\") :]\n                        )\n\n            # Expand placeholder image tokens to the full image token sequence\n            text = [\n                prompt.replace(\"<image>\", self.full_image_sequence) for prompt in text\n            ]\n\n        text_input = self.tokenizer(text=text, **output_kwargs[\"text_kwargs\"])\n        return BatchFeature(data={**text_input, **image_inputs})\n\n    # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma\n    def batch_decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.batch_decode(*args, **kwargs)\n\n    # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma\n    def decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to\n        the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.decode(*args, **kwargs)\n\n    @property\n    # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma\n    def model_input_names(self):\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))\n\n\n__all__ = [\"Gemma3Processor\"]\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/gemma3/utils.py",
    "content": "# coding=utf-8\n# Copyright 2025 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Union\n\n\nfrom transformers.image_utils import ImageInput, is_valid_image, is_pil_image\n\n\ndef is_valid_list_of_images(images: List):\n    return images and all(is_valid_image(image) for image in images)\n\n\ndef make_nested_list_of_images(\n    images: Union[List[ImageInput], ImageInput],\n) -> ImageInput:\n    \"\"\"\n    Ensure that the output is a nested list of images.\n    Args:\n        images (`Union[List[ImageInput], ImageInput]`):\n            The input image.\n    Returns:\n        list: A list of list of images or a list of 4d array of images.\n    \"\"\"\n    # If it's a list of batches, it's already in the right format\n    if (\n        isinstance(images, (list, tuple))\n        and all(isinstance(images_i, (list, tuple)) for images_i in images)\n        and all(is_valid_list_of_images(images_i) for images_i in images)\n    ):\n        return images\n\n    # If it's a list of images, it's a single batch, so convert it to a list of lists\n    if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):\n        if is_pil_image(images[0]) or images[0].ndim == 3:\n            return [images]\n        if images[0].ndim == 4:\n            return [list(image) for image in images]\n\n    # If it's a single image, convert it to a list of lists\n    if is_valid_image(images):\n        if is_pil_image(images) or images.ndim == 3:\n            return [[images]]\n        if images.ndim == 4:\n            return [list(images)]\n\n    raise ValueError(\n        \"Invalid input type. Must be a single image, a list of images, or a list of batches of images.\"\n    )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics2.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Idefics2 model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nimport math\n\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n)\nfrom text_generation_server.layers.attention import Seqlen\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Idefics2VisionEmbeddings(nn.Module):\n    \"\"\"\n    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable\n    resolution.\n\n    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)\n    which allows treating images in their native aspect ratio and without the need to resize them to the same\n    fixed size. In particular, we start from the original pre-trained SigLIP model\n    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n\n        self.num_patches_per_side = self.image_size // self.patch_size\n        self.num_patches = self.num_patches_per_side**2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n\n    def forward(\n        self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor\n    ) -> torch.Tensor:\n        batch_size, _, max_im_h, max_im_w = pixel_values.shape\n\n        patch_embeds = self.patch_embedding(pixel_values)\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        max_nb_patches_h, max_nb_patches_w = (\n            max_im_h // self.patch_size,\n            max_im_w // self.patch_size,\n        )\n        boundaries = torch.arange(\n            1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side\n        )\n        position_ids = torch.full(\n            size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0\n        )\n\n        for batch_idx, p_attn_mask in enumerate(patch_attention_mask):\n            nb_patches_h = p_attn_mask[:, 0].sum()\n            nb_patches_w = p_attn_mask[0].sum()\n\n            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)\n            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)\n\n            bucket_coords_h = torch.bucketize(\n                fractional_coords_h, boundaries, right=True\n            )\n            bucket_coords_w = torch.bucketize(\n                fractional_coords_w, boundaries, right=True\n            )\n\n            pos_ids = (\n                bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w\n            ).flatten()\n            position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids\n\n        position_ids = position_ids.to(self.position_embedding.weight.device)\n        embeddings = embeddings + self.position_embedding(position_ids)\n        return embeddings\n\n\nclass Idefics2VisionAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n        self.is_causal = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        batch_size, q_len, _ = hidden_states.size()\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n\n        k_v_seq_len = key_states.shape[-2]\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale\n        )\n\n        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics2VisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Idefics2EncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = Idefics2VisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.mlp = Idefics2VisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n    # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Idefics2Encoder(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                Idefics2EncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    # Ignore copy\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n        return hidden_states\n\n\nclass Idefics2VisionTransformer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embeddings = Idefics2VisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = Idefics2Encoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.post_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        pixel_values,\n        patch_attention_mask: Optional[torch.BoolTensor] = None,\n    ):\n        batch_size = pixel_values.size(0)\n        if patch_attention_mask is None:\n            patch_size = self.config.patch_size\n            patch_attention_mask = torch.ones(\n                (\n                    batch_size,\n                    pixel_values.size(2) // patch_size,\n                    pixel_values.size(3) // patch_size,\n                )\n            )\n            patch_attention_mask = patch_attention_mask.to(\n                dtype=torch.bool, device=pixel_values.device\n            )\n\n        hidden_states = self.embeddings(\n            pixel_values=pixel_values, patch_attention_mask=patch_attention_mask\n        )\n\n        patch_attention_mask = patch_attention_mask.view(batch_size, -1)\n        # The call to `_upad_input` in `_flash_attention_forward` is expensive\n        # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),\n        # avoiding passing the attention_mask, which is equivalent to attending to the full sequence\n        if not torch.any(~patch_attention_mask):\n            patch_attention_mask = None\n        else:\n            patch_attention_mask = _prepare_4d_attention_mask(\n                patch_attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=patch_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        return last_hidden_state\n\n\nclass Idefics2MLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        act = config.text_config.hidden_act\n        self.act = (\n            ACT2FN[act]\n            if \"gelu\" not in act\n            else lambda x: torch.nn.functional.gelu(\n                x,\n                approximate=(\n                    \"tanh\" if act in [\"gelu_fast\", \"gelu_pytorch_tanh\"] else \"none\"\n                ),\n            )\n        )\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(self, hidden_states):\n        start_shape = hidden_states.shape[:-1]\n        gate_up_states = self.gate_up_proj(hidden_states)\n        intermediate_size = gate_up_states.shape[-1] // 2\n        gate_up_states = gate_up_states.view(-1, 2, intermediate_size)\n        return self.down_proj(\n            self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]\n        ).view(*start_shape, -1)\n\n\nclass Idefics2RMSNorm(nn.Module):\n    def __init__(self, prefix, weights, eps):\n        \"\"\"\n        Idefics2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.weight\"), requires_grad=False\n        )\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass Idefics2PerceiverAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.layer_idx = None\n        self.hidden_size = config.text_config.hidden_size\n        self.num_heads = config.perceiver_config.resampler_n_heads\n        self.head_size = config.perceiver_config.resampler_head_dim\n        self.num_key_value_heads = config.perceiver_config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.attention_dropout = config.perceiver_config.attention_dropout\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            self.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.q_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.q_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.kv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.o_proj\", weights=weights, bias=False\n        )\n\n        self.is_causal = False\n\n    def forward(\n        self,\n        latents: torch.Tensor,\n        context: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = latents.size()\n        kv_seq_len = q_len + context.size()[1]\n\n        hidden_states = torch.concat([context, latents], dim=-2)\n        query_states = self.q_proj(latents)\n        kv = self.kv(hidden_states)\n        key_states, value_states = kv.split(\n            [\n                self.head_size * self.num_key_value_heads,\n                self.head_size * self.num_key_value_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, kv_seq_len, self.num_key_value_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, kv_seq_len, self.num_key_value_heads, self.head_size\n        ).transpose(1, 2)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(\n            query_states, key_states.transpose(2, 3)\n        ) / math.sqrt(self.head_size)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics2PerceiverLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.text_config.hidden_size\n        self.n_latents = config.perceiver_config.resampler_n_latents\n        self.depth = config.perceiver_config.resampler_depth\n        self.rms_norm_eps = config.text_config.rms_norm_eps\n\n        self.input_latents_norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.input_latents_norm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.input_context_norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.input_context_norm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.self_attn = Idefics2PerceiverAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.post_attention_layernorm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=self.rms_norm_eps,\n        )\n        self.mlp = Idefics2MLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n\n    def forward(\n        self,\n        latents: torch.Tensor,\n        context: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        \"\"\"\n        Args:\n            latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n        \"\"\"\n        residual = latents\n\n        latents = self.input_latents_norm(latents)\n        context = self.input_context_norm(context)\n\n        latents = self.self_attn(\n            latents=latents,\n            context=context,\n            attention_mask=attention_mask,\n        )\n        latents = residual + latents\n        residual = latents\n\n        latents = self.post_attention_layernorm(latents)\n        latents = self.mlp(latents)\n        latents = residual + latents\n\n        return latents\n\n\nclass Idefics2PerceiverResampler(nn.Module):\n    def __init__(self, prefix, config, weights) -> None:\n        super().__init__()\n        self.hidden_size = config.text_config.hidden_size\n        self.hidden_act = config.perceiver_config.hidden_act\n        self.n_latents = config.perceiver_config.resampler_n_latents\n        self.depth = config.perceiver_config.resampler_depth\n        self.rms_norm_eps = config.text_config.rms_norm_eps\n\n        # Create Latents for Perceiver\n        self.latents = weights.get_tensor(f\"{prefix}.latents\")\n\n        # Create Transformer Blocks\n        self.layers = nn.ModuleList(\n            [\n                Idefics2PerceiverLayer(\n                    prefix=f\"{prefix}.layers.{idx}\", config=config, weights=weights\n                )\n                for idx in range(self.depth)\n            ]\n        )\n        self.norm = Idefics2RMSNorm(\n            prefix=f\"{prefix}.norm\",\n            weights=weights,\n            eps=config.text_config.rms_norm_eps,\n        )\n\n    def forward(\n        self,\n        context: torch.Tensor,\n        attention_mask,\n    ) -> torch.Tensor:\n        # seq embed -> bsz seq embed\n        latents = self.latents.unsqueeze(0).expand(\n            (context.shape[0], *self.latents.size())\n        )\n\n        latent_attention_mask = torch.ones(\n            (attention_mask.size(0), latents.size(1)),\n            dtype=attention_mask.dtype,\n            device=attention_mask.device,\n        )\n        attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)\n        attention_mask = _prepare_4d_attention_mask(\n            attention_mask, latents.dtype, tgt_len=self.n_latents\n        )\n\n        compressed_context = latents\n        for perceiver_layer in self.layers:\n            compressed_context = perceiver_layer(\n                compressed_context,\n                context,\n                attention_mask=attention_mask,\n            )\n        compressed_context = self.norm(compressed_context)\n\n        return compressed_context\n\n\nclass Idefics2Connector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.modality_projection = Idefics2MLP(\n            prefix=f\"{prefix}.modality_projection\", config=config, weights=weights\n        )\n        self.perceiver_resampler = Idefics2PerceiverResampler(\n            prefix=f\"{prefix}.perceiver_resampler\", config=config, weights=weights\n        )\n\n    def forward(self, image_hidden_states, attention_mask):\n        image_hidden_states = self.modality_projection(image_hidden_states)\n        image_hidden_states = self.perceiver_resampler(\n            context=image_hidden_states, attention_mask=attention_mask\n        )\n        return image_hidden_states\n\n\nclass Idefics2ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n\n        vision_config = config.vision_config\n        self.text_model = load_text_model(\n            prefix=\"model\" if not prefix else f\"{prefix}.model\",\n            config=config.text_config,\n            weights=weights,\n            name=\"text_model\",\n        )\n        self.dtype = weights.dtype\n\n        # The vision and connector models are not quantized.\n        with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):\n            self.vision_model = Idefics2VisionTransformer(\n                prefix=(\n                    f\"{prefix}.model.vision_model\" if prefix else \"model.vision_model\"\n                ),\n                config=vision_config,\n                weights=weights,\n            )\n\n            config.quantize = None\n            self.connector = Idefics2Connector(\n                prefix=f\"{prefix}.model.connector\" if prefix else \"model.connector\",\n                config=config,\n                weights=weights,\n            )\n\n        self.config = config\n        self.image_seq_len = config.perceiver_config.resampler_n_latents\n        self.image_token_id = config.image_token_id\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        # mask = input_ids == self.config.image_token_index\n        mask = input_ids == self.config.image_token_id\n        # Let's pray we have enabled enough slots !\n        inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        assert pixel_values is not None\n        batch_size, num_images, num_channels, height, width = pixel_values.shape\n        all_states = []\n        all_pixel_values = pixel_values\n        all_pixel_mask = pixel_attention_mask\n        for i in range(batch_size):\n            pixel_values = all_pixel_values.to(dtype=self.dtype)  # fp16 compatibility\n            pixel_values = pixel_values[i : i + 1]\n            pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])\n\n            # Remove padding images - padding images are full 0.\n            nb_values_per_image = pixel_values.shape[1:].numel()\n            real_images_inds = (pixel_values == 0.0).sum(\n                dim=(-1, -2, -3)\n            ) != nb_values_per_image\n            pixel_values = pixel_values[real_images_inds].contiguous()\n\n            # Handle the vision attention mask\n            if pixel_attention_mask is None:\n                pixel_attention_mask = torch.ones(\n                    size=(\n                        pixel_values.size(0),\n                        pixel_values.size(2),\n                        pixel_values.size(3),\n                    ),\n                    dtype=torch.bool,\n                    device=pixel_values.device,\n                )\n            else:\n                # Remove padding images from the mask/pP p\n                pixel_attention_mask = all_pixel_mask[i : i + 1]\n                pixel_attention_mask = pixel_attention_mask.view(\n                    1 * num_images, *pixel_attention_mask.shape[2:]\n                )\n                pixel_attention_mask = pixel_attention_mask[\n                    real_images_inds\n                ].contiguous()\n\n            patch_size = self.config.vision_config.patch_size\n            patches_subgrid = pixel_attention_mask.unfold(\n                dimension=1, size=patch_size, step=patch_size\n            )\n            patches_subgrid = patches_subgrid.unfold(\n                dimension=2, size=patch_size, step=patch_size\n            )\n            patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()\n\n            # Get sequence from the vision encoder\n            image_hidden_states = self.vision_model(\n                pixel_values=pixel_values,\n                patch_attention_mask=patch_attention_mask,\n            )\n\n            # Modality projection & resampling\n            image_hidden_states = self.connector(\n                image_hidden_states,\n                attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),\n            )\n            all_states.append(image_hidden_states)\n        image_hidden_states = torch.stack(all_states, dim=0)\n        return image_hidden_states.view(-1, image_hidden_states.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        # Unused here\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=None,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics3.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Idefics3 model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n)\nfrom text_generation_server.layers.attention import Seqlen\nfrom transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n)\nfrom text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Idefics3VisionEmbeddings(nn.Module):\n    \"\"\"\n    This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable\n    resolution.\n\n    The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)\n    which allows treating images in their native aspect ratio and without the need to resize them to the same\n    fixed size. In particular, we start from the original pre-trained SigLIP model\n    (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n\n        self.num_patches_per_side = self.image_size // self.patch_size\n        self.num_patches = self.num_patches_per_side**2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n\n    def forward(\n        self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor\n    ) -> torch.Tensor:\n        batch_size, _, max_im_h, max_im_w = pixel_values.shape\n\n        patch_embeds = self.patch_embedding(pixel_values)\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        max_nb_patches_h, max_nb_patches_w = (\n            max_im_h // self.patch_size,\n            max_im_w // self.patch_size,\n        )\n        boundaries = torch.arange(\n            1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side\n        )\n        position_ids = torch.full(\n            size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0\n        )\n\n        for batch_idx, p_attn_mask in enumerate(patch_attention_mask):\n            nb_patches_h = p_attn_mask[:, 0].sum()\n            nb_patches_w = p_attn_mask[0].sum()\n\n            fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)\n            fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)\n\n            bucket_coords_h = torch.bucketize(\n                fractional_coords_h, boundaries, right=True\n            )\n            bucket_coords_w = torch.bucketize(\n                fractional_coords_w, boundaries, right=True\n            )\n\n            pos_ids = (\n                bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w\n            ).flatten()\n            position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids\n\n        position_ids = position_ids.to(self.position_embedding.weight.device)\n        embeddings = embeddings + self.position_embedding(position_ids)\n        return embeddings\n\n\nclass Idefics3VisionAttention(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_size = self.embed_dim // self.num_heads\n        if self.head_size * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_size**-0.5\n        self.dropout = config.attention_dropout\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=True,\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n        self.is_causal = False\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        batch_size, q_len, _ = hidden_states.size()\n\n        qkv = self.qkv(hidden_states)\n        query_states, key_states, value_states = qkv.split(\n            [\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n                self.head_size * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        query_states = query_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            batch_size, q_len, self.num_heads, self.head_size\n        ).transpose(1, 2)\n\n        k_v_seq_len = key_states.shape[-2]\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale\n        )\n\n        if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output\n\n\nclass Idefics3VisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Idefics3EncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = Idefics3VisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", eps=config.layer_norm_eps, weights=weights\n        )\n        self.mlp = Idefics3VisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n    # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Idefics3Encoder(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                Idefics3EncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    # Ignore copy\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for encoder_layer in self.layers:\n            hidden_states = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n        return hidden_states\n\n\nclass Idefics3VisionTransformer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embeddings = Idefics3VisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = Idefics3Encoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n        self.post_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    def forward(\n        self,\n        pixel_values,\n        patch_attention_mask: Optional[torch.BoolTensor] = None,\n    ):\n        batch_size = pixel_values.size(0)\n        if patch_attention_mask is None:\n            patch_size = self.config.patch_size\n            patch_attention_mask = torch.ones(\n                (\n                    batch_size,\n                    pixel_values.size(2) // patch_size,\n                    pixel_values.size(3) // patch_size,\n                )\n            )\n            patch_attention_mask = patch_attention_mask.to(\n                dtype=torch.bool, device=pixel_values.device\n            )\n\n        hidden_states = self.embeddings(\n            pixel_values=pixel_values, patch_attention_mask=patch_attention_mask\n        )\n\n        patch_attention_mask = patch_attention_mask.view(batch_size, -1)\n        # The call to `_upad_input` in `_flash_attention_forward` is expensive\n        # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),\n        # avoiding passing the attention_mask, which is equivalent to attending to the full sequence\n        if not torch.any(~patch_attention_mask):\n            patch_attention_mask = None\n        else:\n            patch_attention_mask = _prepare_4d_attention_mask(\n                patch_attention_mask, hidden_states.dtype\n            )\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            attention_mask=patch_attention_mask,\n        )\n\n        last_hidden_state = encoder_outputs\n        last_hidden_state = self.post_layernorm(last_hidden_state)\n\n        return last_hidden_state\n\n\nclass Idefics3SimpleMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        input_size = config.vision_config.hidden_size * (config.scale_factor**2)\n        output_size = config.text_config.hidden_size\n        proj = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.modality_projection.proj.weight\"),\n            requires_grad=False,\n        ).to(weights.dtype)\n        self.proj = nn.Linear(input_size, output_size, bias=False)\n        self.proj.weight = proj\n\n    def forward(self, x):\n        return self.proj(x)\n\n\nclass Idefics3Connector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)\n        self.scale_factor = config.scale_factor\n\n    def pixel_shuffle(self, x, scale_factor=2):\n        bsz, seq, embed_dim = x.size()\n        height = width = int(seq**0.5)\n        x = x.view(bsz, height, width, embed_dim)\n        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)\n        x = x.permute(0, 2, 1, 3)\n        x = x.reshape(\n            bsz,\n            int(width / scale_factor),\n            int(height / scale_factor),\n            embed_dim * (scale_factor**2),\n        )\n        x = x.permute(0, 2, 1, 3)\n        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))\n        return x\n\n    def forward(self, image_hidden_states):\n        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)\n        image_hidden_states = self.modality_projection(image_hidden_states)\n        return image_hidden_states\n\n\nclass Idefics3ForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        # set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`\n        # since Idefics3 uses the `embed_tokens` for the final prediction\n        # config.text_config.tie_word_embeddings = True\n\n        vision_config = config.vision_config\n        self.text_model = load_text_model(\n            prefix=\"model\" if not prefix else f\"{prefix}.model\",\n            config=config.text_config,\n            weights=weights,\n            name=\"text_model\",\n        )\n        self.dtype = weights.dtype\n\n        # The vision and connector models are not quantized.\n        with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):\n            self.vision_model = Idefics3VisionTransformer(\n                prefix=(\n                    f\"{prefix}.model.vision_model\" if prefix else \"model.vision_model\"\n                ),\n                config=vision_config,\n                weights=weights,\n            )\n\n            config.quantize = None\n            self.connector = Idefics3Connector(\n                prefix=f\"{prefix}.model.connector\" if prefix else \"model.connector\",\n                config=config,\n                weights=weights,\n            )\n\n        self.config = config\n        self.image_token_id = config.image_token_id\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        # mask = input_ids == self.config.image_token_index\n        mask = input_ids == self.config.image_token_id\n        # Let's pray we have enabled enough slots !\n        inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        batch_size, num_images, num_channels, height, width = pixel_values.shape\n        all_states = []\n        all_pixel_values = pixel_values\n        all_pixel_mask = pixel_attention_mask\n        for i in range(batch_size):\n            pixel_values = all_pixel_values.to(dtype=self.dtype)  # fp16 compatibility\n            pixel_values = pixel_values[i : i + 1]\n            pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])\n\n            # Remove padding images - padding images are full 0.\n            nb_values_per_image = pixel_values.shape[1:].numel()\n            real_images_inds = (pixel_values == 0.0).sum(\n                dim=(-1, -2, -3)\n            ) != nb_values_per_image\n            pixel_values = pixel_values[real_images_inds].contiguous()\n            # Handle the vision attention mask\n            if pixel_attention_mask is None:\n                pixel_attention_mask = torch.ones(\n                    size=(\n                        pixel_values.size(0),\n                        pixel_values.size(2),\n                        pixel_values.size(3),\n                    ),\n                    dtype=torch.bool,\n                    device=pixel_values.device,\n                )\n            else:\n                # Remove padding images from the mask/pP p\n                pixel_attention_mask = all_pixel_mask[i : i + 1]\n                pixel_attention_mask = pixel_attention_mask.view(\n                    1 * num_images, *pixel_attention_mask.shape[2:]\n                )\n                pixel_attention_mask = pixel_attention_mask[\n                    real_images_inds\n                ].contiguous()\n\n            patch_size = self.config.vision_config.patch_size\n            patches_subgrid = pixel_attention_mask.unfold(\n                dimension=1, size=patch_size, step=patch_size\n            )\n            patches_subgrid = patches_subgrid.unfold(\n                dimension=2, size=patch_size, step=patch_size\n            )\n            patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()\n\n            # Get sequence from the vision encoder\n            image_hidden_states = self.vision_model(\n                pixel_values=pixel_values,\n                patch_attention_mask=patch_attention_mask,\n            )\n\n            # Modality projection & resampling\n            image_hidden_states = self.connector(\n                image_hidden_states,\n            )\n\n            all_states.append(image_hidden_states)\n        image_hidden_states = torch.stack(all_states, dim=0)\n\n        return image_hidden_states.view(-1, image_hidden_states.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        # Unused here\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=None,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_config.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Idefics model configuration\"\"\"\nimport copy\n\nfrom transformers import PretrainedConfig\n\nIDEFICS_PRETRAINED_CONFIG_ARCHIVE_MAP = {\n    \"HuggingFaceM4/idefics-9b\": \"https://huggingface.co/HuggingFaceM4/idefics-9b/blob/main/config.json\",\n    \"HuggingFaceM4/idefics-80b\": \"https://huggingface.co/HuggingFaceM4/idefics-80b/blob/main/config.json\",\n}\n\n\nclass IdeficsVisionConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an\n    Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of the Idefics-9B.\n    e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        hidden_size (`int`, *optional*, defaults to 768):\n            Dimensionality of the encoder layers and the pooler layer. (elsewhere referred to as `hidden_size`)\n        image_size (`int`, *optional*, defaults to 224):\n            The size (resolution) of each image.\n        intermediate_size (`int`, *optional*, defaults to 5120):\n            Dimensionality of the \"intermediate\" (i.e., feed-forward) layer in the Transformer encoder.\n        patch_size (`int`, *optional*, defaults to 14):\n            The size (resolution) of each patch.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        image_num_channels (`int`, *optional*, defaults to `3`):\n            Number of image channels.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"gelu\"`):\n            The non-linear activation function (function or string) in the encoder and pooler. If string, `\"gelu\"`,\n            `\"relu\"`, `\"selu\"` and `\"gelu_new\"` ``\"quick_gelu\"` are supported.\n        layer_norm_eps (`float`, *optional*, defaults to 1e-5):\n            The epsilon used by the layer normalization layers.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        initializer_factor (`float`, *optional*, defaults to 1.0):\n            A factor for initializing all weight matrices (should be kept to 1.0, used internally for initialization\n            testing).\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n    \"\"\"\n\n    model_type = \"idefics\"\n    attribute_map = {\n        \"hidden_size\": \"embed_dim\",\n    }\n\n    def __init__(\n        self,\n        embed_dim=768,\n        image_size=224,\n        intermediate_size=5120,\n        patch_size=14,\n        num_hidden_layers=32,\n        num_attention_heads=16,\n        num_channels=3,\n        hidden_act=\"gelu\",\n        layer_norm_eps=1e-5,\n        attention_dropout=0.0,\n        initializer_range=0.02,\n        initializer_factor=1.0,\n        **kwargs,\n    ):\n        self.embed_dim = embed_dim\n        self.image_size = image_size\n        self.intermediate_size = intermediate_size\n        self.patch_size = patch_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_channels = num_channels\n        self.layer_norm_eps = layer_norm_eps\n        self.attention_dropout = attention_dropout\n        self.initializer_range = initializer_range\n        self.initializer_factor = initializer_factor\n        self.hidden_act = hidden_act\n\n        super().__init__(**kwargs)\n\n\nclass IdeficsPerceiverConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an\n    Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of the Idefics-9B.\n    e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        use_resampler (`bool`, *optional*, defaults to `False`):\n            Whether or not to use the resampler\n        resampler_n_latents (`int`, *optional*, defaults to ):\n            Number of latent embeddings to resample (\"compress\") the input sequence to (usually < 128).\n        resampler_depth (`int`, *optional*, defaults to 6):\n            Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).\n        resampler_n_heads (`int`, *optional*, defaults to 16):\n            Number of heads in each Transformer block (for multi-headed self-attention).\n        resampler_head_dim (`int`, *optional*, defaults to 96):\n            Dimensionality of each head projection in the Transformer block.\n        qk_layer_norms_perceiver (`bool`, *optional*, defaults to `False`):\n            Whether or not to use qk layer norms in perceiver\n    \"\"\"\n\n    model_type = \"idefics\"\n\n    def __init__(\n        self,\n        use_resampler=False,\n        resampler_n_latents=64,\n        resampler_depth=6,\n        resampler_n_heads=16,\n        resampler_head_dim=96,\n        qk_layer_norms_perceiver=False,\n        **kwargs,\n    ):\n        self.use_resampler = use_resampler\n        self.resampler_n_latents = resampler_n_latents\n        self.resampler_depth = resampler_depth\n        self.resampler_n_heads = resampler_n_heads\n        self.resampler_head_dim = resampler_head_dim\n        self.qk_layer_norms_perceiver = qk_layer_norms_perceiver\n\n        super().__init__(**kwargs)\n\n\nclass IdeficsConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`IdeficsModel`]. It is used to instantiate an\n    Idefics model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of the Idefics-9B.\n    e.g. [HuggingFaceM4/idefics-9b](https://huggingface.co/HuggingFaceM4/idefics-9b)\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        additional_vocab_size (`int`, *optional`, defaults to 0):\n            Additional vocabulary size of the model, typically for the special \"<img>\" token. Additional vocab tokens\n            are always trainable whereas regular vocab tokens can be frozen or not.\n        vocab_size (`int`, *optional*, defaults to 32000):\n            Vocabulary size of the Idefics model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`~IdeficsModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        dropout (`float`, *optional*, defaults to 0.0):\n            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        alpha_initializer (`str`, *optional*, defaults to `\"zeros\"`):\n            Initialization type for the alphas.\n        alphas_initializer_range (`float`, *optional*, defaults to 0.0):\n            The standard deviation of the truncated_normal_initializer for initializing the alphas in the Gated Cross\n            Attention.\n        alpha_type (`str`, *optional*, defaults to `\"float\"`):\n            Whether the gating alphas should be vectors or single floats.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-6):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*, defaults to 0)\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1)\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2)\n            End of stream token id.\n        tie_word_embeddings(`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        cross_layer_interval (`int`, *optional*, default to 1)\n            Interval for cross attention (from text to image) layers.\n        qk_layer_norms (`bool`, *optional*, defaults to `False`): Whether to add layer norm after q and k\n        freeze_text_layers (`bool`, *optional*, defaults to `True`): Whether to freeze text layers\n        freeze_text_module_exceptions (`bool`, *optional*, defaults to `[]`):\n            Exceptions to freezing text layers when `freeze_text_layers` is `True`\n        freeze_lm_head (`bool`, *optional*, defaults to `False`): Whether to freeze lm head\n        freeze_vision_layers (`bool`, *optional*, defaults to `True`):  Whether to freeze vision layers\n        freeze_vision_module_exceptions (`bool`, *optional*, defaults to `[]`):\n            Exceptions to freezing vision layers when `freeze_vision_layers` is `True`\n        use_resampler (`bool`, *optional*, defaults to `False`): Whether to use the Resampler\n        vision_config (`IdeficsVisionConfig`,  *optional*): Custom vision config or dict\n        perceiver_config (`IdeficsPerceiverConfig`,  *optional*): Custom perceiver config or dict\n    Example:\n    ```python\n    >>> from transformers import IdeficsModel, IdeficsConfig\n    >>> # Initializing a Idefics idefics-9b style configuration\n    >>> configuration = IdeficsConfig()\n    >>> # Initializing a model from the idefics-9b style configuration\n    >>> model = IdeficsModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"idefics\"\n    is_composition = True\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        additional_vocab_size=0,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        dropout=0.0,\n        hidden_act=\"silu\",\n        initializer_range=0.02,\n        alpha_initializer=\"zeros\",\n        alphas_initializer_range=0.0,\n        alpha_type=\"float\",\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        tie_word_embeddings=False,\n        cross_layer_interval=1,\n        qk_layer_norms=False,\n        freeze_text_layers=True,\n        freeze_text_module_exceptions=[],\n        freeze_lm_head=False,\n        freeze_vision_layers=True,\n        freeze_vision_module_exceptions=[],\n        use_resampler=False,\n        vision_config=None,\n        perceiver_config=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.additional_vocab_size = additional_vocab_size\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.dropout = dropout\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.alpha_initializer = alpha_initializer\n        self.alphas_initializer_range = alphas_initializer_range\n        self.alpha_type = alpha_type\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n\n        self.cross_layer_interval = cross_layer_interval\n        self.qk_layer_norms = qk_layer_norms\n        self.freeze_vision_layers = freeze_vision_layers\n\n        self.freeze_text_layers = freeze_text_layers\n        self.freeze_text_module_exceptions = freeze_text_module_exceptions\n        self.freeze_vision_module_exceptions = freeze_vision_module_exceptions\n        self.freeze_lm_head = freeze_lm_head\n\n        self.use_resampler = use_resampler\n\n        if perceiver_config is None:\n            self.perceiver_config = IdeficsPerceiverConfig()\n        elif isinstance(perceiver_config, dict):\n            self.perceiver_config = IdeficsPerceiverConfig(**perceiver_config)\n        elif isinstance(perceiver_config, IdeficsPerceiverConfig):\n            self.perceiver_config = perceiver_config\n\n        if vision_config is None:\n            self.vision_config = IdeficsVisionConfig()\n        elif isinstance(vision_config, dict):\n            self.vision_config = IdeficsVisionConfig(**vision_config)\n        elif isinstance(vision_config, IdeficsVisionConfig):\n            self.vision_config = vision_config\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n        # IMPORTANT: Do not do any __init__ args-based checks in the constructor, since\n        # PretrainedConfig.from_dict first instantiates the class with the config dict and only then\n        # updates the config object with `kwargs` from from_pretrained, so during the instantiation\n        # of this object many attributes have default values and haven't yet been overridden.\n        # Do any required checks inside `from_pretrained` once the superclass' `from_pretrained` was run.\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].\n        Returns:\n            `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,\n        \"\"\"\n        output = copy.deepcopy(self.__dict__)\n\n        output[\"vision_config\"] = self.vision_config.to_dict()\n        output[\"perceiver_config\"] = self.perceiver_config.to_dict()\n        output[\"model_type\"] = self.__class__.model_type\n\n        return output\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_image_processing.py",
    "content": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Image processor class for Idefics.\"\"\"\n\nfrom typing import Callable, Dict, List, Optional, Union, Iterable\nimport numpy as np\n\nfrom PIL import Image\n\nimport transformers\nfrom transformers.image_processing_utils import BaseImageProcessor, BatchFeature\nfrom transformers.image_transforms import (\n    resize,\n    to_channel_dimension_format,\n    rescale,\n    normalize,\n)\nfrom transformers.image_utils import (\n    ChannelDimension,\n    ImageInput,\n    PILImageResampling,\n    make_list_of_images,\n    to_numpy_array,\n    valid_images,\n)\nfrom io import BytesIO\nimport base64\nimport requests\nfrom transformers import TensorType, is_torch_available\n\n\nIDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]\nIDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]\n\n\ndef convert_to_rgb(image):\n    # `image.convert(\"RGB\")` would only work for .jpg images, as it creates a wrong background\n    # for transparent images. The call to `alpha_composite` handles this case\n    if image.mode == \"RGB\":\n        return image\n\n    image_rgba = image.convert(\"RGBA\")\n    background = Image.new(\"RGBA\", image_rgba.size, (255, 255, 255))\n    alpha_composite = Image.alpha_composite(background, image_rgba)\n    alpha_composite = alpha_composite.convert(\"RGB\")\n    return alpha_composite\n\n\nclass IdeficsImageProcessor(BaseImageProcessor):\n    r\"\"\"\n    Constructs a Idefics image processor.\n    Args:\n        image_size (`int`, *optional*, defaults to `224`):\n            Resize to image size\n        image_num_channels (`int`, *optional*, defaults to `3`):\n            Number of image channels.\n        image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):\n            Mean to use if normalizing the image. This is a float or list of floats the length of the number of\n            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be\n            overridden by the `image_mean` parameter in the `preprocess` method.\n        image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):\n            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the\n            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.\n            Can be overridden by the `image_std` parameter in the `preprocess` method.\n    \"\"\"\n\n    model_input_names = [\"pixel_values\"]\n\n    def __init__(\n        self,\n        image_size: int = 224,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        image_num_channels: Optional[int] = 3,\n        **kwargs,\n    ) -> None:\n        super().__init__(**kwargs)\n\n        self.image_size = image_size\n        self.image_num_channels = image_num_channels\n        self.image_mean = image_mean\n        self.image_std = image_std\n\n    def preprocess(\n        self,\n        images: ImageInput,\n        image_num_channels: Optional[int] = 3,\n        image_size: Optional[Dict[str, int]] = None,\n        image_mean: Optional[Union[float, List[float]]] = None,\n        image_std: Optional[Union[float, List[float]]] = None,\n        transform: Callable = None,\n        **kwargs,\n    ) -> TensorType.PYTORCH:\n        \"\"\"\n        Preprocess a batch of images.\n        Args:\n            images (`ImageInput`):\n                A list of images to preprocess.\n            image_size (`int`, *optional*, defaults to `self.image_size`):\n                Resize to image size\n            image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):\n                Number of image channels.\n            image_mean (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):\n                Mean to use if normalizing the image. This is a float or list of floats the length of the number of\n                channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can\n                be overridden by the `image_mean` parameter in the `preprocess` method.\n            image_std (`float` or `List[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):\n                Standard deviation to use if normalizing the image. This is a float or list of floats the length of the\n                number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`\n                method. Can be overridden by the `image_std` parameter in the `preprocess` method.\n            transform (`Callable`, *optional*, defaults to `None`):\n                A custom transform function that accepts a single image can be passed for training. For example,\n                `torchvision.Compose` can be used to compose multiple transforms. If `None` - an inference mode is\n                assumed - and then a preset of inference-specific transforms will be applied to the images\n        Returns:\n            a PyTorch tensor of the processed images\n        \"\"\"\n        image_size = image_size if image_size is not None else self.image_size\n        image_num_channels = (\n            image_num_channels\n            if image_num_channels is not None\n            else self.image_num_channels\n        )\n        image_mean = image_mean if image_mean is not None else self.image_mean\n        image_std = image_std if image_std is not None else self.image_std\n        size = (image_size, image_size)\n\n        if len(images) == 0:\n            return []\n\n        images = make_list_of_images(images)\n\n        if not valid_images(images):\n            raise ValueError(\n                \"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, \"\n                \"torch.Tensor, tf.Tensor or jax.ndarray.\"\n            )\n\n        # For training a user needs to pass their own set of transforms as a Callable.\n        # For reference this is what was used in the original IDEFICS training:\n        # transform = transforms.Compose([\n        #     convert_to_rgb,\n        #     transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),\n        #     transforms.ToTensor(),\n        #     transforms.Normalize(mean=image_mean, std=image_std),\n        # ])\n        if transform is not None:\n            if not is_torch_available():\n                raise ImportError(\"To pass in `transform` torch must be installed\")\n            import torch\n\n            images = [transform(x) for x in images]\n            return torch.stack(images)\n\n        # for inference we do the exact transforms that were used to train IDEFICS\n        images = [convert_to_rgb(x) for x in images]\n        # further transforms expect numpy arrays\n        images = [to_numpy_array(x) for x in images]\n        images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]\n        images = [self.rescale(image=image, scale=1 / 255) for image in images]\n        images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]\n        images = [\n            to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images\n        ]\n        # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available\n        images = BatchFeature(\n            data={\"pixel_values\": images}, tensor_type=TensorType.PYTORCH\n        )[\"pixel_values\"]\n\n        return images\n\n    def fetch_images(self, image_url_or_urls: Union[str, List[str]]):\n        \"\"\"\n        Convert a single or a list of urls into the corresponding `PIL.Image` objects.\n        If a single url is passed, the return value will be a single object. If a list is passed a list of objects is\n        returned.\n        \"\"\"\n        headers = {\n            \"User-Agent\": (\n                \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0\"\n                \" Safari/537.36\"\n            )\n        }\n        if isinstance(image_url_or_urls, list):\n            return [self.fetch_images(x) for x in image_url_or_urls]\n        elif isinstance(image_url_or_urls, str):\n            image = image_url_or_urls\n\n            if image.startswith(\"http://\") or image.startswith(\"https://\"):\n                response = requests.get(\n                    image_url_or_urls, stream=True, headers=headers, timeout=(1, 5)\n                )\n                response.raise_for_status()\n                content = response.content\n            elif image.startswith(\"data:\"):\n                # https://stackoverflow.com/questions/17090571/is-there-a-way-to-set-background-image-as-a-base64-encoded-image\n                # data:image/png;base64,xxx\n                image = image.split(\",\")[-1]\n                content = base64.b64decode(image)\n            else:\n                raise ValueError(f\"Unrecognized image {image}\")\n\n            try:\n                image = Image.open(BytesIO(content))\n                # image.verify()\n            except Exception:\n                raise ValueError(f\"Could not load image from url {image_url_or_urls}\")\n            return image\n        else:\n            raise ValueError(\n                f\"only a single or a list of entries is supported but got type={type(image_url_or_urls)}\"\n            )\n\n    def rescale(\n        self,\n        image: np.ndarray,\n        scale: float,\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        **kwargs,\n    ) -> np.ndarray:\n        \"\"\"\n        Rescale an image by a scale factor. image = image * scale.\n\n        Args:\n            image (`np.ndarray`):\n                Image to rescale.\n            scale (`float`):\n                The scaling factor to rescale pixel values by.\n            data_format (`str` or `ChannelDimension`, *optional*):\n                The channel dimension format for the output image. If unset, the channel dimension format of the input\n                image is used. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n            input_data_format (`ChannelDimension` or `str`, *optional*):\n                The channel dimension format for the input image. If unset, the channel dimension format is inferred\n                from the input image. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n\n        Returns:\n            `np.ndarray`: The rescaled image.\n        \"\"\"\n        # return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)\n        # requires 4.32\n        return rescale(image, scale=scale, data_format=data_format, **kwargs)\n\n    def normalize(\n        self,\n        image: np.ndarray,\n        mean: Union[float, Iterable[float]],\n        std: Union[float, Iterable[float]],\n        data_format: Optional[Union[str, ChannelDimension]] = None,\n        input_data_format: Optional[Union[str, ChannelDimension]] = None,\n        **kwargs,\n    ) -> np.ndarray:\n        \"\"\"\n        Normalize an image. image = (image - image_mean) / image_std.\n\n        Args:\n            image (`np.ndarray`):\n                Image to normalize.\n            mean (`float` or `Iterable[float]`):\n                Image mean to use for normalization.\n            std (`float` or `Iterable[float]`):\n                Image standard deviation to use for normalization.\n            data_format (`str` or `ChannelDimension`, *optional*):\n                The channel dimension format for the output image. If unset, the channel dimension format of the input\n                image is used. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n            input_data_format (`ChannelDimension` or `str`, *optional*):\n                The channel dimension format for the input image. If unset, the channel dimension format is inferred\n                from the input image. Can be one of:\n                - `\"channels_first\"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.\n                - `\"channels_last\"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.\n\n        Returns:\n            `np.ndarray`: The normalized image.\n        \"\"\"\n        # TODO 4.32\n        return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)\n\n\ntransformers.IdeficsImageProcessor = IdeficsImageProcessor\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Idefics model.\"\"\"\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers import PreTrainedModel\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    dataclass,\n)\nfrom text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig\nfrom text_generation_server.models.custom_modeling.idefics_vision import (\n    IdeficsVisionTransformer,\n)\nfrom text_generation_server.models.custom_modeling.idefics_perceiver import (\n    IdeficsPerceiverResampler,\n)\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n    FastLinear,\n)\nfrom text_generation_server.layers.rotary import PositionRotaryEmbedding\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom loguru import logger\n\nif SYSTEM == \"cuda\":\n    import dropout_layer_norm\nelif SYSTEM == \"rocm\":\n    import vllm._custom_ops as ops\nelse:\n    dropout_layer_norm = None\n\n\n@dataclass\nclass BaseModelOutputWithPastImage(BaseModelOutputWithPast):\n    image_hidden_states: Optional[torch.FloatTensor] = None\n\n\n@dataclass\nclass CausalLMOutputWithPastImage(CausalLMOutputWithPast):\n    image_hidden_states: Optional[torch.FloatTensor] = None\n\n\n# logger = logging.get_logger(__name__)\n\n# _CONFIG_FOR_DOC = \"IdeficsConfig\"\n\n# IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST = [\n#     \"HuggingFaceM4/idefics-9b\",\n#     \"HuggingFaceM4/idefics-80b\",\n#     # See all Idefics models at https://huggingface.co/models?filter=idefics\n# ]\n\n\ndef expand_inputs_for_generation(\n    input_ids,\n    expand_size=1,\n    is_encoder_decoder=False,\n    attention_mask=None,\n    encoder_outputs=None,\n    **model_kwargs,\n):\n    expanded_return_idx = (\n        torch.arange(input_ids.shape[0])\n        .view(-1, 1)\n        .repeat(1, expand_size)\n        .view(-1)\n        .to(input_ids.device)\n    )\n    input_ids = input_ids.index_select(0, expanded_return_idx)\n\n    if \"token_type_ids\" in model_kwargs:\n        token_type_ids = model_kwargs[\"token_type_ids\"]\n        model_kwargs[\"token_type_ids\"] = token_type_ids.index_select(\n            0, expanded_return_idx\n        )\n\n    if attention_mask is not None:\n        model_kwargs[\"attention_mask\"] = attention_mask.index_select(\n            0, expanded_return_idx\n        )\n        model_kwargs[\"image_attention_mask\"] = model_kwargs[\n            \"image_attention_mask\"\n        ].index_select(0, expanded_return_idx)\n        model_kwargs[\"pixel_values\"] = model_kwargs[\"pixel_values\"].index_select(\n            0, expanded_return_idx\n        )\n\n    if is_encoder_decoder:\n        if encoder_outputs is None:\n            raise ValueError(\n                \"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.\"\n            )\n        encoder_outputs[\"last_hidden_state\"] = (\n            encoder_outputs.last_hidden_state.index_select(\n                0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)\n            )\n        )\n        model_kwargs[\"encoder_outputs\"] = encoder_outputs\n    return input_ids, model_kwargs\n\n\ndef update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):\n    # must have this key set to at least None\n    model_kwargs[\"past_key_values\"] = model_kwargs.get(\"past_key_values\", None)\n\n    # update past\n    if \"past_key_values\" in outputs:\n        model_kwargs[\"past\"] = outputs.past_key_values\n    elif \"mems\" in outputs:\n        model_kwargs[\"past\"] = outputs.mems\n    elif \"past_buckets_states\" in outputs:\n        model_kwargs[\"past\"] = outputs.past_buckets_states\n    else:\n        model_kwargs[\"past\"] = None\n\n    # update token_type_ids with last value\n    if \"token_type_ids\" in model_kwargs:\n        token_type_ids = model_kwargs[\"token_type_ids\"]\n        model_kwargs[\"token_type_ids\"] = torch.cat(\n            [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1\n        )\n\n    # update attention masks\n    if not is_encoder_decoder:\n        if \"attention_mask\" in model_kwargs:\n            attention_mask = model_kwargs[\"attention_mask\"]\n            model_kwargs[\"attention_mask\"] = torch.cat(\n                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],\n                dim=-1,\n            )\n        if \"image_attention_mask\" in model_kwargs:\n            image_attention_mask = model_kwargs[\"image_attention_mask\"]\n            last_mask = image_attention_mask[:, -1, :].unsqueeze(1)\n            model_kwargs[\"image_attention_mask\"] = last_mask\n\n    return model_kwargs\n\n\ndef prepare_inputs_for_generation(input_ids, past=None, **kwargs):\n    token_type_ids = kwargs.get(\"token_type_ids\", None)\n    # only last token for inputs_ids if past is defined in kwargs\n    if past:\n        input_ids = input_ids[:, -1].unsqueeze(-1)\n        if token_type_ids is not None:\n            token_type_ids = token_type_ids[:, -1].unsqueeze(-1)\n\n    attention_mask = kwargs.get(\"attention_mask\", None)\n    position_ids = kwargs.get(\"position_ids\", None)\n\n    if attention_mask is not None and position_ids is None:\n        # create position_ids on the fly for batch generation\n        position_ids = attention_mask.long().cumsum(-1) - 1\n        position_ids.masked_fill_(attention_mask == 0, 1)\n        if past:\n            position_ids = position_ids[:, -1].unsqueeze(-1)\n\n    pixel_values = kwargs.get(\"pixel_values\", None)\n    image_attention_mask = kwargs.get(\"image_attention_mask\", None)\n    # if pixel_values is None or image_attention_mask is None:\n    #     raise ValueError(\"pixel values and image attention mask cannot be None\")\n\n    return {\n        \"input_ids\": input_ids,\n        \"past_key_values\": past,\n        \"use_cache\": kwargs.get(\"use_cache\"),\n        \"position_ids\": position_ids,\n        \"attention_mask\": attention_mask,\n        \"token_type_ids\": token_type_ids,\n        \"pixel_values\": pixel_values,\n        \"image_attention_mask\": image_attention_mask,\n    }\n\n\ndef freeze_model(model, module_exceptions=[]):\n    mapping = {\n        \"LayerNorm\": nn.LayerNorm,\n        \"Linear\": nn.Linear,\n        \"Embedding\": nn.Embedding,\n    }\n    module_exceptions_mapped = [mapping[m] for m in module_exceptions]\n    for module in model.modules():\n        if module_exceptions and any(\n            [isinstance(module, t) for t in module_exceptions_mapped]\n        ):\n            module.requires_grad_(\n                True\n            )  # Explicitely setting it to true to avoid any mistakes\n        else:\n            module.requires_grad_(False)\n    return model\n\n\nclass IdeficsDecoupledPartialTPEmbedding(nn.Module):\n    def __init__(\n        self,\n        config,\n        weights,\n    ):\n        super().__init__()\n        self.num_embeddings = config.vocab_size\n        self.weight = TensorParallelEmbedding(\n            prefix=\"model.embed_tokens\", weights=weights\n        )\n        self.additional_weight = nn.Parameter(\n            weights.get_tensor(\"model.embed_tokens.additional_embedding.weight\")\n        )\n\n    def forward(self, input_ids):\n        # Clone so that we don't modify the original input_ids later on\n        input_ids = input_ids.clone()\n        additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)\n        input_ids_additional_vocab = input_ids[additional_vocab_indices]\n        additional_embeddings = torch.nn.functional.embedding(\n            input_ids_additional_vocab - self.num_embeddings, self.additional_weight\n        )\n\n        # for successful lookup replace input_ids with 0, the results of these will be discarded anyway\n        input_ids[additional_vocab_indices] = 0\n        full_vector = self.weight(input_ids)\n\n        # overwrite the records with high indices\n        full_vector[additional_vocab_indices] = additional_embeddings\n\n        return full_vector\n\n\nclass IdeficsDecoupledTensorParallelLinear(nn.Module):\n    # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear\n    \"\"\"\n    Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the\n    regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,\n    then it will create `out_additional_features * in_features` additional parameters that are always trained. If\n    `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        weights,\n    ) -> None:\n        super().__init__()\n        self.fc = SpeculativeHead.load(config=config, prefix=\"lm_head\", weights=weights)\n        self.additional_fc = FastLinear.load(\n            config=config,\n            prefix=\"lm_head.additional_fc\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        output, speculative_logits = self.fc(input)\n        additional_features = self.additional_fc(input)\n        output = torch.cat((output, additional_features), -1)\n\n        return output, speculative_logits\n\n    def extra_repr(self) -> str:\n        \"\"\"Overwriting `nn.Linear.extra_repr` to include new parameters.\"\"\"\n        return \"in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}\".format(\n            self.in_features,\n            self.out_features,\n            self.out_additional_features,\n            self.bias is not None,\n            self.partially_freeze,\n        )\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n    input_ids_shape: torch.Size,\n    dtype: torch.dtype,\n    device: torch.device,\n    past_key_values_length: int = 0,\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [\n                torch.zeros(\n                    tgt_len, past_key_values_length, dtype=dtype, device=device\n                ),\n                mask,\n            ],\n            dim=-1,\n        )\n    return mask[None, None, :, :].expand(\n        bsz, 1, tgt_len, tgt_len + past_key_values_length\n    )\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(\n        inverted_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n\nclass IdeficsRMSNorm(nn.Module):\n    def __init__(self, prefix, weights, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        self.weight = nn.Parameter(weight)\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states, residual=None):\n        if SYSTEM == \"ipex\":\n            import intel_extension_for_pytorch as ipex\n\n            out = ipex.llm.functional.add_rms_norm(\n                residual,\n                hidden_states,\n                self.weight,\n                None,\n                self.variance_epsilon,\n                residual is not None,\n            )\n            return out\n        elif hidden_states.shape[-1] > 8192:\n            if residual is not None:\n                hidden_states += residual\n            residual = hidden_states\n\n            hidden_states = hidden_states.to(torch.float32)\n            variance = hidden_states.pow(2).mean(-1, keepdim=True)\n            hidden_states = hidden_states * torch.rsqrt(\n                variance + self.variance_epsilon\n            )\n\n            # convert into half-precision if necessary\n            if self.weight.dtype in [torch.float16, torch.bfloat16]:\n                hidden_states = hidden_states.to(self.weight.dtype)\n\n            return self.weight * hidden_states\n        elif SYSTEM == \"cuda\":\n            # faster post attention rms norm\n            unwrap = False\n            if len(hidden_states.shape) > 2:\n                unwrap = True\n                shape = hidden_states.shape\n                hidden_states = hidden_states.reshape(-1, shape[-1])\n\n            normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(\n                hidden_states,\n                residual,\n                self.weight,\n                None,\n                None,\n                None,\n                None,\n                None,\n                0.0,\n                self.variance_epsilon,\n                1.0,\n                0,\n                None,\n                False,\n                True,  # Activate RMSNorm\n            )\n            if res is None:\n                res = hidden_states\n\n            if unwrap:\n                normed_hidden_states = normed_hidden_states.view(*shape)\n\n            return normed_hidden_states\n        elif SYSTEM == \"rocm\":\n            # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.\n            if residual is not None:\n                hidden_states += residual\n            residual = hidden_states\n\n            unwrap = False\n            if len(hidden_states.shape) > 2:\n                unwrap = True\n                shape = hidden_states.shape\n                hidden_states = hidden_states.reshape(-1, shape[-1])\n\n            out = torch.empty_like(hidden_states)\n            ops.rms_norm(\n                out,\n                hidden_states,\n                self.weight.data,\n                self.variance_epsilon,\n            )\n\n            if unwrap:\n                out = out.view(*shape)\n\n            return out\n        else:\n            raise ValueError(\n                \"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.\"\n            )\n\n\n# this was adapted from LlamaMLP\nclass IdeficsMLP(nn.Module):\n    def __init__(\n        self,\n        config,\n        prefix,\n        weights,\n    ):\n        super().__init__()\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states):\n        gate_up_states = self.gate_up_proj(hidden_states)\n        shape = gate_up_states.shape\n        gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)\n        return self.down_proj(\n            self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]\n        )\n\n\n# this was adapted from LlamaAttention\nclass IdeficsAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        config,\n        prefix,\n        weights,\n        qk_layer_norms: bool = False,\n        is_cross_attention: bool = False,\n    ):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.dropout = config.dropout\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.is_cross_attention = is_cross_attention\n\n        # if not hasattr(nn.functional, \"scaled_dot_product_attention\"):\n        #     raise ValueError(\"this model requires pytorch 2.0 or higher\")\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads //= weights.process_group.size()\n\n        if self.is_cross_attention:\n            # kv_input_dim = (\n            #     self.hidden_size if not hasattr(config.vision_config, \"embed_dim\") else config.vision_config.embed_dim\n            # )\n            self.q_proj = TensorParallelColumnLinear.load(\n                config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=False\n            )\n            self.k_proj = TensorParallelColumnLinear.load(\n                config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=False\n            )\n            self.v_proj = TensorParallelColumnLinear.load(\n                config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=False\n            )\n        else:\n            self.qkv = TensorParallelColumnLinear.load_multi(\n                config,\n                prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n                dim=0,\n                weights=weights,\n                bias=False,\n            )\n        self.o_proj = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.o_proj\", weights=weights, bias=False\n        )\n        self.rotary_emb = PositionRotaryEmbedding.static(\n            config=config, dim=self.head_dim, base=10000.0, device=weights.device\n        )\n        self.qk_layer_norms = qk_layer_norms\n        if self.qk_layer_norms:\n            self.q_layer_norm = IdeficsRMSNorm(\n                prefix=f\"{prefix}.q_layer_norm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n            self.k_layer_norm = IdeficsRMSNorm(\n                prefix=f\"{prefix}.q_layer_norm\",\n                weights=weights,\n                eps=config.rms_norm_eps,\n            )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        key_value_states: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # if key_value_states are provided this layer is used as a cross-attention layer\n        is_cross_attention = self.is_cross_attention or key_value_states is not None\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if is_cross_attention:\n            query_states = self.q_proj(hidden_states).view(\n                bsz, q_len, self.num_heads, self.head_dim\n            )  # .transpose(1, 2)\n            query_states = query_states.transpose(1, 2)\n            (\n                _,\n                kv_len,\n                _,\n            ) = (\n                key_value_states.size()\n            )  # Note that, in this case, `kv_len` == `kv_seq_len`\n            key_states = (\n                self.k_proj(key_value_states)\n                .view(bsz, kv_len, self.num_heads, self.head_dim)\n                .transpose(1, 2)\n            )\n            value_states = (\n                self.v_proj(key_value_states)\n                .view(bsz, kv_len, self.num_heads, self.head_dim)\n                .transpose(1, 2)\n            )\n        else:\n            qkv = self.qkv(hidden_states)\n            query_states, key_states, value_states = qkv.split(\n                self.num_heads * self.head_dim, dim=2\n            )\n\n            query_states = query_states.view(\n                bsz, q_len, self.num_heads, self.head_dim\n            )  # .transpose(1, 2)\n            key_states = key_states.view(\n                bsz, q_len, self.num_heads, self.head_dim\n            )  # . transpose(1, 2)\n            value_states = value_states.view(\n                bsz, q_len, self.num_heads, self.head_dim\n            )  # .transpose(1, 2)\n            kv_seq_len = q_len\n            if past_key_value is not None:\n                kv_seq_len += past_key_value[0].shape[-2]\n            max_s = max(kv_seq_len, q_len)\n            cos, sin = self.rotary_emb.get_cos_sin(\n                position_ids.view(-1), max_s, hidden_states.dtype\n            )\n\n            query_shape = query_states.shape\n            key_shape = key_states.shape\n            self.rotary_emb(\n                query_states.view(-1, *query_shape[2:]),\n                key_states.reshape(-1, *key_shape[2:]),\n                cos,\n                sin,\n            )\n\n            query_states = query_states.view(query_shape)\n            key_states = key_states.view(key_shape)\n\n            query_states = query_states.transpose(1, 2)\n            key_states = key_states.transpose(1, 2)\n            value_states = value_states.transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value[0].shape[-2]\n        # [bsz, nh, t, hd]\n\n        if past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n\n        past_key_value = (key_states, value_states) if use_cache else None\n\n        if self.qk_layer_norms:\n            query_states = self.q_layer_norm(query_states)\n            key_states = self.k_layer_norm(key_states)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n\n        attn_output = nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=attention_mask,\n            dropout_p=self.dropout,\n        )\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        attn_weights = None\n        if output_attentions:\n            logger.warning_once(\n                \"attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead\"\n            )\n\n        return attn_output, attn_weights, past_key_value\n\n\n# this was adapted from LlamaDecoderLayer\nclass IdeficsDecoderLayer(nn.Module):\n    def __init__(self, layer_id: int, config: IdeficsConfig, weights):\n        super().__init__()\n        self.process_group = weights.process_group\n        self.hidden_size = config.hidden_size\n        prefix = f\"model.layers.{layer_id}\"\n        self.self_attn = IdeficsAttention(\n            config=config,\n            prefix=f\"{prefix}.self_attn\",\n            weights=weights,\n            qk_layer_norms=False,\n            is_cross_attention=False,\n        )\n        self.mlp = IdeficsMLP(\n            config=config,\n            prefix=f\"{prefix}.mlp\",\n            weights=weights,\n        )\n        self.input_layernorm = IdeficsRMSNorm(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = IdeficsRMSNorm(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.dropout = config.dropout\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n        )\n        # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass IdeficsGatedCrossAttentionLayer(nn.Module):\n    def __init__(self, layer_id, config: IdeficsConfig, weights):\n        super().__init__()\n        self.process_group = weights.process_group\n        self.hidden_size = config.hidden_size\n        prefix = f\"model.gated_cross_attn_layers.{layer_id}\"\n        self.cross_attn = IdeficsAttention(\n            config=config,\n            prefix=f\"{prefix}.cross_attn\",\n            weights=weights,\n            qk_layer_norms=True,\n            is_cross_attention=True,\n        )\n        self.mlp = IdeficsMLP(\n            config=config,\n            prefix=f\"{prefix}.mlp\",\n            weights=weights,\n        )\n        self.input_layernorm = IdeficsRMSNorm(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = IdeficsRMSNorm(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.config = config.dropout\n\n        self.act_cross_attn = nn.Tanh()\n        self.act_dense = nn.Tanh()\n\n        self.alpha_cross_attn = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.alpha_cross_attn\")\n        )\n        self.alpha_dense = nn.Parameter(weights.get_tensor(f\"{prefix}.alpha_dense\"))\n\n        if not (hasattr(self, \"alpha_cross_attn\") and hasattr(self, \"alpha_dense\")):\n            raise ValueError(\"Alpha parameters not initialized correctly!\")\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        image_hidden_states: Optional[torch.Tensor] = None,\n        image_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        no_images: Optional[bool] = False,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            no_images (`bool`, *optional*, defaults to `False`): If `True` the vision part is ignored\n        \"\"\"\n        if image_hidden_states is None:\n            raise ValueError(\n                \"`image_hidden_states` is required for Idefics cross attention module which are visual features to be\"\n                \" conditioned on.\"\n            )\n\n        if past_key_value is not None:\n            raise NotImplementedError(\n                \"Past key value states are not implemented for Idefics cross attention module.\"\n            )\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.cross_attn(\n            hidden_states=hidden_states,\n            key_value_states=image_hidden_states,\n            attention_mask=image_attention_mask,\n            output_attentions=output_attentions,\n        )\n        # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)\n        # when there are no images the model is used in pure language mode\n        gate = 0 if no_images else 1\n        hidden_states = (\n            residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states\n        )\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)\n        hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`IdeficsConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n# @add_start_docstrings(\n#     \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n#     LLAMA_START_DOCSTRING,\n# )\nclass IdeficsPreTrainedModel(PreTrainedModel):\n    config_class = IdeficsConfig\n    # base_model_prefix = \"model\"\n    # supports_gradient_checkpointing = True\n    # _no_split_modules = [\"IdeficsDecoderLayer\", \"IdeficsGatedCrossAttentionLayer\"]\n\n    # def _init_weights(self, module):\n    #     # important: this ported version of Idefics isn't meant for training from scratch - only\n    #     # inference and fine-tuning - so the proper init weights code has been removed - the m4 code\n    #     # base should be used for training from scratch and it contains the correct code.\n    #     std = self.config.initializer_range\n    #     if isinstance(module, nn.Linear):\n    #         module.weight.data.normal_(mean=0.0, std=std)\n    #         if module.bias is not None:\n    #             module.bias.data.zero_()\n    #     elif isinstance(module, nn.Embedding):\n    #         module.weight.data.normal_(mean=0.0, std=std)\n    #         if module.padding_idx is not None:\n    #             module.weight.data[module.padding_idx].zero_()\n\n    # def _set_gradient_checkpointing(self, module, value=False):\n    #     if isinstance(module, IdeficsModel):\n    #         module.gradient_checkpointing = value\n\n\n# LLAMA_INPUTS_DOCSTRING = r\"\"\"\n#     Args:\n#         input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n#             Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n#             it.\n\n#             Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n#             [`PreTrainedTokenizer.__call__`] for details.\n\n#             [What are input IDs?](../glossary#input-ids)\n#         attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n#             Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n#             - 1 for tokens that are **not masked**,\n#             - 0 for tokens that are **masked**.\n\n#             [What are attention masks?](../glossary#attention-mask)\n\n#             Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n#             [`PreTrainedTokenizer.__call__`] for details.\n\n#             If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n#             `past_key_values`).\n\n#             If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n#             and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n#             information on the default strategy.\n\n#             - 1 indicates the head is **not masked**,\n#             - 0 indicates the head is **masked**.\n#         position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n#             Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n#             config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)\n#         past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n#             Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n#             `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n#             `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n#             Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n#             blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n#             If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n#             don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n#             `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n#         inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n#             Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n#             is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n#             model's internal embedding lookup matrix.\n#         use_cache (`bool`, *optional*):\n#             If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n#             `past_key_values`).\n#         output_attentions (`bool`, *optional*):\n#             Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n#             tensors for more detail.\n#         output_hidden_states (`bool`, *optional*):\n#             Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n#             more detail.\n#         return_dict (`bool`, *optional*):\n#             Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n# \"\"\"\n\n\n# @add_start_docstrings(\n#     \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n#     LLAMA_START_DOCSTRING,\n# )\nclass IdeficsModel(IdeficsPreTrainedModel):\n    # \"\"\"\n    # Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]\n\n    # Args:\n    #     config: IdeficsConfig\n    # \"\"\"\n\n    def __init__(self, config: IdeficsConfig, weights):\n        super().__init__(config)\n        self.config = config\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = IdeficsDecoupledPartialTPEmbedding(\n            config=config,\n            weights=weights,\n        )\n\n        self.image_size = config.vision_config.image_size\n        self.vision_config = config.vision_config\n        self.vision_model = IdeficsVisionTransformer(\n            prefix=\"model.vision_model\",\n            config=config.vision_config,\n            weights=weights,\n        )\n\n        # Perceiver Resampler\n        if config.use_resampler:\n            perceiver_config = config.perceiver_config\n            self.perceiver_resampler = IdeficsPerceiverResampler(\n                prefix=\"model.perceiver_resampler\",\n                config=config,\n                embed_dim=config.vision_config.embed_dim,\n                depth=perceiver_config.resampler_depth,\n                n_heads=perceiver_config.resampler_n_heads,\n                head_dim=perceiver_config.resampler_head_dim,\n                n_latents=perceiver_config.resampler_n_latents,\n                weights=weights,\n            )\n\n        self.layers = nn.ModuleList(\n            [\n                IdeficsDecoderLayer(layer_id, config, weights)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n        self.cross_layer_interval = config.cross_layer_interval\n        num_cross_layers = config.num_hidden_layers // self.cross_layer_interval\n        self.gated_cross_attn_layers = nn.ModuleList(\n            [\n                IdeficsGatedCrossAttentionLayer(layer_id, config, weights)\n                for layer_id in range(num_cross_layers)\n            ]\n        )\n        # self.gradient_checkpointing = False\n\n        self.norm = IdeficsRMSNorm(\n            prefix=\"model.norm\", weights=weights, eps=config.rms_norm_eps\n        )\n\n        # self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        # self.post_init()\n\n        # self.freeze_relevant_params(config)\n\n    # def freeze_relevant_params(self, config=None):\n    #     if config is None:\n    #         config = self.config\n\n    #     if config.freeze_text_layers:\n    #         self.freeze_text_layers(config.freeze_text_module_exceptions)\n\n    #     if config.freeze_vision_layers:\n    #         freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)\n\n    # def freeze_text_layers(self, module_exceptions=[]):\n    #     for module in [self.layers, self.norm]:\n    #         freeze_model(module, module_exceptions=module_exceptions)\n\n    # def freeze_vision_layers(self, module_exceptions=[]):\n    #     freeze_model(self.vision_model, module_exceptions=module_exceptions)\n\n    # def get_input_embeddings(self):\n    #     return self.embed_tokens\n\n    # def set_input_embeddings(self, value):\n    #     self.embed_tokens = value\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(\n        self, attention_mask, input_shape, inputs_embeds, past_key_values_length\n    ):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(\n                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]\n            ).to(inputs_embeds.device)\n            combined_attention_mask = (\n                expanded_attn_mask\n                if combined_attention_mask is None\n                else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    # @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        image_hidden_states: Optional[torch.FloatTensor] = None,\n        image_embeddings: Optional[torch.FloatTensor] = None,\n        image_attention_mask: Optional[torch.Tensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPastImage]:\n        device = input_ids.device if input_ids is not None else inputs_embeds.device\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape\n        elif inputs_embeds is not None:\n            batch_size, seq_length, _ = inputs_embeds.shape\n        else:\n            raise ValueError(\n                \"You have to specify either decoder_input_ids or decoder_inputs_embeds\"\n            )\n\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n\n        if past_key_values is not None:\n            past_key_values_length = past_key_values[0][0].shape[2]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n        elif position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length,\n                seq_length + past_key_values_length,\n                dtype=torch.long,\n                device=device,\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        no_images = False\n\n        if image_hidden_states is None:\n            if pixel_values is None and image_embeddings is None:\n                raise ValueError(\n                    \"Either pixel_values and image_embeddings have to be not-None.\"\n                )\n\n            elif pixel_values is not None and image_embeddings is not None:\n                raise ValueError(\n                    \"You cannot specify both pixel_values and image_embeddings at the same time\"\n                )\n\n            elif pixel_values is not None:\n                no_images = len(torch.nonzero(pixel_values)) == 0\n                pixel_values = pixel_values.to(\n                    dtype=self.dtype, device=device\n                )  # fp16 compatibility\n                batch_size, num_images = pixel_values.shape[:2]\n                pixel_values = pixel_values.contiguous().view(\n                    batch_size * num_images, *pixel_values.shape[2:]\n                )\n\n                # Get sequence from the vision encoder\n                image_hidden_states = self.vision_model(\n                    pixel_values=pixel_values\n                ).last_hidden_state\n\n            elif image_embeddings is not None:\n                (\n                    batch_size,\n                    num_images,\n                    image_seq_len,\n                    image_hidden_size,\n                ) = image_embeddings.size()\n                image_hidden_states = image_embeddings.to(\n                    dtype=self.dtype, device=input_ids.device\n                )\n                image_hidden_states = image_hidden_states.view(\n                    batch_size * num_images, image_seq_len, image_hidden_size\n                )\n\n            if self.config.use_resampler:\n                image_hidden_states = self.perceiver_resampler(image_hidden_states)\n            image_seq_len, image_hidden_size = image_hidden_states.size(\n                1\n            ), image_hidden_states.size(2)\n            image_hidden_states = image_hidden_states.view(\n                batch_size, num_images * image_seq_len, image_hidden_size\n            )\n        else:\n            no_images = False\n            num_images = pixel_values.shape[1]\n            image_seq_len = image_hidden_states.shape[1] // num_images\n\n        # # Hack to use the model in full language modeling mode\n        # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)\n        # Make image_attention_mask compatible with hidden states\n        text_seq_len = image_attention_mask.size(1)\n        image_attention_mask = image_attention_mask.unsqueeze(-1)\n        image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)\n        image_attention_mask = image_attention_mask.view(\n            batch_size, text_seq_len, num_images * image_seq_len\n        )\n        image_batch_size, image_sequence_length, _ = image_hidden_states.size()\n        image_hidden_shape = (image_batch_size, image_sequence_length)\n        if image_attention_mask is None:\n            image_attention_mask = torch.ones(image_hidden_shape, device=device)\n        image_attention_mask = self.invert_attention_mask(image_attention_mask)\n\n        # if list(image_attention_mask.shape) != [4, 1, 1024, 64]:\n        #     raise ValueError(f\"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}\")\n\n        # if image_hidden_states is not None:\n        # else:\n        #     image_attention_mask = None\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past),\n                dtype=torch.bool,\n                device=inputs_embeds.device,\n            )\n        attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask,\n            (batch_size, seq_length),\n            inputs_embeds,\n            past_key_values_length,\n        )\n\n        hidden_states = inputs_embeds\n\n        # if self.gradient_checkpointing and self.training:\n        #     if use_cache:\n        #         logger.warning_once(\n        #             \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n        #         )\n        #         use_cache = False\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        for idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            past_key_value = (\n                past_key_values[idx] if past_key_values is not None else None\n            )\n\n            def vblock(\n                main_block,\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                image_hidden_states,\n                image_attention_mask,\n                output_attentions,\n                use_cache,\n                no_images,\n                layer_idx,\n                cross_layer_interval,\n                gated_cross_attn_layers,\n            ):\n                # TODO(ls): Add cross attention values to respective lists\n                if layer_idx % cross_layer_interval == 0:\n                    xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]\n                    outputs = xblock(\n                        hidden_states,\n                        attention_mask=attention_mask,\n                        image_hidden_states=image_hidden_states,\n                        image_attention_mask=image_attention_mask,\n                        output_attentions=output_attentions,\n                        use_cache=use_cache,\n                        past_key_value=None,  # not implemented\n                        no_images=no_images,\n                    )\n                    hidden_states = outputs[0]\n\n                layer_outputs = main_block(\n                    hidden_states,\n                    attention_mask=attention_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_value,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                )\n\n                return layer_outputs\n\n            # if self.gradient_checkpointing and self.training:\n            #     past_key_value = None\n            #     if use_cache:\n            #         logger.warning_once(\n            #             \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n            #         )\n            #         use_cache = False\n\n            #     layer_outputs = torch.utils.checkpoint.checkpoint(\n            #         vblock,\n            #         decoder_layer,\n            #         hidden_states,\n            #         attention_mask,\n            #         position_ids,\n            #         past_key_value,\n            #         image_hidden_states,\n            #         image_attention_mask,\n            #         output_attentions,\n            #         use_cache,\n            #         no_images,\n            #         idx,\n            #         self.cross_layer_interval,\n            #         self.gated_cross_attn_layers,\n            #     )\n            # else:\n            layer_outputs = vblock(\n                decoder_layer,\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                image_hidden_states=image_hidden_states,\n                image_attention_mask=image_attention_mask,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                no_images=no_images,\n                layer_idx=idx,\n                cross_layer_interval=self.cross_layer_interval,\n                gated_cross_attn_layers=self.gated_cross_attn_layers,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPastImage(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            image_hidden_states=image_hidden_states,\n        )\n\n\nclass IdeficsForVisionText2Text(IdeficsPreTrainedModel):\n    def __init__(\n        self,\n        config,\n        weights,\n    ):\n        super().__init__(config)\n        self.model = IdeficsModel(\n            config=config,\n            weights=weights,\n        )\n\n        self.lm_head = IdeficsDecoupledTensorParallelLinear(\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        image_embeddings: Optional[torch.FloatTensor] = None,\n        image_hidden_states: Optional[torch.FloatTensor] = None,\n        image_attention_mask: Optional[torch.Tensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPastImage]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you consciours? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you consciours? Can you talk to me?\\nI'm not consciours, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            pixel_values=pixel_values,\n            image_embeddings=image_embeddings,\n            image_hidden_states=image_hidden_states,\n            image_attention_mask=image_attention_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        loss = None\n\n        return (\n            CausalLMOutputWithPastImage(\n                loss=loss,\n                logits=logits,\n                past_key_values=outputs.past_key_values,\n                hidden_states=outputs.hidden_states,\n                attentions=outputs.attentions,\n                image_hidden_states=outputs.image_hidden_states,\n            ),\n            speculative_logits,\n        )\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):\n        inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)\n        unwanted_kwargs = [\"token_type_ids\"]\n        for kwarg in unwanted_kwargs:\n            inputs.pop(kwarg, None)\n        return inputs\n\n    @staticmethod\n    def _expand_inputs_for_generation(\n        *args,\n        **model_kwargs,\n    ):\n        return expand_inputs_for_generation(*args, **model_kwargs)\n\n    @staticmethod\n    def _update_model_kwargs_for_generation(\n        outputs, model_kwargs, is_encoder_decoder=False\n    ):\n        return update_model_kwargs_for_generation(\n            outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder\n        )\n\n    @staticmethod\n    def _reorder_cache(past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_perceiver.py",
    "content": "# This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.\n#\n# MIT License\n#\n# Copyright (c) 2020  The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n\n\"\"\"\n\nGeneric interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially\ntime-indexed) contextual embeddings, and \"resamples\" (compresses) them down to a pre-specified number of latents! Note\nthat the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to\nprime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that\nto softly \"retrieve & compress\" what we need --> this would be a novel contribution we should explore.\n\nReferences:\n    - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model\n    - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch\n\n\"\"\"\nfrom typing import Optional, Tuple\n\nimport torch\nimport torch.nn as nn\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\nEPS = 1e-5\n\n\nclass IdeficsPerceiverResampler(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config,\n        embed_dim: int,\n        depth: int,\n        n_heads: int,\n        head_dim: int,\n        n_latents: int,\n        weights,\n    ) -> None:\n        \"\"\"\n        Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or\n        MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then\n        returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed\n        to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.\n        Could be e.g., VIT embed_dim, ResNet pool dim, and so on.\n\n        Args:\n            config (`IdeficsConfig`): config object\n            embed_dim (`int`): The size of each embedding vector\n            depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).\n            n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).\n            head_dim (`int`): Dimensionality of each head projection in the Transformer block.\n            n_latents (`int`):\n                Number of latent embeddings to resample (\"compress\") the input sequence to (usually < 128).\n\n        \"\"\"\n        super().__init__()\n        self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (\n            embed_dim,\n            n_heads,\n            head_dim,\n            n_latents,\n        )\n        self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver\n\n        # Create Latents for Perceiver\n        self.latents = nn.Parameter(weights.get_tensor(f\"{prefix}.latents\"))\n\n        self.intermediate_dim = (\n            self.embed_dim * 4\n            if not hasattr(config.vision_config, \"embed_dim\")\n            else config.vision_config.embed_dim * 4\n        )\n        # Create Transformer Blocks\n        self.blocks = nn.ModuleList(\n            [\n                nn.ModuleList(\n                    [\n                        IdeficsPerceiverAttention(\n                            prefix=f\"{prefix}.blocks.{layer_id}.0\",\n                            config=config,\n                            embed_dim=self.embed_dim,\n                            n_heads=self.n_heads,\n                            head_dim=self.head_dim,\n                            qk_layer_norms=self.qk_layer_norms,\n                            weights=weights,\n                        ),\n                        IdeficsMLP(\n                            prefix=f\"{prefix}.blocks.{layer_id}.1\",\n                            intermediate_size=self.intermediate_dim,\n                            config=config,\n                            weights=weights,\n                        ),\n                    ]\n                )\n                for layer_id in range(depth)\n            ]\n        )\n        self.layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm\", weights=weights, eps=EPS\n        )\n\n    def forward(self, context: torch.Tensor) -> torch.Tensor:\n        \"\"\"Resample arbitrary length context & *compress* down to self.n_latents latent embeddings\"\"\"\n        # einsum.repeat(self.latents, \"seq embed -> bsz seq embed\", bsz=context.shape[0])\n        latents = self.latents.repeat(context.shape[0], 1, 1)\n\n        # Feed through Perceiver Attention blocks...\n        for attn, ff in self.blocks:\n            latents = attn(context, latents) + latents\n            latents = ff(latents) + latents\n\n        return self.layer_norm(latents)\n\n\nclass IdeficsPerceiverAttention(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        config,\n        embed_dim: int,\n        n_heads: int,\n        head_dim: int,\n        qk_layer_norms: bool,\n        weights,\n    ) -> None:\n        \"\"\"Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`\"\"\"\n        super().__init__()\n        self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim\n        self.qk_layer_norms = qk_layer_norms\n        # Normalization & Scaling\n        self.context_layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.context_layer_norm\", weights=weights, eps=EPS\n        )\n        self.latents_layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.latents_layer_norm\", weights=weights, eps=EPS\n        )\n        if self.qk_layer_norms:\n            self.q_layer_norm = nn.LayerNorm.load(\n                prefix=f\"{prefix}.q_layer_norm\", weights=weights, eps=EPS\n            )\n            self.k_layer_norm = nn.LayerNorm.load(\n                prefix=f\"{prefix}.k_layer_norm\", weights=weights, eps=EPS\n            )\n\n        self.qk_scale = self.head_dim**-0.5\n\n        if n_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.n_heads //= weights.process_group.size()\n\n        # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).\n        self.q_proj = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=False\n        )\n        self.k_proj = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=False\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config=config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=False\n        )\n\n        self.output_proj = TensorParallelRowLinear.load(\n            config=config, prefix=f\"{prefix}.output_proj\", weights=weights, bias=False\n        )\n\n    def forward(self, context: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!\n\n        Args:\n            context (`torch.Tensor`):\n                Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.\n            latents (`torch.Tensor`):\n                Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.\n\n        Returns:\n            `torch.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross\n            from context.\n        \"\"\"\n        context = self.context_layer_norm(context)\n        latents = self.latents_layer_norm(latents)\n        batch_size, seq_length, embed_dim = context.shape[:3]\n\n        # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!\n        #   Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`\n        q = self.q_proj(latents)\n        k = self.k_proj(torch.cat([context, latents], dim=-2))\n        v = self.v_proj(torch.cat([context, latents], dim=-2))\n\n        # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)\n        #   =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]\n        # einsum.rearrange(x, \"bsz seq (heads embed) -> bsz heads seq embed\", heads=self.n_heads)\n        q, k, v = [\n            x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(\n                1, 2\n            )\n            for x in (q, k, v)\n        ]\n\n        if self.qk_layer_norms:\n            q = self.q_layer_norm(q)\n            k = self.k_layer_norm(k)\n\n        scores = torch.einsum(\"... i d, ... j d -> ... i j\", q * self.qk_scale, k)\n        stabilized_scores = scores - (scores.amax(dim=-1, keepdim=True).detach())\n        attn = stabilized_scores.softmax(dim=-1)\n\n        # Attend & project back to output...\n        resampled = torch.einsum(\"... i j, ... j d -> ... i d\", attn, v)\n        # einsum.rearrange(resampled, \"bsz heads seq embed -> bsz seq (heads embed)\", heads=self.n_heads)\n        return self.output_proj(resampled.transpose(1, 2).flatten(-2))\n\n\nclass IdeficsMLP(nn.Module):\n    def __init__(\n        self,\n        prefix,\n        intermediate_size,\n        config,\n        weights,\n    ):\n        \"\"\"Simple MLP block with intermediate_size and embedding size\"\"\"\n        super().__init__()\n        self.embed_dim = config.vision_config.embed_dim\n        self.ln = nn.LayerNorm.load(prefix=f\"{prefix}.ln\", weights=weights, eps=EPS)\n        self.fc = TensorParallelColumnLinear.load(\n            config=config,\n            prefix=f\"{prefix}.fc\",\n            weights=weights,\n            bias=False,\n        )\n        self.act = nn.ReLU()\n        self.c_proj = TensorParallelRowLinear.load(\n            config=config,\n            prefix=f\"{prefix}.c_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(\n        self, hidden_states: Optional[Tuple[torch.FloatTensor]]\n    ) -> torch.FloatTensor:\n        hidden_states = self.ln(hidden_states)\n        hidden_states = self.fc(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n\n        return hidden_states\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_processing.py",
    "content": "# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nProcessor class for IDEFICS.\n\"\"\"\n\nfrom typing import Callable, List, Optional, Union\nfrom urllib.parse import urlparse\n\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.processing_utils import ProcessorMixin\nfrom transformers.tokenization_utils_base import (\n    BatchEncoding,\n    PaddingStrategy,\n    TextInput,\n    TruncationStrategy,\n)\nfrom transformers.utils import TensorType, is_torch_available\n\n\nif is_torch_available():\n    import torch\n\n\nIMAGE_TOKEN = \"<image>\"\n\n\n# copied from m4.training.packing\ndef incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):\n    # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]\n\n    # If any of images index are more than num_classes, set them to -1.\n    # Words after the max number of images allowed have been seen don't attend on anything\n    if num_classes != -1:\n        incremental_mask[incremental_mask >= num_classes] = -1\n\n    negatives = incremental_mask == -1\n    incremental_mask[negatives] = 0\n    attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)\n    attn_mask[negatives, :] = 0\n    return attn_mask\n\n\n# copied from m4.training.packing\ndef image_attention_mask_for_packed_input_ids(input_ids, tokenizer):\n    image_attention_mask = torch.full_like(input_ids, fill_value=-1)\n    next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)\n    image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)\n    eod_token_id = tokenizer.eos_token_id\n    for batch_idx in range(input_ids.size(0)):\n        count = -1\n        seen_eod = False\n        for idx, token_id in enumerate(input_ids[batch_idx]):\n            if token_id == image_token_id:\n                count += 1\n                image_attention_mask[batch_idx][idx] = count\n                seen_eod = False\n            else:\n                image_attention_mask[batch_idx][idx] = count\n\n            if seen_eod:\n                image_attention_mask[batch_idx][idx] = -1\n\n            if token_id == eod_token_id:\n                seen_eod = True\n\n    for batch_idx in range(input_ids.size(0)):\n        count = -1\n        seen_eod = False\n        for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):\n            token_id = input_ids[batch_idx][idx]\n            if token_id == image_token_id:\n                count += 1\n                next_image_attention_mask[batch_idx][idx] = count\n                seen_eod = False\n            else:\n                next_image_attention_mask[batch_idx][idx] = count\n\n            if token_id == eod_token_id:\n                seen_eod = True\n\n            if seen_eod:\n                next_image_attention_mask[batch_idx][idx] = -1\n\n        non_negative_indices = next_image_attention_mask[batch_idx] != -1\n        next_image_attention_mask[batch_idx][non_negative_indices] -= count\n        next_image_attention_mask[batch_idx][non_negative_indices] *= -1\n\n    return image_attention_mask, next_image_attention_mask\n\n\ndef is_url(string):\n    \"\"\"Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately\n    invalidated the url\"\"\"\n    if \" \" in string:\n        return False\n    result = urlparse(string)\n    return all([result.scheme, result.netloc])\n\n\ndef is_image(string):\n    \"\"\"Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately\n    invalidated the url\"\"\"\n    return is_url(string) or string.startswith(\"data:\")\n\n\nclass IdeficsProcessor(ProcessorMixin):\n    r\"\"\"\n    Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.\n\n    [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See\n    the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.\n\n    Args:\n        image_processor (`IdeficsImageProcessor`):\n            An instance of [`IdeficsImageProcessor`]. The image processor is a required input.\n        tokenizer (`LlamaTokenizerFast`):\n            An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.\n        image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)\n    \"\"\"\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n    image_processor_class = \"IdeficsImageProcessor\"\n    tokenizer_class = \"LlamaTokenizerFast\"\n\n    def __init__(\n        self,\n        image_processor,\n        tokenizer=None,\n        image_size=224,\n        add_end_of_utterance_token=None,\n        **kwargs,\n    ):\n        if image_processor is None:\n            raise ValueError(\"You need to specify an `image_processor`.\")\n        if tokenizer is None:\n            raise ValueError(\"You need to specify a `tokenizer`.\")\n\n        super().__init__(image_processor, tokenizer)\n        self.current_processor = self.image_processor\n        self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)\n\n        self.default_image_dims = (\n            self.image_processor.image_num_channels,\n            self.image_processor.image_size,\n            self.image_processor.image_size,\n        )\n\n        self.tokenizer_was_trained_with_end_of_utterance_token = (\n            True\n            if \"<end_of_utterance>\"\n            in self.tokenizer.special_tokens_map.get(\"additional_special_tokens\", [])\n            else False\n        )\n\n    def __call__(\n        self,\n        prompts: Union[List[TextInput], List[List[TextInput]]],\n        padding: Union[bool, str, PaddingStrategy] = False,\n        truncation: Union[bool, str, TruncationStrategy] = None,\n        max_length: Optional[int] = None,\n        transform: Callable = None,\n        add_eos_token=False,\n        add_end_of_utterance_token=None,\n        debug=False,\n        return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,\n    ) -> BatchEncoding:\n        \"\"\"This method takes batched or non-batched prompts made of text and images and converts them into prompts that\n        the model was trained on and prepares the image pixel values for the model to process.\n\n        Args:\n            prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):\n                either a single prompt or a batched list of prompts - see the detailed description immediately after\n                the end of the arguments doc section.\n            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):\n                Select a strategy to pad the returned sequences (according to the model's padding side and padding\n                index) among:\n                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n                  sequence if provided).\n                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum\n                  acceptable input length for the model if that argument is not provided.\n                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different\n                  lengths).\n            max_length (`int`, *optional*):\n                Maximum length of the returned list and optionally padding length (see above).\n            truncation (`bool`, *optional*):\n                Activates truncation to cut input sequences longer than `max_length` to `max_length`.\n            transform (`Callable`, *optional*):\n                A custom transform function that accepts a single image can be passed for training. For example,\n                `torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific\n                set of transforms will be applied to the images\n            add_eos_token (`bool`, *optional*, defaults to `False`):\n                Adds `eos_token` at the end of the final prompt if True`\n            add_end_of_utterance_token (`bool`, *optional*)\n                Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an\n                image). If `None` the tokenizer will be checked instead and if this token is found in\n                `additional_special_tokens` then the value will be `True`.\n            debug (`bool`, *optional*, defaults to `False`):\n                `True` value will help debug prompt generation by dumping useful information\n            return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):\n                The type of tensors to return. Can be one of:\n                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.\n\n        Returns:\n            a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be\n            directly passed to `model.generate`\n\n        Detailed explanation:\n\n        Each entry in `prompts` is either a text to be passed as is or an image that will be processed.\n\n        An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.\n\n        When the processor encounters an image it'll inject `<fake_token_around_image><image><fake_token_around_image>`\n        entry into the prompt.\n\n        Example:\n\n        ```python\n        checkpoint = \"HuggingFaceM4/idefics-9b\"\n        processor = AutoProcessor.from_pretrained(checkpoint)\n        url = \"https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg\"\n        img = processor.image_processor.fetch_images([url])[0]\n\n        prompts = [\n            \"User:\",\n            img,\n            \"Describe this image.\\nAssistant: An image of two kittens in grass.\\n\",\n            \"User:\",\n            \"https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg\",\n            \"Describe this image.\\nAssistant:\",\n        ]\n\n        inputs = processor(prompts, return_tensors=\"pt\")\n        generated_ids = model.generate(**inputs, max_length=100)\n        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]\n        ```\n\n        In this example the `prompts` will be converted into:\n\n        ```\n        <s>User:<fake_token_around_image><image><fake_token_around_image>Describe this image.\n        Assistant: An image of two kittens in grass.\n        User:<fake_token_around_image><image><fake_token_around_image>Describe this image.\n        Assistant:'\n        ```\n\n        and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the\n        `pixel_values` dict entry of the return value.\n\n        This example also examplifies that images can be passed as objects or as text urls. It can be seen that the\n        first image is passed as object and the second one as a url.\n\n        To do training do:\n\n        ```python\n        image_transform = transforms.Compose(\n            [\n                transforms.RandomResizedCrop(\n                    (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC\n                ),\n                transforms.ToTensor(),\n                transforms.Normalize(mean=self.image_mean, std=self.image_std),\n            ]\n        )\n        inputs = processor(prompts, transform=image_transform, return_tensors=\"pt\")\n        ```\n\n        In order to help debug prompt generation enable `debug=True` which will show you what's happening.\n\n        \"\"\"\n\n        # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it\n        if add_end_of_utterance_token is None:\n            add_end_of_utterance_token = (\n                self.tokenizer_was_trained_with_end_of_utterance_token\n            )\n\n        # turn non-batched prompts into batched\n        if not any(isinstance(i, list) for i in prompts):\n            prompts = [prompts]\n\n        fake_token = \"<fake_token_around_image>\"\n        image_token = \"<image>\"\n        end_of_utterance_token = \"<end_of_utterance>\"\n\n        def image_tokens(last_was_image):\n            if last_was_image:\n                return image_token + fake_token\n            else:\n                return fake_token + image_token + fake_token\n\n        all_texts = []\n        all_images = []\n        for sample in prompts:\n            # the model was trained on samples starting with <s>\n            full_text = f\"{self.tokenizer.bos_token}\"\n\n            # an image can either be an image object in the item or the url, everything else is a verbatim prompt text\n            image_objects = []\n            last_was_image = False\n            last_was_text = False\n            for i, item in enumerate(sample):\n                if i > 0:\n                    last_was_text = True if not last_was_image else False\n\n                if isinstance(item, str):\n                    item = item.strip(\" \")\n                    if is_image(item):\n                        image = self.image_processor.fetch_images(item)\n                        full_text += image_tokens(last_was_image)\n                        image_objects.append(image)\n                        last_was_image = True\n                    else:\n                        # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)\n                        if add_end_of_utterance_token and last_was_text:\n                            full_text += end_of_utterance_token\n                        full_text += item\n                        last_was_image = False\n                else:\n                    # must be an image obj\n                    full_text += image_tokens(last_was_image)\n                    image_objects.append(item)\n                    last_was_image = True\n\n            if add_eos_token:\n                full_text += self.tokenizer.eos_token\n\n            if debug is True:\n                print(f\"{full_text=}\")\n\n            image_objects = self.image_processor(image_objects, transform=transform)\n\n            text_encoding = self.tokenizer(\n                text=full_text,\n                add_special_tokens=False,\n                padding=padding,\n                truncation=truncation,\n                max_length=max_length,\n            )\n\n            all_texts.append(text_encoding[\"input_ids\"])\n            all_images.append(image_objects)\n\n        max_seq_len = max(len(x) for x in all_texts)\n\n        # max_num_images has to be at least 1 even when there are no images\n        max_num_images = max(len(x) for x in all_images)\n        max_num_images = max(1, max_num_images)\n\n        at_least_one_image = sum(len(x) for x in all_images) > 0\n        output_input_ids = []\n        output_images = []\n        output_attention_masks = []\n        for text, images in zip(all_texts, all_images):\n            padded_input_ids = [self.tokenizer.pad_token_id] * max_seq_len\n            unpadded_seq_len = len(text)\n            start = max_seq_len - unpadded_seq_len\n            padded_input_ids[start:] = text[:max_seq_len]\n\n            attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)\n            attention_mask[start:] = 1\n\n            image_count = padded_input_ids.count(self.image_token_id)\n            local_max_num_images = min(image_count, max_num_images)\n\n            current_images = images[:local_max_num_images]\n\n            if len(current_images) > 0:\n                padded_image_tensor = torch.zeros(\n                    max_num_images, *current_images.size()[1:]\n                )\n                padded_image_tensor[: current_images.size(0)] = current_images\n            else:\n                padded_image_tensor = torch.zeros(\n                    max_num_images, *self.default_image_dims\n                )\n\n            output_images.append(padded_image_tensor)\n            output_input_ids.append(torch.tensor(padded_input_ids))\n\n            output_attention_masks.append(attention_mask)\n\n        output_input_ids = torch.stack(output_input_ids)\n        output_images = torch.stack(output_images)\n        output_attention_masks = torch.stack(output_attention_masks)\n\n        if at_least_one_image:\n            image_attention_mask, _ = image_attention_mask_for_packed_input_ids(\n                output_input_ids, self.tokenizer\n            )\n            image_attention_mask = incremental_to_binary_attention_mask(\n                image_attention_mask, num_classes=max_num_images\n            )\n        else:\n            # in full language mode we set the image mask to all-0s\n            image_attention_mask = torch.zeros(\n                output_input_ids.shape[0],\n                output_input_ids.shape[1],\n                1,\n                dtype=torch.bool,\n            )\n\n        return BatchFeature(\n            data={\n                \"input_ids\": output_input_ids,\n                \"attention_mask\": output_attention_masks,\n                \"pixel_values\": output_images,\n                \"image_attention_mask\": image_attention_mask,\n            }\n        )\n\n    def batch_decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.batch_decode(*args, **kwargs)\n\n    def decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to\n        the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.decode(*args, **kwargs)\n\n    @property\n    def model_input_names(self):\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/idefics_vision.py",
    "content": "# coding=utf-8\n# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch IdeficsVision model: a copy of CLIPVisionModel using a simpler config object\"\"\"\n\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling\nfrom transformers.utils import (\n    ModelOutput,\n    logging,\n)\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n)\n\nlogger = logging.get_logger(__name__)\n\n\n@dataclass\nclass IdeficsVisionModelOutput(ModelOutput):\n    \"\"\"\n    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.\n\n    Args:\n        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):\n            The image embeddings obtained by applying the projection layer to the pooler_output.\n        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n            Sequence of hidden-states at the output of the last layer of the model.\n        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):\n            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +\n            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.\n\n            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.\n        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):\n            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n            sequence_length)`.\n\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    image_embeds: Optional[torch.FloatTensor] = None\n    last_hidden_state: torch.FloatTensor = None\n    hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Idefics\nclass IdeficsVisionEmbeddings(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n\n        self.class_embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.class_embedding\")\n        )\n\n        self.patch_embedding = nn.Conv2d.load_no_bias(\n            prefix=f\"{prefix}.patch_embedding\",\n            weights=weights,\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=\"model.vision_model.embeddings.position_embedding\", weights=weights\n        )\n        self.position_ids = (\n            torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)\n        )\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        batch_size = pixel_values.shape[0]\n        target_dtype = self.patch_embedding.weight.dtype\n        patch_embeds = self.patch_embedding(\n            pixel_values.to(dtype=target_dtype)\n        )  # shape = [*, width, grid, grid]\n        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)\n\n        class_embeds = self.class_embedding.expand(batch_size, 1, -1)\n        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->IdeficsVision\nclass IdeficsVisionAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n\n        self.k_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=True\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=True\n        )\n        self.q_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=True\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n        query_states = self.q_proj(hidden_states) * self.scale\n        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        # apply the causal_attention_mask first\n        if causal_attention_mask is not None:\n            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is\"\n                    f\" {causal_attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + causal_attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if output_attentions:\n            # this operation is a bit akward, but it's required to\n            # make sure that attn_weights keeps its gradient.\n            # In order to do so, attn_weights have to reshaped\n            # twice and have to be reused in the following\n            attn_weights_reshaped = attn_weights.view(\n                bsz, self.num_heads, tgt_len, src_len\n            )\n            attn_weights = attn_weights_reshaped.view(\n                bsz * self.num_heads, tgt_len, src_len\n            )\n        else:\n            attn_weights_reshaped = None\n\n        attn_probs = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights_reshaped\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->IdeficsVision\nclass IdeficsVisionMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.fc1\", weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.fc2\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->IdeficsVision\nclass IdeficsVisionEncoderLayer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = IdeficsVisionAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.mlp = IdeficsVisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", weights=weights, eps=config.layer_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n        causal_attention_mask: torch.Tensor,\n        output_attentions: Optional[bool] = False,\n    ) -> Tuple[torch.FloatTensor]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n                `(config.encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            causal_attention_mask=causal_attention_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = residual + hidden_states\n\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->IdeficsVision\nclass IdeficsVisionEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`IdeficsVisionEncoderLayer`].\n\n    Args:\n        config: IdeficsVisionConfig\n    \"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                IdeficsVisionEncoderLayer(\n                    prefix=f\"{prefix}.encoder.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        # self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n        causal_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutput]:\n        r\"\"\"\n        Args:\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Causal mask for the text model. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        encoder_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n\n        hidden_states = inputs_embeds\n        for idx, encoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n            # if self.gradient_checkpointing and self.training:\n\n            #     def create_custom_forward(module):\n            #         def custom_forward(*inputs):\n            #             return module(*inputs, output_attentions)\n\n            #         return custom_forward\n\n            #     layer_outputs = torch.utils.checkpoint.checkpoint(\n            #         create_custom_forward(encoder_layer),\n            #         hidden_states,\n            #         attention_mask,\n            #         causal_attention_mask,\n            #     )\n            # else:\n            layer_outputs = encoder_layer(\n                hidden_states,\n                attention_mask,\n                causal_attention_mask,\n                output_attentions=output_attentions,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_attentions = all_attentions + (layer_outputs[1],)\n\n        if output_hidden_states:\n            encoder_states = encoder_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, encoder_states, all_attentions]\n                if v is not None\n            )\n        return BaseModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=encoder_states,\n            attentions=all_attentions,\n        )\n\n\n# Adapted from transformers.models.clip.modeling_clip.CLIPVisionTransformer\nclass IdeficsVisionTransformer(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = IdeficsVisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.pre_layrnorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.pre_layrnorm\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.encoder = IdeficsVisionEncoder(\n            prefix=prefix, config=config, weights=weights\n        )\n        self.post_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n\n    # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPooling]:\n        r\"\"\"\n        Returns:\n\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.pre_layrnorm(hidden_states)\n\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        last_hidden_state = encoder_outputs[0]\n        pooled_output = last_hidden_state[:, 0, :]\n        pooled_output = self.post_layernorm(pooled_output)\n\n        if not return_dict:\n            return (last_hidden_state, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            pooler_output=pooled_output,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/llava_next.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Llava-NeXT model.\"\"\"\n\nfrom typing import List, Optional, Tuple\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.image_processing_utils import select_best_resolution\n\nfrom text_generation_server.layers.attention import Seqlen\nfrom text_generation_server.models.custom_modeling.vlm import (\n    load_text_model,\n    load_vision_model,\n)\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\ndef get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n    \"\"\"\n    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n\n    Args:\n        image_size (`tuple`):\n            The size of the input image in the format (height, width).\n        grid_pinpoints (`List`):\n            A list containing possible resolutions. Each item in the list should be a tuple or list\n            of the form `(height, width)`.\n        patch_size (`int`):\n            The size of each image patch.\n\n    Returns:\n        tuple: The shape of the image patch grid in the format (height, width).\n    \"\"\"\n    if not isinstance(grid_pinpoints, list):\n        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n\n    height, width = select_best_resolution(image_size, grid_pinpoints)\n    return height // patch_size, width // patch_size\n\n\ndef unpad_image(tensor, original_size):\n    \"\"\"\n    Unpads a PyTorch tensor of a padded and resized image.\n\n    Args:\n        tensor (`torch.Tensor`):\n            The image tensor, assumed to be of shape (num_channels, height, width).\n        original_size (`tuple`):\n            The original size of the image (height, width).\n\n    Returns:\n        `torch.Tensor`: The unpadded image tensor.\n    \"\"\"\n    original_height, original_width = original_size\n    current_height, current_width = tensor.shape[1:]\n\n    original_aspect_ratio = original_width / original_height\n    current_aspect_ratio = current_width / current_height\n\n    if original_aspect_ratio > current_aspect_ratio:\n        scale_factor = current_width / original_width\n        new_height = int(original_height * scale_factor)\n        padding = (current_height - new_height) // 2\n        unpadded_tensor = tensor[:, padding : current_height - padding, :]\n    else:\n        scale_factor = current_height / original_height\n        new_width = int(original_width * scale_factor)\n        padding = (current_width - new_width) // 2\n        unpadded_tensor = tensor[:, :, padding : current_width - padding]\n\n    return unpadded_tensor\n\n\n# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext\nclass LlavaNextMultiModalProjector(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.linear_1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.linear_1\", config=config, weights=weights, bias=True\n        )\n        self.act = ACT2FN[config.projector_hidden_act]\n        self.linear_2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.linear_2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, image_features):\n        hidden_states = self.linear_1(image_features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass LlavaNextForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = config.quantize\n        vision_config = config.vision_config\n        # Instead of selecting in hidden_states[-2].\n        # Instead compute only the n -2 + 1 layers and don't pool\n        if config.vision_feature_layer < 0:\n            vision_config.num_hidden_layers += config.vision_feature_layer + 1\n        else:\n            vision_config.num_hidden_layers = config.vision_feature_layer + 1\n        self.vision_tower = load_vision_model(\n            prefix=\"vision_tower\" if not prefix else f\"{prefix}.vision_tower\",\n            config=config.vision_config,\n            weights=weights,\n        )\n\n        self.multi_modal_projector = LlavaNextMultiModalProjector(\n            prefix=\"multi_modal_projector\", config=config, weights=weights\n        )\n\n        self.image_newline = weights.get_tensor(\"image_newline\")\n\n        self.vocab_size = config.text_config.vocab_size\n        self.config = config\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        self.text_model = load_text_model(\n            prefix=\"language_model\" if not prefix else f\"{prefix}.language_model\",\n            config=config.text_config,\n            weights=weights,\n        )\n        self.pad_token_id = (\n            config.pad_token_id if config.pad_token_id is not None else -1\n        )\n\n    def _merge_input_ids_with_image_features(\n        self,\n        input_ids: torch.Tensor,\n        inputs_embeds: torch.Tensor,\n        image_features: torch.Tensor,\n    ):\n        \"\"\"In place merges in vision_embeddings with inputs_embeds.\"\"\"\n        mask = input_ids == self.config.image_token_index\n        # Let's pray we have enabled enough slots !\n        try:\n            inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])\n        except Exception as e:\n            raise RuntimeError(\n                f\"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens`  to handle images. If error happens at regular runtime, please fill in an issue: {e}\"\n            )\n        return inputs_embeds\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        # num_special_image_tokens = (input_ids == self.config.image_token_index).sum()\n        # assert num_special_image_tokens == len(pixel_values), f\"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid\"\n        # 1. Extract the input embeddings\n\n        # 2. Merge text and images\n        num_images, num_patches, channels, height, width = pixel_values.shape\n        pixel_values = pixel_values.view(\n            num_images * num_patches, channels, height, width\n        )\n        image_features = self.vision_tower(pixel_values)\n\n        # selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]\n        # Already done within the clip model\n        selected_image_feature = image_features.last_hidden_state\n\n        if self.config.vision_feature_select_strategy == \"default\":\n            selected_image_feature = selected_image_feature[:, 1:]\n        elif self.config.vision_feature_select_strategy == \"full\":\n            selected_image_feature = selected_image_feature\n        else:\n            raise RuntimeError(\n                f\"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid.\"\n            )\n\n        image_features = self.multi_modal_projector(selected_image_feature)\n\n        # split up image_features for each of the individual images\n        # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)\n        # if we assume each image has 5 image features (base image + 4 patches)\n        split_sizes = [num_patches] * num_images\n        image_features = torch.split(image_features, split_sizes, dim=0)\n\n        # NOTE we only support multimodal_patch_merge_type == \"spatial_unpad\"\n        height = width = (\n            self.config.vision_config.image_size // self.config.vision_config.patch_size\n        )\n\n        new_image_features = []\n        for image_idx, image_feature in enumerate(image_features):\n            if image_feature.shape[0] > 1:\n                base_image_feature = image_feature[0]\n                image_feature = image_feature[1:]\n\n                if height * width != base_image_feature.shape[0]:\n                    raise ValueError(\n                        \"The number of patches is not consistent with the image size.\"\n                    )\n\n                # Dimensions are intentionally swapped to be bug-compatible with\n                # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59\n                num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n                    image_sizes[image_idx],\n                    self.config.image_grid_pinpoints,\n                    self.config.vision_config.image_size,\n                )\n                image_feature = image_feature.view(\n                    num_patch_height, num_patch_width, height, width, -1\n                )\n                image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()\n                image_feature = image_feature.flatten(1, 2).flatten(2, 3)\n                image_feature = unpad_image(image_feature, image_sizes[image_idx])\n                image_feature = torch.cat(\n                    (\n                        image_feature,\n                        self.image_newline[:, None, None].expand(\n                            *image_feature.shape[:-1], 1\n                        ),\n                    ),\n                    dim=-1,\n                )\n                image_feature = image_feature.flatten(1, 2).transpose(0, 1)\n                image_feature = torch.cat((base_image_feature, image_feature), dim=0)\n            else:\n                image_feature = image_feature[0]\n                image_feature = torch.cat(\n                    (image_feature, self.image_newline[None]), dim=0\n                )\n            new_image_features.append(image_feature)\n        image_features = torch.stack(new_image_features, dim=0)\n        return image_features.view(-1, image_features.shape[-1])\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n        pixel_values: torch.FloatTensor = None,\n        image_sizes: Optional[torch.LongTensor] = None,\n    ):\n        inputs_embeds = self.text_model.embed_tokens(input_ids)\n\n        if vision_embeds is not None:\n            # When we generate, we don't want to replace the potential image_token_id that we generated by images\n            # that simply don't exist\n            inputs_embeds = self._merge_input_ids_with_image_features(\n                input_ids, inputs_embeds, vision_embeds\n            )\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n        # Unused for this model\n        attention_mask: Optional[torch.BoolTensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = self.text_model.model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=None,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.text_model.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/mamba_modeling.py",
    "content": "import torch\nimport torch.distributed\n\nfrom mamba_ssm.ops.triton.selective_state_update import selective_state_update\nfrom mamba_ssm.ops.selective_scan_interface import selective_scan_fn\nfrom torch import nn\nfrom typing import Optional, Tuple, Any\nfrom transformers.configuration_utils import PretrainedConfig\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers import (\n    SpeculativeHead,\n    TensorParallelEmbedding,\n    FastLinear,\n)\nfrom text_generation_server.layers.layernorm import FastRMSNorm\n\nfrom einops import rearrange\nfrom causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nimport math\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass InferenceParams:\n    \"\"\"Inference parameters that are passed to the main model in order\n    to efficienly calculate and store the context during inference.\"\"\"\n\n    max_seqlen: int\n    max_batch_size: int\n    conv_states: torch.Tensor\n    ssm_states: torch.Tensor\n    seqlen_offset: int\n\n\nclass MambaConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=50280,\n        d_model=768,\n        d_state=16,\n        n_layer=32,\n        layer_norm_epsilon=1e-5,\n        tie_word_embeddings=False,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        expand=2,\n        dt_rank=\"auto\",\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.n_layer = n_layer\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.d_model = d_model\n        self.d_inner = d_model * 2\n        self.d_conv = 4\n        self.d_state = d_state\n        self.expand = expand\n        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == \"auto\" else dt_rank\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\nclass MambaBlock(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        self.layer_id = layer_id\n        self.in_proj = FastLinear.load(config, f\"{prefix}.in_proj\", weights, bias=False)\n        self.x_proj = FastLinear.load(config, f\"{prefix}.x_proj\", weights, bias=False)\n        self.dt_proj = FastLinear.load(config, f\"{prefix}.dt_proj\", weights, bias=True)\n        self.dt_proj_no_bias = FastLinear.load(\n            config, f\"{prefix}.dt_proj\", weights, bias=False\n        )\n        self.out_proj = FastLinear.load(\n            config, f\"{prefix}.out_proj\", weights, bias=False\n        )\n        self.conv1d = FastLinear.load(config, f\"{prefix}.conv1d\", weights, bias=True)\n        self.negA = -torch.exp(weights.get_tensor(f\"{prefix}.A_log\").float())\n        self.D = weights.get_tensor(f\"{prefix}.D\")\n        self.activation = \"silu\"\n        self.dt_rank = config.dt_rank\n        self.d_state = config.d_state\n        self.d_conv = config.d_conv\n        self.act = nn.SiLU()\n\n    # inference_params\n    def forward(self, hidden_states: torch.Tensor, inference_params=None):\n        if inference_params.seqlen_offset > 0:\n            conv_state = inference_params.conv_states[self.layer_id]\n            ssm_state = inference_params.ssm_states[self.layer_id]\n            out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)\n            return out, conv_state, ssm_state\n\n        _, seqlen, _ = hidden_states.shape\n        projected_states = self.in_proj(hidden_states).transpose(1, 2)\n        # assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f\"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]\"\n        x, z = projected_states.chunk(2, dim=1)\n        conv_state = F.pad(x, (self.d_conv - seqlen, 0))\n        x = causal_conv1d_fn(\n            x=x,\n            weight=self.conv1d.weight.squeeze(1),\n            bias=self.conv1d.bias,\n            activation=self.activation,\n        )\n\n        # We're careful here about the layout, to avoid extra transposes.\n        # We want dt to have d as the slowest moving dimension\n        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.\n        x_dbl = self.x_proj(rearrange(x, \"b d l -> (b l) d\"))  # (bl d)\n        dt, B, C = torch.split(\n            x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1\n        )\n        dt = self.dt_proj.weight @ dt.t()\n        dt = rearrange(dt, \"d (b l) -> b d l\", l=seqlen)\n        B = rearrange(B, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n        C = rearrange(C, \"(b l) dstate -> b dstate l\", l=seqlen).contiguous()\n        y, last_state = selective_scan_fn(\n            x,\n            dt,\n            self.negA,\n            B,\n            C,\n            self.D.float(),\n            z=z,\n            delta_bias=self.dt_proj.bias.float(),\n            delta_softplus=True,\n            return_last_state=True,\n        )\n        y = rearrange(y, \"b d l -> b l d\")\n        attn_outputs = self.out_proj(y)\n        return attn_outputs, conv_state, last_state\n\n    def step(self, hidden_states, conv_state, ssm_state):\n        xz = self.in_proj(hidden_states.squeeze(1))\n        x, z = xz.chunk(2, dim=-1)  # (B D)\n        x = causal_conv1d_update(\n            x,\n            conv_state,\n            self.conv1d.weight.squeeze(1),\n            self.conv1d.bias,\n            self.activation,\n        )\n        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)\n        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n        dt = F.linear(dt, self.dt_proj.weight)\n        A = self.negA\n        y = selective_state_update(\n            ssm_state,\n            x,\n            dt,\n            A,\n            B,\n            C,\n            self.D,\n            z=z,\n            dt_bias=self.dt_proj.bias,\n            dt_softplus=True,\n        )\n        out = self.out_proj(y)\n        return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, prefix, config, weights, layer_id):\n        super().__init__()\n        self.mamba_block = MambaBlock(\n            prefix=f\"{prefix}.mixer\", config=config, weights=weights, layer_id=layer_id\n        )\n        self.layer_norm = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm\", weights=weights, eps=config.layer_norm_epsilon\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        residual: Optional[torch.Tensor] = None,\n        inference_params: Optional[Any] = None,\n    ):\n        residual = (hidden_states + residual) if residual is not None else hidden_states\n        shape = residual.shape\n        hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1]))\n        hidden_states, conv_state, last_ssm_state = self.mamba_block(\n            hidden_states.view(*shape), inference_params\n        )\n        return hidden_states, residual, conv_state, last_ssm_state\n\n\nclass MambaModel(nn.Module):\n    def __init__(self, config, weights):\n        super().__init__()\n        prefix = \"backbone\"\n        try:\n            self.embed_tokens = TensorParallelEmbedding(f\"{prefix}.embeddings\", weights)\n        except RuntimeError:\n            self.embed_tokens = TensorParallelEmbedding(f\"{prefix}.embedding\", weights)\n        self.blocks = nn.ModuleList(\n            [\n                ResidualBlock(f\"{prefix}.layers.{i}\", config, weights, layer_id=i)\n                for i in range(config.n_layer)\n            ]\n        )\n        self.norm_f = FastRMSNorm.load(\n            f\"{prefix}.norm_f\", weights, eps=config.layer_norm_epsilon\n        )\n        try:\n            self.lm_head = SpeculativeHead.load(config, f\"{prefix}.embeddings\", weights)\n        except RuntimeError:\n            self.lm_head = SpeculativeHead.load(config, f\"{prefix}.embedding\", weights)\n        self.config = config\n\n    def forward(\n        self, input_ids: torch.Tensor, inference_params=None, residual=None\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        hidden_states = self.embed_tokens(input_ids)\n        for i, block in enumerate(self.blocks):\n            hidden_states, residual, conv_state, ssm_state = block(\n                hidden_states, residual, inference_params\n            )\n            inference_params.conv_states[i].copy_(conv_state)\n            inference_params.ssm_states[i].copy_(ssm_state)\n\n        hidden_states = (\n            hidden_states + residual if residual is not None else hidden_states\n        )\n        hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))\n        hidden_states = hidden_states.view(residual.shape)\n        logits, speculative_logits = self.lm_head(hidden_states)\n\n        # update the offset for the next inference using these params\n        inference_params.seqlen_offset += input_ids.size(1)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/mllama.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Mllama model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\nelse:\n    import flash_attn_2_cuda\n\nfrom transformers.activations import ACT2FN\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    FastLinear,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n)\nfrom text_generation_server.models.custom_modeling.flash_llama_modeling import (\n    FlashLlamaForCausalLM,\n)\n\n\ndef _prepare_aspect_ratio_attention_mask(\n    aspect_ratio_mask: torch.Tensor,\n    num_patches: int,\n    target_length: int,\n    dtype: torch.dtype,\n) -> torch.Tensor:\n    # Expand aspect ratio mask to target_length\n    batch_size, max_num_tiles = aspect_ratio_mask.shape\n    attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)\n    attention_mask = attention_mask.repeat(1, 1, target_length, 1)\n\n    # Mask padding patches\n    pad_patches = target_length - num_patches\n    attention_mask[:, :, -pad_patches:] = 0\n\n    # Invert the mask (0 -> 1, 1 -> 0)\n    attention_mask = 1 - attention_mask\n\n    # Reshape to 2D and create 4D attention mask\n    # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)\n    attention_mask = attention_mask.reshape(\n        batch_size, max_num_tiles * target_length, 1\n    )\n    attention_mask = (\n        attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min\n    )\n    attention_mask = attention_mask.unsqueeze(1)\n\n    return attention_mask\n\n\n# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position\ndef _prepare_4d_causal_attention_mask_with_cache_position(\n    attention_mask: torch.Tensor,\n    sequence_length: int,\n    target_length: int,\n    dtype: torch.dtype,\n    device: torch.device,\n    min_dtype: float,\n    cache_position: torch.Tensor,\n    batch_size: int,\n):\n    \"\"\"\n    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n    `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n    Args:\n        attention_mask (`torch.Tensor`):\n            A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n        sequence_length (`int`):\n            The sequence length being processed.\n        target_length (`int`):\n            The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n        dtype (`torch.dtype`):\n            The dtype to use for the 4D attention mask.\n        device (`torch.device`):\n            The device to plcae the 4D attention mask on.\n        min_dtype (`float`):\n            The minimum value representable with the dtype `dtype`.\n        cache_position (`torch.Tensor`):\n            Indices depicting the position of the input sequence tokens in the sequence.\n        batch_size (`torch.Tensor`):\n            Batch size.\n    \"\"\"\n    if attention_mask is not None and attention_mask.dim() == 4:\n        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n        causal_mask = attention_mask\n    else:\n        causal_mask = torch.full(\n            (sequence_length, target_length),\n            fill_value=min_dtype,\n            dtype=dtype,\n            device=device,\n        )\n        if sequence_length != 1:\n            causal_mask = torch.triu(causal_mask, diagonal=1)\n        causal_mask *= torch.arange(\n            target_length, device=device\n        ) > cache_position.reshape(-1, 1)\n        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n        if attention_mask is not None:\n            causal_mask = (\n                causal_mask.clone()\n            )  # copy to contiguous memory for in-place edit\n            mask_length = attention_mask.shape[-1]\n            padding_mask = (\n                causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n            )\n            padding_mask = padding_mask == 0\n            causal_mask[:, :, :, :mask_length] = causal_mask[\n                :, :, :, :mask_length\n            ].masked_fill(padding_mask, min_dtype)\n\n    return causal_mask\n\n\ndef _prepare_cross_attention_mask(\n    cross_attention_mask: torch.Tensor,\n    num_vision_tokens: int,\n    dtype: str,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    # reshape so it can be used by attn module\n    batch_size, text_total_length, *_ = cross_attention_mask.shape\n    cross_attention_mask = cross_attention_mask.repeat_interleave(\n        num_vision_tokens, dim=3\n    )\n    cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)\n    cross_attention_mask = cross_attention_mask.unsqueeze(1)\n\n    # invert the mask\n    inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)\n    cross_attention_mask = inverted_cross_attn_mask.masked_fill(\n        inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n    # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's\n    # last dimension contains negative infinity values, otherwise it's 1\n    negative_inf_value = torch.finfo(dtype).min\n    full_text_row_masked_out_mask = (\n        (cross_attention_mask != negative_inf_value)\n        .any(dim=-1)\n        .type_as(cross_attention_mask)[..., None]\n    )\n    cross_attention_mask *= full_text_row_masked_out_mask\n\n    return cross_attention_mask, full_text_row_masked_out_mask\n\n\n# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision\nclass MllamaVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass MllamaVisionSdpaAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n\n        self.embed_dim = config.hidden_size\n        self.head_dim = config.hidden_size // config.attention_heads\n        self.num_heads = config.attention_heads // weights.process_group.size()\n\n        self.qkv_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.q_proj\", f\"{prefix}.k_proj\", f\"{prefix}.v_proj\"],\n            dim=0,\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        qkv = self.qkv_proj(hidden_state)\n        query, key, value = qkv.split(\n            [\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n                self.head_dim * self.num_heads,\n            ],\n            dim=2,\n        )\n\n        batch_size, q_seq_len, _ = query.shape\n        _, kv_seq_len, _ = key.shape\n\n        query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)\n        key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)\n        value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)\n\n        query = query.transpose(1, 2)\n        key = key.transpose(1, 2)\n        value = value.transpose(1, 2)\n\n        attn_output = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(batch_size, q_seq_len, -1)\n\n        output = self.o_proj(attn_output)\n        return output\n\n\nclass MllamaVisionEncoderLayer(nn.Module):\n    def __init__(self, *, prefix, config, weights, is_gated: bool):\n        super().__init__()\n\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.attention_heads\n        self.is_gated = is_gated\n        self.intermediate_size = config.intermediate_size\n\n        self.self_attn = MllamaVisionSdpaAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.mlp = MllamaVisionMLP(\n            prefix=f\"{prefix}.mlp\", config=config, weights=weights\n        )\n\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=1e-05\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\", weights=weights, eps=1e-05\n        )\n\n        # there used to be an if else here, no code path\n        if is_gated:\n            self.gate_attn = nn.Parameter(\n                weights.get_tensor(f\"{prefix}.gate_attn\"), requires_grad=False\n            )\n            self.gate_ffn = nn.Parameter(\n                weights.get_tensor(f\"{prefix}.gate_ffn\"), requires_grad=False\n            )\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        # Self Attention\n        residual = hidden_state\n        hidden_state = self.input_layernorm(hidden_state)\n        hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)\n        gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()\n        hidden_state = residual + gate_attn * hidden_state\n\n        # Feed forward\n        residual = hidden_state\n        hidden_state = self.post_attention_layernorm(hidden_state)\n        hidden_state = self.mlp(hidden_state)\n        gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()\n        hidden_state = residual + gate_ffn * hidden_state\n        return hidden_state\n\n\nclass MllamaVisionEncoder(nn.Module):\n    def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):\n        super().__init__()\n        self.config = config\n        self.layers = [\n            MllamaVisionEncoderLayer(\n                prefix=f\"{prefix}.layers.{i}\",\n                config=config,\n                weights=weights,\n                is_gated=is_gated,\n            )\n            for i in range(num_layers)\n        ]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        encoder_states = [hidden_states]\n        for encoder_layer in self.layers:\n            layer_outputs = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n\n            hidden_states = layer_outputs\n            encoder_states.append(hidden_states)\n\n        return hidden_states, encoder_states\n\n\nclass MllamaPrecomputedAspectRatioEmbedding(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.max_num_tiles = config.max_num_tiles\n        self.hidden_size = config.hidden_size\n        self.max_aspect_ratio_id = config.max_aspect_ratio_id\n\n        self.embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embedding\", weights=weights\n        )\n        self.gate = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.gate\"), requires_grad=False\n        )\n\n    def forward(\n        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor\n    ) -> torch.Tensor:\n        embeddings = self.embedding(aspect_ratio_ids)\n        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)\n\n        # Always gated.\n        embeddings = embeddings * self.gate.tanh()\n\n        hidden_state = hidden_state + embeddings\n        return hidden_state\n\n\nclass MllamaPrecomputedPositionEmbedding(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.max_num_tiles = config.max_num_tiles\n        self.max_aspect_ratio_id = config.max_aspect_ratio_id\n        self.num_patches = (config.image_size // config.patch_size) ** 2 + 1\n        self.hidden_size = config.hidden_size\n        self.scale = config.hidden_size**-0.5\n\n        self.gate = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.gate\"), requires_grad=False\n        )\n\n        # position embedding\n        embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.embedding\"), requires_grad=False\n        )\n        self.gated_position_embedding = (1 - self.gate.tanh()) * embedding\n        self.tile_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.tile_embedding\", weights=weights\n        )\n\n    def forward(\n        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor\n    ) -> torch.Tensor:\n        # position embeddings\n        hidden_state = hidden_state + self.gated_position_embedding.view(\n            1, 1, self.num_patches, self.hidden_size\n        )\n\n        # precomputed tile position embeddings\n        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)\n        batch_size = hidden_state.shape[0]\n        tile_position_embedding = tile_position_embedding.reshape(\n            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size\n        )\n        gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding\n        hidden_state = hidden_state + gated_tile_position_embedding\n\n        return hidden_state\n\n\nclass MllamaVisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.max_num_tiles = config.max_num_tiles\n        self.hidden_size = config.hidden_size\n        self.num_channels = config.num_channels\n        self.intermediate_layers_indices = config.intermediate_layers_indices\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1\n        self.scale = config.hidden_size**-0.5\n        self.dtype = weights.dtype\n\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.hidden_size,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n\n        self.class_embedding = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.class_embedding\"), requires_grad=False\n        )\n\n        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(\n            prefix=f\"{prefix}.gated_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n\n        self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(\n            prefix=f\"{prefix}.pre_tile_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n        self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(\n            prefix=f\"{prefix}.post_tile_positional_embedding\",\n            config=config,\n            weights=weights,\n        )\n\n        ## layer norms\n        self.layernorm_pre = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_pre\",\n            weights=weights,\n            # torch default\n            eps=1e-05,\n        )\n        self.layernorm_post = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layernorm_post\",\n            weights=weights,\n            # torch default\n            eps=1e-05,\n        )\n\n        ## encoders\n        self.transformer = MllamaVisionEncoder(\n            prefix=f\"{prefix}.transformer\",\n            config=config,\n            weights=weights,\n            is_gated=False,\n            num_layers=config.num_hidden_layers,\n        )\n        self.global_transformer = MllamaVisionEncoder(\n            prefix=f\"{prefix}.global_transformer\",\n            config=config,\n            weights=weights,\n            is_gated=True,\n            num_layers=config.num_global_layers,\n        )\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        aspect_ratio_ids: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> torch.Tensor:\n        (\n            batch_size,\n            num_concurrent_media,\n            num_tiles,\n            num_channels,\n            height,\n            width,\n        ) = pixel_values.shape\n\n        pixel_values = pixel_values.reshape(\n            batch_size * num_concurrent_media * num_tiles, num_channels, height, width\n        )\n        aspect_ratio_ids = aspect_ratio_ids.reshape(\n            batch_size * num_concurrent_media, -1\n        )\n\n        # patch embedding\n        patch_embeds = self.patch_embedding(pixel_values)\n        hidden_state = patch_embeds.flatten(2).transpose(1, 2)\n\n        # tile embeddings\n        _, num_patches, dim = hidden_state.shape\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media, num_tiles, -1, dim\n        )\n        hidden_state = self.pre_tile_positional_embedding(\n            hidden_state, aspect_ratio_ids\n        )\n\n        # apply cls token\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media * num_tiles, num_patches, dim\n        )\n        hidden_state = self.apply_class_embedding(hidden_state)\n        num_patches += 1\n\n        # apply position embeddings\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media, num_tiles, num_patches, dim\n        )\n        hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)\n\n        # apply encoder\n        hidden_state = self.layernorm_pre(hidden_state)\n\n        # Compute the number of tokens to pad\n        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8\n        # Compute padding tuple for pad function\n        padding = (\n            0,\n            0,\n            0,\n            num_padding_patches,\n        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)\n        # Pad the tensor\n        hidden_state = F.pad(hidden_state, padding, mode=\"constant\", value=0)\n        slice_index = -num_padding_patches if num_padding_patches > 0 else None\n\n        if attention_mask is not None:\n            attention_mask = attention_mask.reshape(\n                batch_size * num_concurrent_media, -1\n            )\n            attention_mask = _prepare_aspect_ratio_attention_mask(\n                aspect_ratio_mask=attention_mask,\n                num_patches=self.num_patches,\n                target_length=hidden_state.shape[2],\n                dtype=self.dtype,\n            )\n\n        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)\n        hidden_state, all_intermediate_hidden_states = self.transformer(\n            hidden_state,\n            attention_mask=attention_mask,\n        )\n        intermediate_hidden_states = [\n            hidden_state\n            for idx, hidden_state in enumerate(all_intermediate_hidden_states)\n            if idx in self.intermediate_layers_indices\n        ]\n        intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)\n\n        # apply global encoder\n        hidden_state = self.layernorm_post(hidden_state)\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            dim,\n        )\n        hidden_state = self.post_tile_positional_embedding(\n            hidden_state, aspect_ratio_ids\n        )\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles * (num_patches + num_padding_patches),\n            dim,\n        )\n        hidden_state, _ = self.global_transformer(\n            hidden_state, attention_mask=attention_mask\n        )\n        hidden_state = hidden_state.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            dim,\n        )\n        hidden_state = hidden_state[:, :, :slice_index]\n\n        # adding intermediate layer outputs\n        hidden_state = hidden_state.reshape(\n            batch_size, num_concurrent_media, num_tiles, num_patches, dim\n        )\n        intermediate_hidden_states = intermediate_hidden_states.reshape(\n            batch_size * num_concurrent_media,\n            num_tiles,\n            num_patches + num_padding_patches,\n            -1,\n        )\n        intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]\n        intermediate_hidden_states = intermediate_hidden_states.reshape(\n            batch_size, num_concurrent_media, num_tiles, num_patches, -1\n        )\n        hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)\n        return hidden_state\n\n\nclass MllamaTextCrossAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, *, prefix, config, weights, layer_idx):\n        super().__init__()\n        self.config = config\n        self.num_heads = self.config.num_attention_heads\n        self.num_key_value_heads = self.config.num_key_value_heads\n        self.dropout = config.dropout\n        self.hidden_size = config.hidden_size\n        self.head_size = config.hidden_size // self.num_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.layer_idx = layer_idx\n\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.num_key_value_heads = (\n            self.num_key_value_heads // weights.process_group.size()\n        )\n\n        self.q_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.q_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.k_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.k_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config,\n            prefix=f\"{prefix}.v_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.o_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.o_proj\",\n            weights=weights,\n            bias=False,\n        )\n\n        self.q_norm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.q_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.k_norm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.k_norm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.softmax_scale = self.head_size**-0.5\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cross_attention_states: Optional[torch.Tensor] = None,\n        # past_key_value=None,\n        # attention_mask: Optional[torch.Tensor] = None,\n        # cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n        # hidden_states = hidden_states.unsqueeze(0)\n        # bsz, q_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states)\n        query_states = query_states.view(-1, self.num_heads, self.head_size)\n        query_states = self.q_norm(query_states)\n\n        (\n            cross_attention_states,\n            cu_seqlen_q,\n            cu_seqlen_k,\n            max_q,\n            max_k,\n            indices,\n        ) = cross_attention_states\n\n        key_states = self.k_proj(cross_attention_states)\n        value_states = self.v_proj(cross_attention_states)\n        key_states = key_states.view(-1, self.num_key_value_heads, self.head_size)\n        value_states = value_states.view(-1, self.num_key_value_heads, self.head_size)\n        key_states = self.k_norm(key_states)\n\n        # key_states = key_states.repeat(1, self.num_key_value_groups, 1)\n        # value_states = value_states.repeat(1, self.num_key_value_groups, 1)\n\n        causal = False\n        # logger.info(\n        #     f\"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}\"\n        # )\n        if SYSTEM == \"ipex\":\n            attn_output = torch.empty_like(query_states)\n            if query_states.device.type == \"xpu\":\n                ipex.llm.functional.varlen_attention(\n                    query_states.contiguous(),\n                    key_states.contiguous(),\n                    value_states.contiguous(),\n                    attn_output,\n                    cu_seqlen_q,\n                    cu_seqlen_k,\n                    None,\n                    max_q,\n                    max_k,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n            else:\n                ipex.llm.functional.varlen_attention(\n                    query_states,\n                    key_states,\n                    value_states,\n                    attn_output,\n                    cu_seqlen_q,\n                    cu_seqlen_k,\n                    max_q,\n                    max_k,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n        else:\n            attn_output = flash_attn_2_cuda.varlen_fwd(\n                query_states,\n                key_states,\n                value_states,\n                None,\n                cu_seqlen_q,\n                cu_seqlen_k,\n                None,\n                None,\n                None,  # block_tables\n                None,\n                max_q,\n                max_k,\n                0.0,\n                self.softmax_scale,\n                False,\n                causal,  # Causal\n                -1,  # window_size_left,\n                -1,\n                0.0,  # softcap\n                False,\n                None,\n            )[0]\n        attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))\n\n        return attn_output\n\n\n# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText\nclass MllamaTextMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n        self.gate_up_proj = TensorParallelColumnLinear.load_multi(\n            config,\n            prefixes=[f\"{prefix}.gate_proj\", f\"{prefix}.up_proj\"],\n            weights=weights,\n            dim=0,\n            bias=False,\n        )\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=False,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        shape = x.shape\n        gate_up_states = self.gate_up_proj(x)\n        gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)\n        result = self.down_proj(\n            self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]\n        )\n        return result\n\n\nclass FlashLlamaCrossLayer(torch.nn.Module):\n    \"\"\"Cross-attention transformer block with tanh-gated attention and feedforward.\"\"\"\n\n    def __init__(self, *, prefix, config, weights, index) -> None:\n        layer_idx = index\n        super().__init__()\n        self.cross_attn = MllamaTextCrossAttention(\n            prefix=f\"{prefix}.cross_attn\",\n            config=config,\n            weights=weights,\n            layer_idx=layer_idx,\n        )\n\n        self.input_layernorm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.input_layernorm\", weights=weights, eps=config.rms_norm_eps\n        )\n        self.cross_attn_attn_gate = torch.nn.Parameter(\n            weights.get_tensor(f\"{prefix}.cross_attn_attn_gate\"), requires_grad=False\n        )\n\n        self.mlp = MllamaTextMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.post_attention_layernorm = MllamaTextRMSNorm.load(\n            prefix=f\"{prefix}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.cross_attn_mlp_gate = torch.nn.Parameter(\n            weights.get_tensor(f\"{prefix}.cross_attn_mlp_gate\"), requires_grad=False\n        )\n        self.layer_idx = layer_idx\n\n    def forward(\n        self,\n        hidden_states,\n        residual,\n        cos,\n        sin,\n        cu_seqlen_prefill,\n        kv_cache,\n        block_tables,\n        slots,\n        seqlen,\n        max_s,\n        adapter_data,\n        cross_attention_states,  # [ IB, ...]\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        if cross_attention_states is None:\n            return hidden_states, residual\n        if residual is not None:\n            hidden_states += residual\n\n        indices = cross_attention_states[-1]\n        out_hidden_states = hidden_states[:]\n        if len(indices) > 0:\n            assert max(indices) < hidden_states.shape[0]\n        hidden_states = hidden_states[indices]\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n\n        hidden_states = self.cross_attn(\n            hidden_states=hidden_states,\n            # attention_mask=cross_attention_mask,\n            cross_attention_states=cross_attention_states,\n        )\n        hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states\n\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states\n\n        out_hidden_states[indices] = hidden_states\n        hidden_states = out_hidden_states\n\n        return hidden_states, None\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText\nclass MllamaTextRMSNorm(nn.Module):\n    def __init__(self, weight, eps):\n        super().__init__()\n        self.weight = weight\n        self.variance_epsilon = eps\n\n    @classmethod\n    def load(cls, *, prefix, weights, eps):\n        weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.weight\"), requires_grad=False\n        )\n        return cls(weight=weight, eps=eps)\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass MllamaForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        config.text_config.quantize = config.quantize\n        config.text_config.speculator = config.speculator\n        config.text_config._attn_implementation = \"sdpa\"\n        self.hidden_size = config.text_config.hidden_size\n        self.vision_model = MllamaVisionModel(\n            prefix=\"vision_model\", config=config.vision_config, weights=weights\n        )\n        self.multi_modal_projector = FastLinear.load(\n            prefix=\"multi_modal_projector\", config=config, weights=weights, bias=True\n        )\n        self.text_model = FlashLlamaForCausalLM(\n            prefix=\"language_model\", config=config.text_config, weights=weights\n        )\n        self.config = config\n        self.dtype = weights.dtype\n        self.device = weights.device\n\n    def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):\n        if aspect_ratio_ids is None:\n            raise ValueError(\n                \"`aspect_ratio_ids` must be provided if `pixel_values` is provided\"\n            )\n        # logger.info(f\"PIxel values {pixel_values.shape}\")\n        batch_size = pixel_values.shape[0]\n        vision_states = self.vision_model(\n            pixel_values, aspect_ratio_ids, aspect_ratio_mask\n        )\n        cross_attention_states = self.multi_modal_projector(vision_states).reshape(\n            -1, vision_states.shape[-2], self.hidden_size\n        )\n        _, _, h = cross_attention_states.shape\n        cross_attention_states = cross_attention_states.view(batch_size, -1, h)\n        # logger.info(f\"cross {cross_attention_states.shape}\")\n        return cross_attention_states\n\n    def forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor],\n        adapter_data: Optional[torch.Tensor] = None,\n        # XXX: Putting these as optional so that the cuda warmup calls can go through.\n        cross_attention_states: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n        if cross_attention_states is not None:\n            seqlen_q = len(image_indices)\n            n_images = cross_attention_states.shape[0]\n            seqlen_k = cross_attention_states.shape[1]\n            device = cross_attention_states.device\n            if cu_seqlen_prefill is not None:\n                offset = 0\n                cu_q = []\n                indices = []\n                for index in image_indices:\n                    cu_q.append(offset)\n                    length = seqlen.input_lengths[index].item()\n                    assert index < seqlen.cu_seqlen_q.shape[0]\n                    input_ids_offset = seqlen.cu_seqlen_q[index]\n                    indices.extend(range(input_ids_offset, input_ids_offset + length))\n                    offset += length\n                cu_q.append(offset)\n                cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)\n\n                assert max(indices) < input_ids.shape[0]\n\n                cu_seqlen_k = (\n                    torch.arange(\n                        n_images + 1,\n                        device=device,\n                        dtype=torch.int32,\n                    )\n                    * seqlen_k\n                )\n                max_q = cu_seqlen_q[-1].item()\n                max_k = seqlen_k\n            else:\n                cu_seqlen_q = torch.arange(\n                    seqlen_q + 1, device=device, dtype=torch.int32\n                )\n                seqlen_k = cross_attention_states.shape[1]\n                n_images = cross_attention_states.shape[0]\n                cu_seqlen_k = (\n                    torch.arange(\n                        n_images + 1,\n                        device=device,\n                        dtype=torch.int32,\n                    )\n                    * seqlen_k\n                )\n                max_q = seqlen_q\n                max_k = seqlen_k\n                indices = image_indices[:]\n\n            cross_attention_states = (\n                cross_attention_states,\n                cu_seqlen_q,\n                cu_seqlen_k,\n                max_q,\n                max_k,\n                indices,\n            )\n\n        outputs = self.text_model(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            prefill_cache_indices=prefill_cache_indices,\n            lm_head_indices=lm_head_indices,\n            adapter_data=adapter_data,\n            cross_attention_states=cross_attention_states,\n        )\n\n        return outputs\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/mpt_modeling.py",
    "content": "\"\"\"A simple, flexible implementation of a GPT model.\n\nInspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py\n\"\"\"\n\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom einops import rearrange\nfrom packaging import version\nfrom text_generation_server.layers import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n    get_linear,\n)\n\nEPS = 1e-5\n\n\ndef load_col(config, prefix, weights, bias):\n    assert config.quantize != \"gptq\", NotImplementedError\n    slice_ = weights._get_slice(f\"{prefix}.weight\")\n    rank = weights.process_group.rank()\n    size = weights.process_group.size()\n\n    h3, h = slice_.get_shape()\n    block_size = h // size\n\n    q_part = slice_[rank * block_size : (rank + 1) * block_size]\n    k_part = slice_[h + rank * block_size : h + (rank + 1) * block_size]\n    v_part = slice_[2 * h + rank * block_size : 2 * h + (rank + 1) * block_size]\n\n    weight = torch.cat([q_part, k_part, v_part], dim=0)\n    if weight.dtype != torch.int32:\n        weight = weight.to(dtype=weights.dtype)\n    weight = weight.to(device=weights.device)\n\n    if bias:\n        bias_slice_ = weights._get_slice(f\"{prefix}.bias\")\n        bias_rank = weights.process_group.rank()\n        bias_size = weights.process_group.size()\n\n        bias_h = bias_slice_.get_shape()\n        bias_h = bias_h[0]\n        bias_block_size = bias_h // bias_size\n\n        bias_q_part = bias_slice_[\n            bias_rank * bias_block_size : (bias_rank + 1) * bias_block_size\n        ]\n        bias_k_part = bias_slice_[\n            bias_h\n            + bias_rank * bias_block_size : bias_h\n            + (bias_rank + 1) * bias_block_size\n        ]\n        bias_v_part = bias_slice_[\n            2 * bias_h\n            + bias_rank * bias_block_size : 2 * bias_h\n            + (bias_rank + 1) * bias_block_size\n        ]\n\n        bias = torch.cat([bias_q_part, bias_k_part, bias_v_part], dim=0)\n        if bias.dtype != torch.int32:\n            bias = bias.to(dtype=weights.dtype)\n        bias = bias.to(device=weights.device)\n    else:\n        bias = None\n    linear = get_linear(weight, bias)\n    return TensorParallelColumnLinear(linear)\n\n\ndef _reset_is_causal(\n    num_query_tokens: int, num_key_tokens: int, original_is_causal: bool\n):\n    if original_is_causal and num_query_tokens != num_key_tokens:\n        if num_query_tokens != 1:\n            raise NotImplementedError(\n                \"MPT does not support query and key with different number of tokens, unless number of query tokens is 1.\"\n            )\n        else:\n            return False\n    return original_is_causal\n\n\ndef scaled_multihead_dot_product_attention(\n    query,\n    key,\n    value,\n    n_heads,\n    past_key_value=None,\n    softmax_scale=None,\n    attn_bias=None,\n    key_padding_mask=None,\n    is_causal=False,\n    dropout_p=0.0,\n    training=False,\n    needs_weights=False,\n    multiquery=False,\n):\n    q = rearrange(query, \"b s (h d) -> b h s d\", h=n_heads)\n    kv_n_heads = 1 if multiquery else n_heads\n    k = rearrange(key, \"b s (h d) -> b h d s\", h=kv_n_heads)\n    v = rearrange(value, \"b s (h d) -> b h s d\", h=kv_n_heads)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            k = torch.cat([past_key_value[0], k], dim=3)\n            v = torch.cat([past_key_value[1], v], dim=2)\n        past_key_value = (k, v)\n    (b, _, s_q, d) = q.shape\n    s_k = k.size(-1)\n    attn_weight = q.matmul(k) * softmax_scale\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - s_q)\n        _s_k = max(0, attn_bias.size(3) - s_k)\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n        if (\n            attn_bias.size(-1) != 1\n            and attn_bias.size(-1) != s_k\n            or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)\n        ):\n            raise RuntimeError(\n                f\"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.\"\n            )\n        attn_weight = attn_weight + attn_bias\n    min_val = torch.finfo(q.dtype).min\n    if key_padding_mask is not None:\n        if attn_bias is not None:\n            warnings.warn(\n                \"Propogating key_padding_mask to the attention module \"\n                + \"and applying it within the attention module can cause \"\n                + \"unneccessary computation/memory usage. Consider integrating \"\n                + \"into attn_bias once and passing that to each attention \"\n                + \"module instead.\"\n            )\n        attn_weight = attn_weight.masked_fill(\n            ~key_padding_mask.view((b, 1, 1, s_k)), min_val\n        )\n    if is_causal and (not q.size(2) == 1):\n        s = max(s_q, s_k)\n        causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)\n        causal_mask = causal_mask.tril()\n        causal_mask = causal_mask.to(torch.bool)\n        causal_mask = ~causal_mask\n        causal_mask = causal_mask[-s_q:, -s_k:]\n        attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)\n    attn_weight = torch.softmax(attn_weight, dim=-1)\n    if dropout_p:\n        attn_weight = torch.nn.functional.dropout(\n            attn_weight, p=dropout_p, training=training, inplace=True\n        )\n    out = attn_weight.to(v.dtype).matmul(v)\n    out = rearrange(out, \"b h s d -> b s (h d)\")\n    if needs_weights:\n        return (out, attn_weight, past_key_value)\n    return (out, None, past_key_value)\n\n\ndef check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):\n    for tensor in tensors:\n        if tensor.dtype not in valid_dtypes:\n            raise TypeError(\n                f\"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.\"\n            )\n        if not tensor.is_cuda:\n            raise TypeError(\n                f\"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).\"\n            )\n\n\ndef flash_attn_fn(\n    query,\n    key,\n    value,\n    n_heads,\n    past_key_value=None,\n    softmax_scale=None,\n    attn_bias=None,\n    key_padding_mask=None,\n    is_causal=False,\n    dropout_p=0.0,\n    training=False,\n    needs_weights=False,\n    multiquery=False,\n):\n    try:\n        from flash_attn import bert_padding, flash_attn_interface\n    except Exception:\n        raise RuntimeError(\"Please install flash-attn==1.0.3.post0\")\n    check_valid_inputs(query, key, value)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            key = torch.cat([past_key_value[0], key], dim=1)\n            value = torch.cat([past_key_value[1], value], dim=1)\n        past_key_value = (key, value)\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - query.size(1))\n        _s_k = max(0, attn_bias.size(3) - key.size(1))\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n    if attn_bias is not None:\n        raise NotImplementedError(\"attn_bias not implemented for flash attn.\")\n    (batch_size, seqlen) = query.shape[:2]\n    if key_padding_mask is None:\n        key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)\n    query_padding_mask = key_padding_mask[:, -query.size(1) :]\n    (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(\n        query, query_padding_mask\n    )\n    query_unpad = rearrange(query_unpad, \"nnz (h d) -> nnz h d\", h=n_heads)\n    (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(\n        key, key_padding_mask\n    )\n    key_unpad = rearrange(\n        key_unpad, \"nnz (h d) -> nnz h d\", h=1 if multiquery else n_heads\n    )\n    (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)\n    value_unpad = rearrange(\n        value_unpad, \"nnz (h d) -> nnz h d\", h=1 if multiquery else n_heads\n    )\n    if multiquery:\n        key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))\n        value_unpad = value_unpad.expand(\n            value_unpad.size(0), n_heads, value_unpad.size(-1)\n        )\n    dropout_p = dropout_p if training else 0.0\n    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)\n    output_unpad = flash_attn_interface.flash_attn_unpadded_func(\n        query_unpad,\n        key_unpad,\n        value_unpad,\n        cu_seqlens_q,\n        cu_seqlens_k,\n        max_seqlen_q,\n        max_seqlen_k,\n        dropout_p,\n        softmax_scale=softmax_scale,\n        causal=reset_is_causal,\n        return_attn_probs=needs_weights,\n    )\n    output = bert_padding.pad_input(\n        rearrange(output_unpad, \"nnz h d -> nnz (h d)\"), indices_q, batch_size, seqlen\n    )\n    return (output, None, past_key_value)\n\n\ndef triton_flash_attn_fn(\n    query,\n    key,\n    value,\n    n_heads,\n    past_key_value=None,\n    softmax_scale=None,\n    attn_bias=None,\n    key_padding_mask=None,\n    is_causal=False,\n    dropout_p=0.0,\n    training=False,\n    needs_weights=False,\n    multiquery=False,\n):\n    try:\n        from .flash_attn_triton import flash_attn_func\n    except Exception:\n        _installed = False\n        if version.parse(torch.__version__) < version.parse(\"2.0.0\"):\n            _installed = True\n            try:\n                from flash_attn.flash_attn_triton import flash_attn_func\n            except Exception:\n                _installed = False\n        if not _installed:\n            raise RuntimeError(\n                \"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.\"\n            )\n    check_valid_inputs(query, key, value)\n    if past_key_value is not None:\n        if len(past_key_value) != 0:\n            key = torch.cat([past_key_value[0], key], dim=1)\n            value = torch.cat([past_key_value[1], value], dim=1)\n        past_key_value = (key, value)\n    if attn_bias is not None:\n        _s_q = max(0, attn_bias.size(2) - query.size(1))\n        _s_k = max(0, attn_bias.size(3) - key.size(1))\n        attn_bias = attn_bias[:, :, _s_q:, _s_k:]\n    if dropout_p:\n        raise NotImplementedError(\"Dropout not implemented for attn_impl: triton.\")\n    if needs_weights:\n        raise NotImplementedError(\"attn_impl: triton cannot return attn weights.\")\n    if key_padding_mask is not None:\n        warnings.warn(\n            \"Propagating key_padding_mask to the attention module \"\n            + \"and applying it within the attention module can cause \"\n            + \"unnecessary computation/memory usage. Consider integrating \"\n            + \"into attn_bias once and passing that to each attention \"\n            + \"module instead.\"\n        )\n        (b_size, s_k) = key_padding_mask.shape[:2]\n        if attn_bias is None:\n            attn_bias = query.new_zeros(b_size, 1, 1, s_k)\n        attn_bias = attn_bias.masked_fill(\n            ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min\n        )\n    query = rearrange(query, \"b s (h d) -> b s h d\", h=n_heads)\n    key = rearrange(key, \"b s (h d) -> b s h d\", h=1 if multiquery else n_heads)\n    value = rearrange(value, \"b s (h d) -> b s h d\", h=1 if multiquery else n_heads)\n    if multiquery:\n        key = key.expand(*key.shape[:2], n_heads, key.size(-1))\n        value = value.expand(*value.shape[:2], n_heads, value.size(-1))\n    reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)\n    attn_output = flash_attn_func(\n        query, key, value, attn_bias, reset_is_causal, softmax_scale\n    )\n    output = attn_output.view(*attn_output.shape[:2], -1)\n    return (output, None, past_key_value)\n\n\nclass MultiheadAttention(nn.Module):\n    \"\"\"Multi-head self attention.\n\n    Using torch or triton attention implementation enables user to also use\n    additive bias.\n    \"\"\"\n\n    def __init__(\n        self,\n        config,\n        prefix,\n        weights,\n    ):\n        super().__init__()\n        attn_impl = config.attn_config.attn_impl\n        self.attn_impl = config.attn_config.attn_impl\n        self.clip_qkv = config.attn_config.clip_qkv\n        self.qk_ln = config.attn_config.qk_ln\n        self.d_model = config.d_model\n        d_model = config.d_model\n        self.n_heads = config.n_heads\n        self.softmax_scale = config.attn_config.softmax_scale\n        if self.softmax_scale is None:\n            self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)\n        self.attn_dropout_p = config.attn_config.attn_pdrop\n\n        if self.n_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.n_heads = self.n_heads // weights.process_group.size()\n        self.Wqkv = load_col(\n            config, prefix=f\"{prefix}.Wqkv\", weights=weights, bias=not config.no_bias\n        )\n        if self.qk_ln:\n            bias = not config.no_bias\n            hidden_size = config.d_model\n            head_dim = hidden_size // self.n_heads\n\n            self.q_ln = LPLayerNorm(\n                d_model, bias=bias, prefix=f\"{prefix}.q_ln\", weights=weights\n            )\n            self.k_ln = LPLayerNorm(\n                self.n_heads * head_dim, prefix=f\"{prefix}.k_ln\", weights=weights\n            )\n        if self.attn_impl == \"flash\":\n            self.attn_fn = flash_attn_fn\n        elif self.attn_impl == \"triton\":\n            self.attn_fn = triton_flash_attn_fn\n        elif self.attn_impl == \"torch\":\n            self.attn_fn = scaled_multihead_dot_product_attention\n        else:\n            raise ValueError(f\"attn_impl={attn_impl!r} is an invalid setting.\")\n        self.out_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=not config.no_bias,\n        )\n\n    def forward(\n        self,\n        x,\n        past_key_value=None,\n        attn_bias=None,\n        attention_mask=None,\n        is_causal=True,\n        needs_weights=False,\n    ):\n        qkv = self.Wqkv(x)\n        if self.clip_qkv:\n            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)\n        (query, key, value) = qkv.chunk(3, dim=2)\n\n        key_padding_mask = attention_mask\n        if self.qk_ln:\n            dtype = query.dtype\n            query = self.q_ln(query).to(dtype)\n            key = self.k_ln(key).to(dtype)\n        (context, attn_weights, past_key_value) = self.attn_fn(\n            query,\n            key,\n            value,\n            self.n_heads,\n            past_key_value=past_key_value,\n            softmax_scale=self.softmax_scale,\n            attn_bias=attn_bias,\n            key_padding_mask=key_padding_mask,\n            is_causal=is_causal,\n            dropout_p=self.attn_dropout_p,\n            training=self.training,\n            needs_weights=needs_weights,\n        )\n        out = self.out_proj(context)\n        return (out, attn_weights, past_key_value)\n\n\nclass MultiQueryAttention(nn.Module):\n    \"\"\"Multi-Query self attention.\n\n    Using torch or triton attention implementation enables user to also use\n    additive bias.\n    \"\"\"\n\n    def __init__(self, config, prefix, weights, verbose=False):\n        super().__init__()\n        attn_impl = config.attn_config.attn_impl\n        self.attn_impl = config.attn_config.attn_impl\n        self.clip_qkv = config.attn_config.clip_qkv\n        self.qk_ln = config.attn_config.qk_ln\n        self.d_model = config.d_model\n        d_model = config.d_model\n        self.n_heads = config.n_heads\n        self.softmax_scale = config.attn_config.softmax_scale\n        if self.softmax_scale is None:\n            self.softmax_scale = 1 / math.sqrt(self.head_dim)\n        self.attn_dropout_p = config.attn_config.attn_pdrop\n        # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)\n        self.Wqkv = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.Wqkv\", weights=weights, bias=not config.no_bias\n        )\n        (d_model, d_model + self.head_dim)\n        if self.qk_ln:\n            raise NotImplementedError(\"qk_ln not supported\")\n        if self.attn_impl == \"flash\":\n            self.attn_fn = flash_attn_fn\n        elif self.attn_impl == \"triton\":\n            self.attn_fn = triton_flash_attn_fn\n            if verbose:\n                warnings.warn(\n                    \"While `attn_impl: triton` can be faster than `attn_impl: flash` \"\n                    + \"it uses more memory. When training larger models this can trigger \"\n                    + \"alloc retries which hurts performance. If encountered, we recommend \"\n                    + \"using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.\"\n                )\n        elif self.attn_impl == \"torch\":\n            self.attn_fn = scaled_multihead_dot_product_attention\n            if torch.cuda.is_available() and verbose:\n                warnings.warn(\n                    \"Using `attn_impl: torch`. If your model does not use `alibi` or \"\n                    + \"`prefix_lm` we recommend using `attn_impl: flash` otherwise \"\n                    + \"we recommend using `attn_impl: triton`.\"\n                )\n        else:\n            raise ValueError(f\"attn_impl={attn_impl!r} is an invalid setting.\")\n        self.out_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=not config.no_bias,\n        )\n        # self.out_proj._is_residual = True\n\n    def forward(\n        self,\n        x,\n        past_key_value=None,\n        attn_bias=None,\n        attention_mask=None,\n        is_causal=True,\n        needs_weights=False,\n    ):\n        qkv = self.Wqkv(x)\n        if self.clip_qkv:\n            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)\n        (query, key, value) = qkv.split(\n            [self.d_model, self.head_dim, self.head_dim], dim=2\n        )\n        key_padding_mask = attention_mask\n        if self.qk_ln:\n            dtype = query.dtype\n            query = self.q_ln(query).to(dtype)\n            key = self.k_ln(key).to(dtype)\n        (context, attn_weights, past_key_value) = self.attn_fn(\n            query,\n            key,\n            value,\n            self.n_heads,\n            past_key_value=past_key_value,\n            softmax_scale=self.softmax_scale,\n            attn_bias=attn_bias,\n            key_padding_mask=key_padding_mask,\n            is_causal=is_causal,\n            dropout_p=self.attn_dropout_p,\n            training=self.training,\n            needs_weights=needs_weights,\n            multiquery=True,\n        )\n        return (self.out_proj(context), attn_weights, past_key_value)\n\n\ndef attn_bias_shape(\n    attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id\n):\n    if attn_impl == \"flash\":\n        return None\n    elif attn_impl in [\"torch\", \"triton\"]:\n        if alibi:\n            if (prefix_lm or not causal) or use_sequence_id:\n                return (1, n_heads, seq_len, seq_len)\n            return (1, n_heads, 1, seq_len)\n        elif prefix_lm or use_sequence_id:\n            return (1, 1, seq_len, seq_len)\n        return None\n    else:\n        raise ValueError(f\"attn_impl={attn_impl!r} is an invalid setting.\")\n\n\ndef build_attn_bias(\n    attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8\n):\n    if attn_impl == \"flash\":\n        return None\n    elif attn_impl in [\"torch\", \"triton\"]:\n        if alibi:\n            (device, dtype) = (attn_bias.device, attn_bias.dtype)\n            attn_bias = attn_bias.add(\n                build_alibi_bias(\n                    n_heads,\n                    seq_len,\n                    full=not causal,\n                    alibi_bias_max=alibi_bias_max,\n                    device=device,\n                    dtype=dtype,\n                )\n            )\n        return attn_bias\n    else:\n        raise ValueError(f\"attn_impl={attn_impl!r} is an invalid setting.\")\n\n\ndef gen_slopes(n_heads, alibi_bias_max=8, device=None):\n    _n_heads = 2 ** math.ceil(math.log2(n_heads))\n    m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)\n    m = m.mul(alibi_bias_max / _n_heads)\n    slopes = 1.0 / torch.pow(2, m)\n    if _n_heads != n_heads:\n        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]\n    return slopes.view(1, n_heads, 1, 1)\n\n\ndef build_alibi_bias(\n    n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None\n):\n    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(\n        1, 1, 1, seq_len\n    )\n    if full:\n        alibi_bias = alibi_bias - torch.arange(\n            1 - seq_len, 1, dtype=torch.int32, device=device\n        ).view(1, 1, seq_len, 1)\n        alibi_bias = alibi_bias.abs().mul(-1)\n    slopes = gen_slopes(n_heads, alibi_bias_max, device=device)\n    alibi_bias = alibi_bias * slopes\n    return alibi_bias.to(dtype=dtype)\n\n\nATTN_CLASS_REGISTRY = {\n    \"multihead_attention\": MultiheadAttention,\n    \"multiquery_attention\": MultiQueryAttention,\n}\n\n\"\"\"GPT Blocks used for the GPT Model.\"\"\"\n\n\nclass MPTMLP(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        # self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)\n        self.up_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.up_proj\", weights=weights, bias=not config.no_bias\n        )\n        self.act = nn.GELU(approximate=\"none\")\n        # self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)\n        self.down_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.down_proj\",\n            weights=weights,\n            bias=not config.no_bias,\n        )\n        # self.down_proj._is_residual = True\n\n    def forward(self, x):\n        return self.down_proj(self.act(self.up_proj(x)))\n\n\nclass MPTBlock(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.prefix = prefix\n        if config.attn_config.attn_type != \"multihead_attention\":\n            raise NotImplementedError(\n                f\"\"\"Not implemented attn {config.attn_config.attn_type}\"\"\"\n            )\n        resid_pdrop = config.resid_pdrop\n        if config.no_bias:\n            self.norm_1 = nn.LayerNorm.load_no_bias(\n                prefix=f\"{prefix}.norm_1\", weights=weights, eps=EPS\n            )\n            self.norm_2 = nn.LayerNorm.load_no_bias(\n                prefix=f\"{prefix}.norm_2\", weights=weights, eps=EPS\n            )\n        else:\n            self.norm_1 = nn.LayerNorm.load(\n                prefix=f\"{prefix}.norm_1\", weights=weights, eps=EPS\n            )\n            self.norm_2 = nn.LayerNorm.load(\n                prefix=f\"{prefix}.norm_2\", weights=weights, eps=EPS\n            )\n        self.attn = MultiheadAttention(config, prefix=f\"{prefix}.attn\", weights=weights)\n        self.ffn = MPTMLP(config, prefix=f\"{prefix}.ffn\", weights=weights)\n        self.resid_attn_dropout = nn.Dropout(resid_pdrop)\n        self.resid_ffn_dropout = nn.Dropout(resid_pdrop)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attn_bias: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        is_causal: bool = True,\n    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:\n        a = self.norm_1(x)\n        (b, attn_weights, past_key_value) = self.attn(\n            a,\n            past_key_value=past_key_value,\n            attn_bias=attn_bias,\n            attention_mask=attention_mask,\n            is_causal=is_causal,\n        )\n        x = x + self.resid_attn_dropout(b)\n        m = self.norm_2(x)\n        n = self.ffn(m)\n        x = x + self.resid_ffn_dropout(n)\n        return (x, attn_weights, past_key_value)\n\n\ndef _cast_if_autocast_enabled(tensor):\n    if torch.is_autocast_enabled():\n        if tensor.device.type == \"cuda\":\n            dtype = torch.get_autocast_gpu_dtype()\n        elif tensor.device.type == \"cpu\":\n            dtype = torch.get_autocast_cpu_dtype()\n        else:\n            raise NotImplementedError()\n        return tensor.to(dtype=dtype)\n    return tensor\n\n\nclass LPLayerNorm(torch.nn.LayerNorm):\n    def __init__(\n        self,\n        normalized_shape,\n        eps=1e-05,\n        elementwise_affine=True,\n        device=None,\n        dtype=None,\n        bias: Optional[bool] = True,\n        prefix=None,\n        weights=None,\n    ):\n        super().__init__(\n            normalized_shape=normalized_shape,\n            eps=eps,\n            elementwise_affine=elementwise_affine,\n            device=device,\n            dtype=dtype,\n            bias=bias,\n        )\n        if weights is not None:\n            self.weight = nn.Parameter(weights.get_sharded(f\"{prefix}.weight\", dim=0))\n            if bias:\n                self.bias = nn.Parameter(weights.get_sharded(f\"{prefix}.bias\", dim=0))\n            self.normalized_shape = self.weight.shape\n\n    def forward(self, x):\n        module_device = x.device\n        downcast_x = _cast_if_autocast_enabled(x)\n        downcast_weight = (\n            _cast_if_autocast_enabled(self.weight)\n            if self.weight is not None\n            else self.weight\n        )\n        downcast_bias = (\n            _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias\n        )\n        with torch.autocast(enabled=False, device_type=module_device.type):\n            return torch.nn.functional.layer_norm(\n                downcast_x,\n                self.normalized_shape,\n                downcast_weight,\n                downcast_bias,\n                self.eps,\n            )\n\n\ndef rms_norm(x, weight=None, eps=1e-05):\n    output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)\n    if weight is not None:\n        return output * weight\n    return output\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(\n        self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None\n    ):\n        super().__init__()\n        self.eps = eps\n        if weight:\n            self.weight = torch.nn.Parameter(\n                torch.ones(normalized_shape, dtype=dtype, device=device)\n            )\n        else:\n            self.register_parameter(\"weight\", None)\n\n    def forward(self, x):\n        return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)\n\n\nclass LPRMSNorm(RMSNorm):\n    def __init__(\n        self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None\n    ):\n        super().__init__(\n            normalized_shape=normalized_shape,\n            eps=eps,\n            weight=weight,\n            dtype=dtype,\n            device=device,\n        )\n\n    def forward(self, x):\n        downcast_x = _cast_if_autocast_enabled(x)\n        downcast_weight = (\n            _cast_if_autocast_enabled(self.weight)\n            if self.weight is not None\n            else self.weight\n        )\n        with torch.autocast(enabled=False, device_type=x.device.type):\n            return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)\n\n\nNORM_CLASS_REGISTRY = {\n    \"layernorm\": torch.nn.LayerNorm,\n    \"low_precision_layernorm\": LPLayerNorm,\n    \"rmsnorm\": RMSNorm,\n    \"low_precision_rmsnorm\": LPRMSNorm,\n}\n\nTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]\n\n\nclass MPTPreTrainedModel(PreTrainedModel):\n    base_model_prefix = \"model\"\n    _no_split_modules = [\"MPTBlock\"]\n\n\nclass MPTModel(MPTPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        # config._validate_config()\n        super().__init__(config)\n        self.world_size = weights.process_group.size()\n        self.rank = weights.process_group.rank()\n        self.n_heads = config.n_heads\n        self.attn_impl = config.attn_config.attn_impl\n        self.prefix_lm = config.attn_config.prefix_lm\n        self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id\n        self.alibi = config.attn_config.alibi\n        self.alibi_bias_max = config.attn_config.alibi_bias_max\n        if config.init_device == \"mixed\":\n            # TODO: reimplement mixed device initialization\n            # dist.get_local_rank() == 0:\n            if True:\n                config.init_device = \"cpu\"\n            else:\n                config.init_device = \"meta\"\n        if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():\n            norm_options = \" | \".join(NORM_CLASS_REGISTRY.keys())\n            raise NotImplementedError(\n                f\"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).\"\n            )\n        if config.norm_type.lower() != \"low_precision_layernorm\":\n            raise NotImplementedError(\n                f\"Requested norm type ({config.norm_type}) is not implemented within this repo.\"\n            )\n\n        self.wte = TensorParallelEmbedding(f\"{prefix}.wte\", weights)\n\n        if not self.alibi:\n            self.wpe = TensorParallelEmbedding(f\"{prefix}.wpe\", weights)\n        self.blocks = nn.ModuleList(\n            [\n                MPTBlock(config, prefix=f\"{prefix}.blocks.{i}\", weights=weights)\n                for i in range(config.n_layers)\n            ]\n        )\n        if config.no_bias:\n            self.norm_f = nn.LayerNorm.load_no_bias(\n                prefix=\"transformer.norm_f\", weights=weights, eps=EPS\n            )\n        else:\n            self.norm_f = nn.LayerNorm.load(\n                prefix=\"transformer.norm_f\", weights=weights, eps=EPS\n            )\n        self.is_causal = not self.prefix_lm\n        self._attn_bias_initialized = False\n        self.attn_bias = None\n        self.attn_bias_shape = attn_bias_shape(\n            self.attn_impl,\n            config.n_heads,\n            config.max_seq_len,\n            self.alibi,\n            prefix_lm=self.prefix_lm,\n            causal=self.is_causal,\n            use_sequence_id=self.attn_uses_sequence_id,\n        )\n        if config.no_bias:\n            for module in self.modules():\n                if hasattr(module, \"bias\") and isinstance(module.bias, nn.Parameter):\n                    if config.verbose:\n                        warnings.warn(f\"Removing bias ({module.bias}) from {module}.\")\n                    module.register_parameter(\"bias\", None)\n        if hasattr(self.config, \"verbose\"):\n            if config.verbose and config.verbose > 2:\n                print(self)\n        if \"verbose\" not in self.config.init_config:\n            self.config.init_config[\"verbose\"] = self.config.verbose\n        if self.config.init_config[\"verbose\"] > 1:\n            init_fn_name = self.config.init_config[\"name\"]\n            warnings.warn(f\"Using {init_fn_name} initialization.\")\n\n    @torch.no_grad()\n    def _attn_bias(\n        self,\n        device,\n        dtype,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        prefix_mask: Optional[torch.ByteTensor] = None,\n        sequence_id: Optional[torch.LongTensor] = None,\n    ):\n        if not self._attn_bias_initialized:\n            if self.attn_bias_shape:\n                self.attn_bias = torch.zeros(\n                    self.attn_bias_shape, device=device, dtype=dtype\n                )\n                self.attn_bias = build_attn_bias(\n                    self.attn_impl,\n                    self.attn_bias,\n                    self.config.n_heads,\n                    self.config.max_seq_len,\n                    causal=self.is_causal,\n                    alibi=self.alibi,\n                    alibi_bias_max=self.alibi_bias_max,\n                )\n                assert self.n_heads % self.world_size == 0\n                block_size = self.n_heads // self.world_size\n                self.attn_bias = self.attn_bias[\n                    :, self.rank * block_size : (self.rank + 1) * block_size\n                ]\n            self._attn_bias_initialized = True\n        if self.attn_impl == \"flash\":\n            return (self.attn_bias, attention_mask)\n        if self.attn_bias is not None:\n            self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)\n        attn_bias = self.attn_bias\n        if self.prefix_lm:\n            assert isinstance(attn_bias, torch.Tensor)\n            assert isinstance(prefix_mask, torch.Tensor)\n            attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)\n        if self.attn_uses_sequence_id and sequence_id is not None:\n            assert isinstance(attn_bias, torch.Tensor)\n            attn_bias = self._apply_sequence_id(attn_bias, sequence_id)\n        if attention_mask is not None:\n            s_k = attention_mask.shape[-1]\n            if attn_bias is None:\n                attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)\n            else:\n                _s_k = max(0, attn_bias.size(-1) - s_k)\n                attn_bias = attn_bias[:, :, :, _s_k:]\n            if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:\n                raise ValueError(\n                    f\"attention_mask shape={attention_mask.shape} \"\n                    + f\"and prefix_mask shape={prefix_mask.shape} are not equal.\"\n                )\n            min_val = torch.finfo(attn_bias.dtype).min\n            attn_bias = attn_bias.masked_fill(\n                ~attention_mask.view(-1, 1, 1, s_k), min_val\n            )\n        return (attn_bias, None)\n\n    def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):\n        (s_k, s_q) = attn_bias.shape[-2:]\n        if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:\n            raise ValueError(\n                \"attn_bias does not match the expected shape. \"\n                + f\"The last two dimensions should both be {self.config.max_length} \"\n                + f\"but are {s_k} and {s_q}.\"\n            )\n        seq_len = prefix_mask.shape[-1]\n        if seq_len > self.config.max_seq_len:\n            raise ValueError(\n                f\"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}\"\n            )\n        attn_bias = attn_bias[..., :seq_len, :seq_len]\n        causal = torch.tril(\n            torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)\n        ).view(1, 1, seq_len, seq_len)\n        prefix = prefix_mask.view(-1, 1, 1, seq_len)\n        cannot_attend = ~torch.logical_or(causal, prefix.bool())\n        min_val = torch.finfo(attn_bias.dtype).min\n        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)\n        return attn_bias\n\n    def _apply_sequence_id(\n        self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor\n    ):\n        seq_len = sequence_id.shape[-1]\n        if seq_len > self.config.max_seq_len:\n            raise ValueError(\n                f\"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}\"\n            )\n        attn_bias = attn_bias[..., :seq_len, :seq_len]\n        cannot_attend = torch.logical_not(\n            torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))\n        ).unsqueeze(1)\n        min_val = torch.finfo(attn_bias.dtype).min\n        attn_bias = attn_bias.masked_fill(cannot_attend, min_val)\n        return attn_bias\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        prefix_mask: Optional[torch.ByteTensor] = None,\n        sequence_id: Optional[torch.LongTensor] = None,\n        return_dict: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        use_cache: Optional[bool] = None,\n    ):\n        return_dict = (\n            return_dict if return_dict is not None else self.config.return_dict\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        if attention_mask is not None:\n            attention_mask = attention_mask.bool()\n        if prefix_mask is not None:\n            prefix_mask = prefix_mask.bool()\n        if not return_dict:\n            raise NotImplementedError(\n                \"return_dict False is not implemented yet for MPT\"\n            )\n        if output_attentions:\n            if self.attn_impl != \"torch\":\n                raise NotImplementedError(\n                    \"output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.\"\n                )\n        if (\n            attention_mask is not None\n            and attention_mask[:, 0].sum() != attention_mask.shape[0]\n            and self.training\n        ):\n            raise NotImplementedError(\n                \"MPT does not support training with left padding.\"\n            )\n        if self.prefix_lm and prefix_mask is None:\n            raise ValueError(\n                \"prefix_mask is a required argument when MPT is configured with prefix_lm=True.\"\n            )\n        if self.training:\n            if self.attn_uses_sequence_id and sequence_id is None:\n                raise ValueError(\n                    \"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True \"\n                    + \"and the model is in train mode.\"\n                )\n            elif self.attn_uses_sequence_id is False and sequence_id is not None:\n                warnings.warn(\n                    \"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. \"\n                    + \"This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.\"\n                )\n        S = input_ids.size(1)\n        assert (\n            S <= self.config.max_seq_len\n        ), f\"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}\"\n        tok_emb = self.wte(input_ids)\n        if self.alibi:\n            x = tok_emb\n        else:\n            past_position = 0\n            if past_key_values is not None:\n                if len(past_key_values) != self.config.n_layers:\n                    raise ValueError(\n                        \"past_key_values must provide a past_key_value for each attention \"\n                        + f\"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).\"\n                    )\n                past_position = past_key_values[0][0].size(1)\n                if self.attn_impl == \"torch\":\n                    past_position = past_key_values[0][0].size(3)\n            if S + past_position > self.config.max_seq_len:\n                raise ValueError(\n                    f\"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.\"\n                )\n            pos = torch.arange(\n                past_position,\n                S + past_position,\n                dtype=torch.long,\n                device=input_ids.device,\n            ).unsqueeze(0)\n            if attention_mask is not None:\n                pos = torch.clamp(\n                    pos\n                    - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[\n                        :, past_position:\n                    ],\n                    min=0,\n                )\n            pos_emb = self.wpe(pos)\n            x = tok_emb + pos_emb\n        (attn_bias, attention_mask) = self._attn_bias(\n            device=x.device,\n            dtype=torch.float32,\n            attention_mask=attention_mask,\n            prefix_mask=prefix_mask,\n            sequence_id=sequence_id,\n        )\n        if use_cache and past_key_values is None:\n            past_key_values = [() for _ in range(self.config.n_layers)]\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        for b_idx, block in enumerate(self.blocks):\n            if output_hidden_states:\n                assert all_hidden_states is not None\n                all_hidden_states = all_hidden_states + (x,)\n            past_key_value = (\n                past_key_values[b_idx] if past_key_values is not None else None\n            )\n            (x, attn_weights, past_key_value) = block(\n                x,\n                past_key_value=past_key_value,\n                attn_bias=attn_bias,\n                attention_mask=attention_mask,\n                is_causal=self.is_causal,\n            )\n            if past_key_values is not None:\n                past_key_values[b_idx] = past_key_value\n            if output_attentions:\n                assert all_self_attns is not None\n                all_self_attns = all_self_attns + (attn_weights,)\n        x = self.norm_f(x)\n        if output_hidden_states:\n            assert all_hidden_states is not None\n            all_hidden_states = all_hidden_states + (x,)\n        return BaseModelOutputWithPast(\n            last_hidden_state=x,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass MPTForCausalLM(MPTPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        if not config.tie_word_embeddings:\n            raise ValueError(\"MPTForCausalLM only supports tied word embeddings\")\n        self.transformer = MPTModel(prefix, config, weights)\n        self.lm_head = SpeculativeHead.load(\n            config, prefix=f\"{prefix}.wte\", weights=weights\n        )\n        self.logit_scale = None\n        if config.logit_scale is not None:\n            logit_scale = config.logit_scale\n            if isinstance(logit_scale, str):\n                if logit_scale == \"inv_sqrt_d_model\":\n                    logit_scale = 1 / math.sqrt(config.d_model)\n                else:\n                    raise ValueError(\n                        f\"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.\"\n                    )\n            self.logit_scale = logit_scale\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        prefix_mask: Optional[torch.ByteTensor] = None,\n        sequence_id: Optional[torch.LongTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        return_dict: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        use_cache: Optional[bool] = None,\n    ):\n        return_dict = (\n            return_dict if return_dict is not None else self.config.return_dict\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        outputs = self.transformer(\n            input_ids=input_ids,\n            past_key_values=past_key_values,\n            attention_mask=attention_mask,\n            prefix_mask=prefix_mask,\n            sequence_id=sequence_id,\n            return_dict=return_dict,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            use_cache=use_cache,\n        )\n        logits, speculative_logits = self.lm_head(outputs.last_hidden_state)\n        if self.logit_scale is not None:\n            if self.logit_scale == 0:\n                warnings.warn(\n                    f\"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.\"\n                )\n            logits *= self.logit_scale\n        loss = None\n        if labels is not None:\n            labels = torch.roll(labels, shifts=-1)\n            labels[:, -1] = -100\n            loss = F.cross_entropy(\n                logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)\n            )\n        return (\n            CausalLMOutputWithPast(\n                loss=loss,\n                logits=logits,\n                past_key_values=outputs.past_key_values,\n                hidden_states=outputs.hidden_states,\n                attentions=outputs.attentions,\n            ),\n            speculative_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs\n    ):\n        if inputs_embeds is not None:\n            raise NotImplementedError(\"inputs_embeds is not implemented for MPT yet\")\n        attention_mask = kwargs[\"attention_mask\"].bool()\n        if attention_mask[:, -1].sum() != attention_mask.shape[0]:\n            raise NotImplementedError(\n                \"MPT does not support generation with right padding.\"\n            )\n        if self.transformer.attn_uses_sequence_id and self.training:\n            sequence_id = torch.zeros_like(input_ids[:1])\n        else:\n            sequence_id = None\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n        if self.transformer.prefix_lm:\n            prefix_mask = torch.ones_like(attention_mask)\n            if kwargs.get(\"use_cache\") is False:\n                raise NotImplementedError(\n                    \"MPT with prefix_lm=True does not support use_cache=False.\"\n                )\n        else:\n            prefix_mask = None\n        return {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"prefix_mask\": prefix_mask,\n            \"sequence_id\": sequence_id,\n            \"past_key_values\": past_key_values,\n            \"use_cache\": kwargs.get(\"use_cache\", True),\n        }\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        \"\"\"Used by HuggingFace generate when using beam search with kv-caching.\n\n        See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133\n        for an example in transformers.\n        \"\"\"\n        reordered_past = []\n        for layer_past in past_key_values:\n            reordered_past += [\n                tuple(\n                    (past_state.index_select(0, beam_idx) for past_state in layer_past)\n                )\n            ]\n        return reordered_past\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/neox_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch GPTNeoX model.\"\"\"\n\nfrom typing import Optional, Tuple, Union\n\nimport os\nimport torch\nimport torch.distributed\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n)\n\n\nCUSTOM_KERNELS_ENABLED = False\nif (\n    torch.cuda.is_available()\n    and not os.environ.get(\"DISABLE_CUSTOM_KERNELS\", \"False\") == \"True\"\n):\n    try:\n        from custom_kernels import fused_attention_cuda\n\n        CUSTOM_KERNELS_ENABLED = True\n    except ImportError:\n        pass\n\n\ndef make_causal_mask(\n    input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int\n) -> torch.BoolTensor:\n    \"\"\"\n    Make causal mask used for self-attention.\n    \"\"\"\n    batch_size, target_length = input_ids_shape\n    mask = torch.ones(\n        (target_length, target_length + past_key_values_length),\n        dtype=torch.bool,\n        device=device,\n    )\n    mask = mask.triu(1 + past_key_values_length)\n\n    expanded_mask = mask.unsqueeze(0).expand(\n        batch_size, target_length, target_length + past_key_values_length\n    )\n    return expanded_mask\n\n\ndef expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:\n    \"\"\"\n    Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.\n    \"\"\"\n    batch_size, src_length = mask.shape\n    tgt_length = tgt_length if tgt_length is not None else src_length\n\n    expanded_mask = ~(mask[:, None, :].to(torch.bool))\n    return expanded_mask.expand(batch_size, tgt_length, src_length)\n\n\ndef prepare_attn_mask(\n    attention_mask: torch.Tensor,\n    input_shape: Tuple[int, int],\n    past_key_values_length: int,\n) -> torch.BoolTensor:\n    # create causal mask\n    # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n    combined_attention_mask = None\n    device = attention_mask.device\n    _, src_length = input_shape\n\n    if src_length > 1:\n        combined_attention_mask = make_causal_mask(\n            input_shape, device=device, past_key_values_length=past_key_values_length\n        )\n\n    # [batch_size, seq_length] -> [batch_size, tgt_length, src_length]\n    expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)\n    combined_attention_mask = (\n        expanded_attn_mask\n        if combined_attention_mask is None\n        else expanded_attn_mask | combined_attention_mask\n    )\n\n    return combined_attention_mask\n\n\nclass GPTNeoXPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n\nclass GPTNeoXAttention(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.num_attention_heads = config.num_attention_heads\n        self.hidden_size = config.hidden_size\n        self.head_size = self.hidden_size // self.num_attention_heads\n        self.rotary_ndims = int(self.head_size * config.rotary_pct)\n        # ??? TODO\n        # self.register_buffer(\n        #     \"bias\",\n        #     torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(\n        #         1, 1, max_positions, max_positions\n        #     ),\n        # )\n        # self.register_buffer(\"masked_bias\", torch.tensor(-1e9))\n        self.rotary_emb = RotaryEmbedding(\n            self.rotary_ndims,\n            config.max_position_embeddings,\n            base=config.rotary_emb_base,\n        )\n        self.rotary_emb.inv_freq = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.rotary_emb.inv_freq\")\n        )\n        self.inv_norm_factor = 1.0 / torch.sqrt(\n            torch.tensor(self.head_size, dtype=torch.float32)\n        ).to(torch.get_default_dtype())\n\n        if self.num_attention_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_attention_heads` must be divisible by `num_shards` \"\n                f\"(got `num_attention_heads`: {self.num_attention_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_attention_heads = (\n            self.num_attention_heads // weights.process_group.size()\n        )\n        self.query_key_value = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.query_key_value\", weights=weights, bias=True\n        )\n        self.dense = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.dense\", weights=weights, bias=True\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        position_ids,\n        attention_mask,\n        head_mask=None,\n        layer_past=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        has_layer_past = layer_past is not None\n\n        # Compute QKV\n        # Attention heads [batch, seq_len, hidden_size]\n        #   --> [batch, seq_len, (np * 3 * head_size)]\n        qkv = self.query_key_value(hidden_states)\n\n        # [batch, seq_len, (num_heads * 3 * head_size)]\n        #   --> [batch, seq_len, num_heads, 3 * head_size]\n        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)\n        qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)\n        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]\n        query, key, value = qkv.split(self.head_size, -1)\n\n        # Compute token offset for rotary embeddings (when decoding)\n        seq_len = key.shape[-2]\n        if has_layer_past:\n            seq_len += layer_past[0].shape[-2]\n\n        # Compute rotary embeddings on rotary_ndims\n        query_rot = query[..., : self.rotary_ndims]\n        key_rot = key[..., : self.rotary_ndims]\n\n        query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)\n\n        query[..., : self.rotary_ndims] = query_rot\n        key[..., : self.rotary_ndims] = key_rot\n\n        if CUSTOM_KERNELS_ENABLED:\n            attn_output, present, attn_weights = fused_attention_cuda.forward(\n                query,\n                key,\n                value,\n                layer_past,\n                attention_mask,\n                head_mask,\n                self.inv_norm_factor,\n                self.num_attention_heads,\n                use_cache,\n            )\n        else:\n            # Cache QKV values\n            if has_layer_past:\n                past_key = layer_past[0]\n                past_value = layer_past[1]\n                key = torch.cat((past_key, key), dim=-2)\n                value = torch.cat((past_value, value), dim=-2)\n            present = (key, value) if use_cache else None\n\n            # Compute attention\n            attn_output, attn_weights = self._attn(\n                query, key, value, attention_mask, head_mask\n            )\n\n            # Reshape outputs\n            attn_output = self._merge_heads(\n                attn_output, self.num_attention_heads, self.head_size\n            )\n\n        attn_output = self.dense(attn_output)\n\n        outputs = (attn_output, present)\n        if output_attentions:\n            outputs += (attn_weights,)\n\n        return outputs\n\n    @classmethod\n    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):\n        \"\"\"\n        Splits hidden dim into attn_head_size and num_attention_heads\n        \"\"\"\n        # tensor: [bs, seq_len, hidden_size]\n        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)\n        # -> [bs, seq_len, num_attention_heads, attn_head_size]\n        tensor = tensor.view(new_shape)\n        # -> [bs, num_attention_heads, seq_len, attn_head_size]\n        tensor = tensor.permute(0, 2, 1, 3)\n        return tensor\n\n    @classmethod\n    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):\n        \"\"\"\n        Merges attn_head_size dim and num_attn_heads dim into hidden dim\n        \"\"\"\n        # tensor [bs, num_attention_heads, seq_len, attn_head_size]\n        tensor = tensor.permute(0, 2, 1, 3).contiguous()\n        # -> [bs, seq_len, num_attention_heads, attn_head_size]\n        tensor = tensor.view(\n            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size\n        )\n        # -> [bs, seq_len, hidden_size]\n        return tensor\n\n    def _attn(self, query, key, value, attention_mask=None, head_mask=None):\n        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]\n        # compute causal mask from causal mask buffer\n        batch_size, num_attention_heads, query_length, attn_head_size = query.size()\n        key_length = key.size(-2)\n\n        query = query.reshape(\n            batch_size * num_attention_heads, query_length, attn_head_size\n        )\n        key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)\n        attn_scores = torch.zeros(\n            1,\n            dtype=query.dtype,\n            device=key.device,\n        ).expand(batch_size * num_attention_heads, query_length, key_length)\n        attn_scores = torch.baddbmm(\n            attn_scores,\n            query,\n            key.transpose(1, 2),\n            beta=1.0,\n            alpha=self.inv_norm_factor,\n        )\n\n        # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]\n        input_dtype = attn_scores.dtype\n        if input_dtype in [torch.float16, torch.bfloat16]:\n            attn_scores = attn_scores.to(torch.float)\n        attn_scores = torch.where(\n            attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores\n        )\n        attn_scores = attn_scores.view(\n            batch_size, num_attention_heads, query_length, key_length\n        )\n\n        attn_weights = nn.functional.softmax(attn_scores, dim=-1)\n        attn_weights = attn_weights.to(value.dtype)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attn_weights = attn_weights * head_mask\n\n        attn_output = torch.matmul(attn_weights, value)\n        return attn_output, attn_weights\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    def __init__(self, dim, max_position_embeddings, base=10000, device=None):\n        super().__init__()\n        self.true_inv_freq = 1.0 / (\n            base ** (torch.arange(0, dim, 2).float().to(device) / dim)\n        )\n        self.register_buffer(\"inv_freq\", self.true_inv_freq)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n        self.cos_cached = None\n        self.sin_cached = None\n\n    @staticmethod\n    def rotate_half(x):\n        \"\"\"Rotates half the hidden dims of the input.\"\"\"\n        x1 = x[..., : x.shape[-1] // 2]\n        x2 = x[..., x.shape[-1] // 2 :]\n        return torch.cat((-x2, x1), dim=-1)\n\n    @staticmethod\n    def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):\n        t = torch.arange(\n            max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype\n        )\n        freqs = torch.einsum(\"i,j->ij\", t, inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)\n\n    def forward(self, q, k, position_ids, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if (\n            seq_len > self.max_seq_len_cached\n            or self.cos_cached is None\n            or self.sin_cached is None\n        ):\n            if seq_len > self.max_seq_len_cached:\n                self.max_seq_len_cached = seq_len\n            self.cos_cached, self.sin_cached = self._create_cos_sin(\n                self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device\n            )\n        return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)\n\n\n@torch.jit.script\ndef rotary_forward(q, k, cos, sin, position_ids):\n    cos = cos[position_ids].unsqueeze(1)\n    sin = sin[position_ids].unsqueeze(1)\n\n    chunk_size = q.shape[-1] // 2\n    q1, q2 = q.split(chunk_size, -1)\n    q_rotated = torch.cat((-q2, q1), dim=-1)\n    k1, k2 = k.split(chunk_size, -1)\n    k_rotated = torch.cat((-k2, k1), dim=-1)\n\n    q_embed = (q * cos) + (q_rotated * sin)\n    k_embed = (k * cos) + (k_rotated * sin)\n    return q_embed, k_embed\n\n\nclass GPTNeoXMLP(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.act = (\n            ACT2FN[config.hidden_act]\n            if \"gelu_fast\" not in config.hidden_act\n            else lambda x: torch.nn.functional.gelu(x, approximate=\"tanh\")\n        )\n\n        self.dense_h_to_4h = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.dense_h_to_4h\", weights=weights, bias=True\n        )\n        self.dense_4h_to_h = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.dense_4h_to_h\", weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        return hidden_states\n\n\nclass GPTNeoXLayer(nn.Module):\n    def __init__(self, layer_id, prefix: str, config, weights):\n        super().__init__()\n        self.use_parallel_residual = config.use_parallel_residual\n        self.input_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layers.{layer_id}.input_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.post_attention_layernorm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layers.{layer_id}.post_attention_layernorm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.attention = GPTNeoXAttention(\n            config, prefix=f\"{prefix}.layers.{layer_id}.attention\", weights=weights\n        )\n        self.mlp = GPTNeoXMLP(\n            config, prefix=f\"{prefix}.layers.{layer_id}.mlp\", weights=weights\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        position_ids,\n        attention_mask=None,\n        head_mask=None,\n        use_cache=False,\n        layer_past=None,\n        output_attentions=False,\n    ):\n        attention_layer_outputs = self.attention(\n            self.input_layernorm(hidden_states),\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            layer_past=layer_past,\n            head_mask=head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        attn_output = attention_layer_outputs[\n            0\n        ]  # output_attn: attn_output, present, (attn_weights)\n        outputs = attention_layer_outputs[1:]\n\n        if self.use_parallel_residual:\n            # pseudocode:\n            # x = x + attn(ln1(x)) + mlp(ln2(x))\n            mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))\n            hidden_states = mlp_output + attn_output + hidden_states\n        else:\n            # pseudocode:\n            # x = x + attn(ln1(x))\n            # x = x + mlp(ln2(x))\n            attn_output = attn_output + hidden_states\n            mlp_output = self.mlp(self.post_attention_layernorm(attn_output))\n            hidden_states = mlp_output + attn_output\n\n        if use_cache:\n            outputs = (\n                hidden_states,\n            ) + outputs  # hidden_states, present, (attn_weights)\n        else:\n            outputs = (hidden_states,) + outputs[1:]  # hidden_states, (attn_weights)\n\n        return outputs\n\n\nclass GPTNeoXModel(GPTNeoXPreTrainedModel):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n        self.config = config\n\n        self.num_attention_heads = config.num_attention_heads\n\n        self.embed_in = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embed_in\", weights=weights\n        )\n        self.layers = nn.ModuleList(\n            [\n                GPTNeoXLayer(layer_id, prefix, config, weights)\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n        self.final_layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.final_layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_eps,\n        )\n        self.tp_world_size = weights.process_group.size()\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        position_ids=None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        r\"\"\"\n        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        batch_size, seq_length = input_shape\n\n        if past_key_values is None:\n            past_length = 0\n            past_key_values = tuple([None] * self.config.num_hidden_layers)\n        else:\n            past_length = past_key_values[0][0].size(-2)\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_length, seq_length + past_length, dtype=torch.long, device=device\n            )\n            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)\n        else:\n            position_ids = position_ids.view(-1, seq_length).long()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_in(input_ids)\n\n        hidden_states = inputs_embeds\n\n        # Attention mask.\n        seq_length_with_past = seq_length\n        past_key_values_length = 0\n        if past_key_values[0] is not None:\n            past_key_values_length = past_key_values[0][0].shape[-1]\n            seq_length_with_past = seq_length_with_past + past_key_values_length\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                (batch_size, seq_length_with_past), device=hidden_states.device\n            )\n        else:\n            attention_mask = attention_mask.to(hidden_states.device)\n\n        causal_mask = prepare_attn_mask(\n            attention_mask,\n            input_shape=(batch_size, seq_length),\n            past_key_values_length=past_key_values_length,\n        )\n\n        assert self.num_attention_heads % self.tp_world_size == 0\n        block_size = self.num_attention_heads // self.tp_world_size\n        causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n\n        presents = () if use_cache else None\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            outputs = layer(\n                hidden_states,\n                position_ids=position_ids,\n                attention_mask=causal_mask,\n                head_mask=head_mask[i],\n                layer_past=layer_past,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n            hidden_states = outputs[0]\n            if use_cache is True:\n                presents = presents + (outputs[1],)\n            if output_attentions:\n                all_attentions = all_attentions + (outputs[2 if use_cache else 1],)\n\n        hidden_states = self.final_layer_norm(hidden_states)\n        # Add last hidden state\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, presents, all_hidden_states, all_attentions]\n                if v is not None\n            )\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=presents,\n            hidden_states=all_hidden_states,\n            attentions=all_attentions,\n        )\n\n\nclass GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, prefix: str, config, weights):\n        super().__init__(config)\n\n        if not prefix:\n            prefix = \"gpt_neox\"\n        else:\n            prefix = f\"{prefix}.gpt_neox\"\n\n        self.gpt_neox = GPTNeoXModel(prefix, config, weights)\n        self.embed_out = SpeculativeHead.load(\n            config, prefix=\"embed_out\", weights=weights\n        )\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are\n            only required when the model is used as a decoder in a Sequence to Sequence model.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see\n            `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are\n            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig\n        >>> import torch\n\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n        >>> config = GPTNeoXConfig.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n        >>> config.is_decoder = True\n        >>> model = GPTNeoXForCausalLM.from_pretrained(\"EleutherAI/gpt-neox-20b\", config=config)\n\n        >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n        >>> outputs = model(**inputs)\n\n        >>> prediction_logits = outputs.logits\n        ```\"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.gpt_neox(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        lm_logits, speculative_logits = self.embed_out(hidden_states)\n\n        lm_loss = None\n        if labels is not None:\n            # move labels to correct device to enable model parallelism\n            labels = labels.to(lm_logits.device)\n            # we are doing next-token prediction; shift prediction scores and input ids by one\n            shift_logits = lm_logits[:, :-1, :].contiguous()\n            labels = labels[:, 1:].contiguous()\n            loss_fct = CrossEntropyLoss()\n            lm_loss = loss_fct(\n                shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)\n            )\n\n        if not return_dict:\n            output = (lm_logits,) + outputs[1:]\n            return ((lm_loss,) + output) if lm_loss is not None else output\n\n        return (\n            CausalLMOutputWithPast(\n                loss=lm_loss,\n                logits=lm_logits,\n                past_key_values=outputs.past_key_values,\n                hidden_states=outputs.hidden_states,\n                attentions=outputs.attentions,\n            ),\n            speculative_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        **kwargs,\n    ):\n        input_shape = input_ids.shape\n\n        # cut decoder_input_ids if past is used\n        if past_key_values and past_key_values[0] is not None:\n            input_ids = input_ids[:, -1:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -1].unsqueeze(-1)\n\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            attention_mask = input_ids.new_ones(input_shape)\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"attention_mask\": attention_mask,\n                \"past_key_values\": past_key_values,\n                \"position_ids\": position_ids,\n            }\n        )\n\n        return model_inputs\n\n    def _reorder_cache(self, past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx)\n                    for past_state in layer_past[:2]\n                )\n                + layer_past[2:],\n            )\n        return reordered_past\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/opt_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch OPT model.\"\"\"\n\nimport random\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers import OPTConfig\nfrom text_generation_server.layers import (\n    FastLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n)\n\nEPS = 1e-5\n\n\n# Copied from transformers.models.bart.modeling_bart._make_causal_mask\ndef _make_causal_mask(\n    input_ids_shape: torch.Size,\n    dtype: torch.dtype,\n    device: torch.device,\n    past_key_values_length: int = 0,\n):\n    \"\"\"\n    Make causal mask used for bi-directional self-attention.\n    \"\"\"\n    bsz, tgt_len = input_ids_shape\n    mask = torch.full(\n        (tgt_len, tgt_len),\n        torch.tensor(torch.finfo(dtype).min, device=device),\n        device=device,\n    )\n    mask_cond = torch.arange(mask.size(-1), device=device)\n    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)\n    mask = mask.to(dtype)\n\n    if past_key_values_length > 0:\n        mask = torch.cat(\n            [\n                torch.zeros(\n                    tgt_len, past_key_values_length, dtype=dtype, device=device\n                ),\n                mask,\n            ],\n            dim=-1,\n        )\n    return mask[None, None, :, :].expand(\n        bsz, 1, tgt_len, tgt_len + past_key_values_length\n    )\n\n\ndef _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(\n        inverted_mask.to(torch.bool), torch.finfo(dtype).min\n    )\n\n\nclass OPTLearnedPositionalEmbedding(nn.Module):\n    \"\"\"\n    This module learns positional embeddings up to a fixed maximum size.\n    \"\"\"\n\n    def __init__(self, prefix: str, weights):\n        super().__init__()\n        self.offset = 2\n        self.weight = nn.Parameter(\n            weights.get_tensor(\n                f\"{prefix if prefix else ''}decoder.embed_positions.weight\"\n            )\n        )\n\n    def forward(\n        self, attention_mask: torch.LongTensor, past_key_values_length: int = 0\n    ):\n        \"\"\"`input_ids_shape` is expected to be [bsz x seqlen].\"\"\"\n        attention_mask = attention_mask.long()\n\n        # create positions depending on attention_mask\n        positions = (\n            torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask\n        ).long() - 1\n\n        # cut positions if `past_key_values_length` is > 0\n        positions = positions[:, past_key_values_length:]\n\n        return torch.nn.functional.embedding(positions + self.offset, self.weight)\n\n\nclass OPTAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(\n        self,\n        config,\n        prefix,\n        weights,\n        is_decoder: bool = False,\n        bias: bool = True,\n        process_group=None,\n    ):\n        super().__init__()\n        hidden_size = config.hidden_size\n        num_heads = config.num_attention_heads\n\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        self.dropout = config.dropout\n        self.head_dim = hidden_size // num_heads\n\n        if (self.head_dim * num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {num_heads}).\"\n            )\n        self.scaling = self.head_dim**-0.5\n        self.is_decoder = is_decoder\n\n        process_group = weights.process_group\n        if self.num_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.num_heads = self.num_heads // process_group.size()\n        self.hidden_size = self.hidden_size // process_group.size()\n\n        self.q_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=bias\n        )\n        self.k_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=bias\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=bias\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=bias\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        key_value_states: Optional[torch.Tensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_head_mask: Optional[torch.Tensor] = None,\n        output_attentions: bool = False,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        # if key_value_states are provided this layer is used as a cross-attention layer\n        # for the decoder\n        is_cross_attention = key_value_states is not None\n\n        bsz, tgt_len, _ = hidden_states.size()\n\n        # get query proj\n        query_states = self.q_proj(hidden_states) * self.scaling\n        # get key, value proj\n        if is_cross_attention and past_key_value is not None:\n            # reuse k,v, cross_attentions\n            key_states = past_key_value[0]\n            value_states = past_key_value[1]\n        elif is_cross_attention:\n            # cross_attentions\n            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)\n            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)\n        elif past_key_value is not None:\n            # reuse k, v, self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n            key_states = torch.cat([past_key_value[0], key_states], dim=2)\n            value_states = torch.cat([past_key_value[1], value_states], dim=2)\n        else:\n            # self_attention\n            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n\n        if self.is_decoder:\n            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.\n            # Further calls to cross_attention layer can then reuse all cross-attention\n            # key/value_states (first \"if\" case)\n            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of\n            # all previous decoder key/value_states. Further calls to uni-directional self-attention\n            # can concat previous decoder key/value_states to current projected key/value_states (third \"elif\" case)\n            # if encoder bi-directional self-attention `past_key_value` is always `None`\n            past_key_value = (key_states, value_states)\n\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = torch.max(\n                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437\n        if attn_weights.dtype == torch.float16:\n            attn_weights = nn.functional.softmax(\n                attn_weights, dim=-1, dtype=torch.float32\n            ).to(torch.float16)\n        else:\n            attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n\n        if layer_head_mask is not None:\n            if layer_head_mask.size() != (self.num_heads,):\n                raise ValueError(\n                    f\"Head mask for a single layer should be of size {(self.num_heads,)}, but is\"\n                    f\" {layer_head_mask.size()}\"\n                )\n            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(\n                bsz, self.num_heads, tgt_len, src_len\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        if output_attentions:\n            # this operation is a bit awkward, but it's required to\n            # make sure that attn_weights keeps its gradient.\n            # In order to do so, attn_weights have to be reshaped\n            # twice and have to be reused in the following\n            attn_weights_reshaped = attn_weights.view(\n                bsz, self.num_heads, tgt_len, src_len\n            )\n            attn_weights = attn_weights_reshaped.view(\n                bsz * self.num_heads, tgt_len, src_len\n            )\n        else:\n            attn_weights_reshaped = None\n\n        attn_probs = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n\n        attn_output = torch.bmm(attn_probs, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)\n        attn_output = attn_output.transpose(1, 2)\n\n        # Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be\n        # partitioned aross GPUs when using tensor-parallelism.\n        attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights_reshaped, past_key_value\n\n\nclass OPTDecoderLayer(nn.Module):\n    def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):\n        super().__init__()\n        self.process_group = weights.process_group\n        self.hidden_size = config.hidden_size\n        self.self_attn = OPTAttention(\n            config,\n            prefix=f\"{prefix}.self_attn\",\n            weights=weights,\n            is_decoder=True,\n            bias=config.enable_bias,\n        )\n        self.do_layer_norm_before = config.do_layer_norm_before\n        self.dropout = config.dropout\n        self.activation_fn = ACT2FN[config.activation_function]\n\n        self.self_attn_layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.self_attn_layer_norm\", weights=weights, eps=EPS\n        )\n        self.fc1 = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.fc1\", weights=weights, bias=config.enable_bias\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.fc2\", weights=weights, bias=config.enable_bias\n        )\n        self.final_layer_norm = nn.LayerNorm.load(\n            prefix=f\"{prefix}.final_layer_norm\", weights=weights, eps=EPS\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        layer_head_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.\n            layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size\n                `(encoder_attention_heads,)`.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n\n        residual = hidden_states\n\n        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n        if self.do_layer_norm_before:\n            hidden_states = self.self_attn_layer_norm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            past_key_value=past_key_value,\n            attention_mask=attention_mask,\n            layer_head_mask=layer_head_mask,\n            output_attentions=output_attentions,\n        )\n        hidden_states = nn.functional.dropout(\n            hidden_states, p=self.dropout, training=self.training\n        )\n        hidden_states = residual + hidden_states\n\n        # 350m applies layer norm AFTER attention\n        if not self.do_layer_norm_before:\n            hidden_states = self.self_attn_layer_norm(hidden_states)\n\n        # Fully Connected\n        hidden_states_shape = hidden_states.shape\n        hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))\n        residual = hidden_states\n\n        # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention\n        if self.do_layer_norm_before:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = nn.functional.dropout(\n            hidden_states, p=self.dropout, training=self.training\n        )\n\n        hidden_states = (residual + hidden_states).view(hidden_states_shape)\n\n        # 350m applies layer norm AFTER attention\n        if not self.do_layer_norm_before:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nclass OPTPreTrainedModel(PreTrainedModel):\n    config_class = OPTConfig\n\n\nclass OPTDecoder(OPTPreTrainedModel):\n    def __init__(self, prefix: str, config: OPTConfig, weights):\n        super().__init__(config)\n        self.dropout = config.dropout\n        self.layerdrop = config.layerdrop\n        self.padding_idx = config.pad_token_id\n        self.max_target_positions = config.max_position_embeddings\n        self.vocab_size = config.vocab_size\n\n        prefix = prefix + \".\" if prefix else \"\"\n\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}decoder.embed_tokens\", weights=weights\n        )\n        self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_out = FastLinear.load(\n                config,\n                prefix=f\"{prefix}decoder.project_out\",\n                weights=weights,\n                bias=False,\n            )\n        else:\n            self.project_out = None\n\n        if config.word_embed_proj_dim != config.hidden_size:\n            self.project_in = FastLinear.load(\n                config,\n                prefix=f\"{prefix}decoder.project_in\",\n                weights=weights,\n                bias=False,\n            )\n        else:\n            self.project_in = None\n\n        # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility\n        # with checkpoints that have been fine-tuned before transformers v4.20.1\n        # see https://github.com/facebookresearch/metaseq/pull/164\n        if config.do_layer_norm_before and not config._remove_final_layer_norm:\n            self.final_layer_norm = nn.LayerNorm.load(\n                prefix=f\"{prefix}decoder.final_layer_norm\", weights=weights, eps=EPS\n            )\n        else:\n            self.final_layer_norm = None\n\n        self.layers = nn.ModuleList(\n            [\n                OPTDecoderLayer(\n                    layer_id,\n                    prefix=f\"{prefix}decoder.layers.{layer_id}\",\n                    config=config,\n                    weights=weights,\n                )\n                for layer_id in range(config.num_hidden_layers)\n            ]\n        )\n\n    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask\n    def _prepare_decoder_attention_mask(\n        self, attention_mask, input_shape, inputs_embeds, past_key_values_length\n    ):\n        # create causal mask\n        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n        combined_attention_mask = None\n        if input_shape[-1] > 1:\n            combined_attention_mask = _make_causal_mask(\n                input_shape,\n                inputs_embeds.dtype,\n                device=inputs_embeds.device,\n                past_key_values_length=past_key_values_length,\n            )\n\n        if attention_mask is not None:\n            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n            expanded_attn_mask = _expand_mask(\n                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]\n            ).to(inputs_embeds.device)\n            combined_attention_mask = (\n                expanded_attn_mask\n                if combined_attention_mask is None\n                else expanded_attn_mask + combined_attention_mask\n            )\n\n        return combined_attention_mask\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        r\"\"\"\n        Args:\n            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you\n                provide it.\n\n                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n                [`PreTrainedTokenizer.__call__`] for details.\n\n                [What are input IDs?](../glossary#input-ids)\n            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n                - 1 for tokens that are **not masked**,\n                - 0 for tokens that are **masked**.\n\n                [What are attention masks?](../glossary#attention-mask)\n            head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):\n                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:\n\n                - 1 indicates the head is **not masked**,\n                - 0 indicates the head is **masked**.\n\n            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of\n\n                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the\n                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those\n                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of\n                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n\n            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.\n                This is useful if you want more control over how to convert `input_ids` indices into associated vectors\n                than the model's internal embedding lookup matrix.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_hidden_states (`bool`, *optional*):\n                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors\n                for more detail.\n            return_dict (`bool`, *optional*):\n                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        \"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            raise ValueError(\n                \"You have to specify either decoder_input_ids or decoder_inputs_embeds\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        batch_size, seq_length = input_shape\n        past_key_values_length = (\n            past_key_values[0][0].shape[2] if past_key_values is not None else 0\n        )\n        # required mask seq length can be calculated via length of past\n        mask_seq_length = past_key_values_length + seq_length\n\n        # embed positions\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                batch_size, mask_seq_length, device=inputs_embeds.device\n            )\n        causal_attention_mask = self._prepare_decoder_attention_mask(\n            attention_mask, input_shape, inputs_embeds, past_key_values_length\n        )\n        pos_embeds = self.embed_positions(attention_mask, past_key_values_length)\n\n        if self.project_in is not None:\n            inputs_embeds = self.project_in(inputs_embeds)\n\n        hidden_states = inputs_embeds + pos_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = () if use_cache else None\n\n        # check if head_mask has a correct number of layers specified if desired\n        for attn_mask, mask_name in zip([head_mask], [\"head_mask\"]):\n            if attn_mask is not None:\n                if attn_mask.size()[0] != (len(self.layers)):\n                    raise ValueError(\n                        f\"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for\"\n                        f\" {head_mask.size()[0]}.\"\n                    )\n\n        for idx, decoder_layer in enumerate(self.layers):\n            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            dropout_probability = random.uniform(0, 1)\n            if self.training and (dropout_probability < self.layerdrop):\n                continue\n\n            past_key_value = (\n                past_key_values[idx] if past_key_values is not None else None\n            )\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=causal_attention_mask,\n                layer_head_mask=(head_mask[idx] if head_mask is not None else None),\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        if self.final_layer_norm is not None:\n            hidden_states = self.final_layer_norm(hidden_states)\n\n        if self.project_out is not None:\n            hidden_states = self.project_out(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n\nclass OPTModel(OPTPreTrainedModel):\n    def __init__(self, prefix: str, config: OPTConfig, weights):\n        super().__init__(config)\n        self.decoder = OPTDecoder(prefix, config, weights)\n        # Initialize weights and apply final processing\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)\n        decoder_outputs = self.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        if not return_dict:\n            return decoder_outputs\n\n        return BaseModelOutputWithPast(\n            last_hidden_state=decoder_outputs.last_hidden_state,\n            past_key_values=decoder_outputs.past_key_values,\n            hidden_states=decoder_outputs.hidden_states,\n            attentions=decoder_outputs.attentions,\n        )\n\n\nclass OPTForCausalLM(OPTPreTrainedModel):\n    def __init__(self, prefix, config, weights):\n        super().__init__(config)\n        if not prefix and any(s.startswith(\"model\") for s in weights.routing.keys()):\n            prefix = \"model\"\n\n        self.model = OPTModel(prefix, config, weights)\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=f\"{prefix + '.' if prefix else ''}decoder.embed_tokens\",\n            weights=weights,\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model.decoder(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            head_mask=head_mask,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        logits, speculative_logits = self.lm_head(outputs.last_hidden_state)\n\n        loss = None\n\n        return (\n            CausalLMOutputWithPast(\n                loss=loss,\n                logits=logits,\n                past_key_values=outputs.past_key_values,\n                hidden_states=outputs.hidden_states,\n                attentions=outputs.attentions,\n            ),\n            speculative_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        **kwargs,\n    ):\n        if past_key_values:\n            input_ids = input_ids[:, -1:]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx) for past_state in layer_past\n                ),\n            )\n        return reordered_past\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/phi_modeling.py",
    "content": "# imlementation of the PhiModel and PhiForCausalLM classes\n\nimport torch\nimport torch.distributed\n\nimport math\nfrom torch import nn\nfrom typing import Optional, List, Tuple\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_outputs import CausalLMOutputWithPast\n\nfrom text_generation_server.layers import (\n    TensorParallelRowLinear,\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n    FastLinear,\n)\n\n\n# PhiConfig is the configuration class for the PhiModel.\nclass PhiConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=51200,\n        n_positions=2048,\n        n_embd=2560,\n        n_layer=32,\n        n_inner=None,\n        n_head=32,\n        rotary_dim=32,\n        layer_norm_epsilon=1e-5,\n        tie_word_embeddings=False,\n        pad_vocab_size_multiple=64,\n        pad_token_id=0,\n        bos_token_id=1,\n        eos_token_id=2,\n        no_bias=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.n_positions = n_positions\n        self.n_embd = n_embd\n        self.n_layer = n_layer\n        self.n_inner = n_inner\n        self.n_head = n_head\n        self.rotary_dim = rotary_dim\n\n        self.layer_norm_epsilon = layer_norm_epsilon\n        self.tie_word_embeddings = tie_word_embeddings\n        self.pad_vocab_size_multiple = pad_vocab_size_multiple\n        self.pad_token_id = pad_token_id\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n        self.no_bias = no_bias\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n# RotaryEmbedding is a class that implements the rotary embedding.\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_seq_len):\n        super().__init__()\n        inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]\n        inv_freq_len = len(inv_freq)\n        inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)\n        t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)\n        freqs = t.matmul(inv_freq)\n        self.sin = freqs.sin()\n        self.cos = freqs.cos()\n\n    def apply_rotary_emb_qkv(self, qkv, seqlen_offset):\n        b_size, seqlen, three, _, _headdim = qkv.shape\n        if three != 3:\n            raise Exception(\"unexpected shape for qkv\")\n        _, rotary_dim = self.cos.shape\n        rotary_dim = rotary_dim * 2\n        q_rot = qkv[:, :, 0, :, :rotary_dim]\n        q_pass = qkv[:, :, 0, :, rotary_dim:]\n        k_rot = qkv[:, :, 1, :, :rotary_dim]\n        k_pass = qkv[:, :, 1, :, rotary_dim:]\n        q12 = torch.chunk(q_rot, 2, dim=-1)\n        k12 = torch.chunk(k_rot, 2, dim=-1)\n        q1, q2 = q12[0], q12[1]\n        k1, k2 = k12[0], k12[1]\n        c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)\n        s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)\n        q_rot = torch.cat(\n            [\n                q1 * c - q2 * s,\n                q1 * s + q2 * c,\n            ],\n            dim=-1,\n        )\n        k_rot = torch.cat(\n            [\n                k1 * c - k2 * s,\n                k1 * s + k2 * c,\n            ],\n            dim=-1,\n        )\n        q = torch.cat([q_rot, q_pass], dim=-1)\n        k = torch.cat([k_rot, k_pass], dim=-1)\n        v = qkv[:, :, 2]\n        return q, k, v\n\n\n# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.\nclass PhiCausalLMHead(nn.Module):\n    def __init__(self, config, weights):\n        super().__init__()\n        self.ln = nn.LayerNorm.load(\n            prefix=\"lm_head.ln\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.linear = SpeculativeHead.load(\n            config=config, prefix=\"lm_head.linear\", weights=weights\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.ln(hidden_states)\n        hidden_states = self.linear(hidden_states)\n        return hidden_states\n\n\n# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.\nclass PhiMHA(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.Wqkv = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.Wqkv\", weights=weights, bias=not config.no_bias\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.out_proj\",\n            weights=weights,\n            bias=not config.no_bias,\n        )\n        self.op_size = config.n_embd\n        self.head_dim = int(config.n_embd / config.n_head)\n        self.num_heads = config.n_head\n        self.rotary_emb = RotaryEmbedding(\n            config.rotary_dim,\n            config.n_positions,\n        )\n        self.softmax_scale = 1.0 / math.sqrt(self.head_dim)\n\n    def forward(\n        self,\n        hidden_states,\n        past_kv_cache,\n        attention_mask=None,\n    ):\n        b_size, seq_len, _n_embd = hidden_states.shape\n        qkv = self.Wqkv(hidden_states)\n        qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)\n        seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]\n        q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)\n\n        # if there is a kv_cache, then we need to concatenate\n        if past_kv_cache is not None:\n            prev_k, prev_v = past_kv_cache\n            k = torch.cat([prev_k, k], dim=1)\n            v = torch.cat([prev_v, v], dim=1)\n\n        past_kv_cache = [k, v]\n        attn_weights = torch.einsum(\"bthd,bshd->bhts\", q, k * self.softmax_scale)\n\n        if attention_mask is not None:\n            seqlen_k = k.shape[1]\n            seqlen_q = q.shape[1]\n            causal_mask = torch.triu(\n                torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),\n                1,\n            )\n            attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)\n\n        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)\n        attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)\n        attn_output = (\n            attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))\n            .transpose(1, 2)\n            .flatten(-2)\n        )\n        return self.out_proj(attn_output), past_kv_cache\n\n\n# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.\nclass PhiMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n\n        self.n_inner = config.n_inner\n        self.fc1 = FastLinear.load(\n            config=config,\n            prefix=f\"{prefix}.fc1\",\n            weights=weights,\n            bias=False,\n        )\n        self.fc2 = FastLinear.load(\n            config=config,\n            prefix=f\"{prefix}.fc2\",\n            weights=weights,\n            bias=False,\n        )\n        self.activation = torch.nn.functional.gelu\n\n    def forward(self, hidden_states):\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\n# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.\nclass PhiBlock(nn.Module):\n    def __init__(self, layer_id, config, weights):\n        super().__init__()\n        self.layer_id = layer_id\n        self.layer_norm = nn.LayerNorm.load(\n            prefix=f\"{layer_id}.ln\", weights=weights, eps=config.layer_norm_epsilon\n        )\n        self.mixer = PhiMHA(prefix=f\"{layer_id}.mixer\", config=config, weights=weights)\n        self.mlp = PhiMLP(prefix=f\"{layer_id}.mlp\", config=config, weights=weights)\n\n    def forward(\n        self,\n        hidden_states,\n        kv_cache,\n        attention_mask,\n    ):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        attn_outputs, past_kv_cache = self.mixer(\n            hidden_states, kv_cache, attention_mask\n        )\n        feed_forward_hidden_states = self.mlp(hidden_states)\n        out = attn_outputs + feed_forward_hidden_states + residual\n        return out, past_kv_cache\n\n\n# PhiModel implements the embedding layer and the transformer blocks.\nclass PhiModel(nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n        self.tp_rank = weights.process_group.rank()\n        self.tp_world_size = weights.process_group.size()\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=f\"{prefix}.embd.wte\", weights=weights\n        )\n        self.blocks = nn.ModuleList(\n            [\n                PhiBlock(f\"{prefix}.h.{layer_id}\", config, weights)\n                for layer_id in range(config.n_layer)\n            ]\n        )\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        return_dict: Optional[bool] = None,\n        use_cache: Optional[bool] = None,\n    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:\n        hidden_states = self.embed_tokens(input_ids)\n        seq_len = hidden_states.shape[1]\n        mask = None if seq_len <= 1 else attention_mask\n\n        past_key_values = (\n            [None] * len(self.blocks) if past_key_values is None else past_key_values\n        )\n\n        for index, block in enumerate(self.blocks):\n            hidden_states, new_key_values = block(\n                hidden_states, past_key_values[index], mask\n            )\n            past_key_values[index] = new_key_values\n\n        return hidden_states, past_key_values\n\n\n# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.\nclass PhiForCausalLM(torch.nn.Module):\n    def __init__(self, prefix: str, config, weights):\n        super().__init__()\n\n        if not prefix:\n            prefix = \"transformer\"\n        else:\n            prefix = f\"{prefix}.transformer\"\n\n        self.model = PhiModel(prefix, config, weights)\n        self.lm_head = PhiCausalLMHead(config, weights)\n\n    def forward(\n        self,\n        input_ids: torch.LongTensor,\n        past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,\n        attention_mask: Optional[torch.ByteTensor] = None,\n        return_dict: Optional[bool] = None,\n        use_cache: Optional[bool] = None,\n        labels: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:\n        model_output = self.model(\n            input_ids, past_key_values, attention_mask, return_dict, use_cache\n        )\n        logits = self.lm_head(model_output[0])\n\n        loss = None\n        if labels is not None:\n            loss = nn.CrossEntropyLoss()(\n                logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)\n            )\n\n        if not return_dict:\n            return (\n                ((loss,) + (logits,) + model_output[1:])\n                if loss is not None\n                else (logits,) + model_output[1:]\n            )\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=model_output[1],\n            hidden_states=None,\n            attentions=None,\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/qwen2_5_vl.py",
    "content": "# coding=utf-8\n# Copyright 2025 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2.5 VL model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\nelse:\n    import flash_attn_2_cuda\n\nimport numpy as np\n\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\n\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers.layernorm import FastRMSNorm\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n)\nfrom text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n    Qwen2Model,\n)\n\n# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py\nfrom typing import Union\nfrom transformers.feature_extraction_utils import BatchFeature\nfrom transformers.image_utils import ImageInput, VideoInput\nfrom transformers.processing_utils import (\n    ProcessingKwargs,\n    ProcessorMixin,\n    Unpack,\n    VideosKwargs,\n)\nfrom transformers.tokenization_utils_base import PreTokenizedInput, TextInput\n\n\nclass Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):\n    fps: Union[List[float], float]\n\n\nclass Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):\n    videos_kwargs: Qwen2_5_VLVideosProcessorKwargs\n    _defaults = {\n        \"text_kwargs\": {\n            \"padding\": False,\n        },\n        \"videos_kwargs\": {\"fps\": 2.0},\n    }\n\n\nclass Qwen2_5_VLProcessor(ProcessorMixin):\n    r\"\"\"\n    Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.\n    [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the\n    [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.\n    Args:\n        image_processor ([`Qwen2VLImageProcessor`], *optional*):\n            The image processor is a required input.\n        tokenizer ([`Qwen2TokenizerFast`], *optional*):\n            The tokenizer is a required input.\n        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages\n            in a chat into a tokenizable string.\n    \"\"\"\n\n    attributes = [\"image_processor\", \"tokenizer\"]\n    valid_kwargs = [\"chat_template\"]\n\n    image_processor_class = \"AutoImageProcessor\"\n    tokenizer_class = (\"Qwen2Tokenizer\", \"Qwen2TokenizerFast\")\n\n    def __init__(\n        self, image_processor=None, tokenizer=None, chat_template=None, **kwargs\n    ):\n        self.image_token = (\n            \"<|image_pad|>\"\n            if not hasattr(tokenizer, \"image_token\")\n            else tokenizer.image_token\n        )\n        self.video_token = (\n            \"<|video_pad|>\"\n            if not hasattr(tokenizer, \"video_token\")\n            else tokenizer.video_token\n        )\n        super().__init__(image_processor, tokenizer, chat_template=chat_template)\n\n    def __call__(\n        self,\n        images: ImageInput = None,\n        text: Union[\n            TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]\n        ] = None,\n        videos: VideoInput = None,\n        **kwargs: Unpack[Qwen2_5_VLProcessorKwargs],\n    ) -> BatchFeature:\n        \"\"\"\n        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`\n        and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode\n        the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to\n        Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.\n\n        Args:\n            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch\n                tensor. Both channels-first and channels-last formats are supported.\n            text (`str`, `List[str]`, `List[List[str]]`):\n                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings\n                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set\n                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).\n            videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):\n                The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch\n                tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.\n            return_tensors (`str` or [`~utils.TensorType`], *optional*):\n                If set, will return tensors of a particular framework. Acceptable values are:\n                - `'tf'`: Return TensorFlow `tf.constant` objects.\n                - `'pt'`: Return PyTorch `torch.Tensor` objects.\n                - `'np'`: Return NumPy `np.ndarray` objects.\n                - `'jax'`: Return JAX `jnp.ndarray` objects.\n\n        Returns:\n            [`BatchFeature`]: A [`BatchFeature`] with the following fields:\n\n            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.\n            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when\n              `return_attention_mask=True` or if *\"attention_mask\"* is in `self.model_input_names` and if `text` is not\n              `None`).\n            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.\n            - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.\n            - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.\n            - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.\n            - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.\n        \"\"\"\n        output_kwargs = self._merge_kwargs(\n            Qwen2_5_VLProcessorKwargs,\n            tokenizer_init_kwargs=self.tokenizer.init_kwargs,\n            **kwargs,\n        )\n        if images is not None:\n            image_inputs = self.image_processor(\n                images=images, videos=None, **output_kwargs[\"images_kwargs\"]\n            )\n            image_grid_thw = image_inputs[\"image_grid_thw\"]\n        else:\n            image_inputs = {}\n            image_grid_thw = None\n\n        if videos is not None:\n            videos_inputs = self.image_processor(\n                images=None, videos=videos, **output_kwargs[\"images_kwargs\"]\n            )\n            video_grid_thw = videos_inputs[\"video_grid_thw\"]\n\n            fps = output_kwargs[\"videos_kwargs\"].pop(\"fps\", 2.0)\n            if isinstance(fps, (int, float)):\n                second_per_grid_ts = [\n                    self.image_processor.temporal_patch_size / fps\n                ] * len(video_grid_thw)\n            elif hasattr(fps, \"__len__\") and len(fps) == len(video_grid_thw):\n                second_per_grid_ts = [\n                    self.image_processor.temporal_patch_size / tmp for tmp in fps\n                ]\n            else:\n                raise ValueError(\n                    f\"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number.\"\n                )\n            videos_inputs.update({\"second_per_grid_ts\": second_per_grid_ts})\n\n        else:\n            videos_inputs = {}\n            video_grid_thw = None\n\n        if not isinstance(text, list):\n            text = [text]\n\n        if image_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.image_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.image_token,\n                        \"<|placeholder|>\"\n                        * (image_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.image_token)\n\n        if video_grid_thw is not None:\n            merge_length = self.image_processor.merge_size**2\n            index = 0\n            for i in range(len(text)):\n                while self.video_token in text[i]:\n                    text[i] = text[i].replace(\n                        self.video_token,\n                        \"<|placeholder|>\"\n                        * (video_grid_thw[index].prod() // merge_length),\n                        1,\n                    )\n                    index += 1\n                text[i] = text[i].replace(\"<|placeholder|>\", self.video_token)\n\n        text_inputs = self.tokenizer(text, **output_kwargs[\"text_kwargs\"])\n\n        return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})\n\n    def batch_decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please\n        refer to the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.batch_decode(*args, **kwargs)\n\n    def decode(self, *args, **kwargs):\n        \"\"\"\n        This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to\n        the docstring of this method for more information.\n        \"\"\"\n        return self.tokenizer.decode(*args, **kwargs)\n\n    def post_process_image_text_to_text(self, generated_outputs):\n        \"\"\"\n        Post-process the output of the model to decode the text.\n\n        Args:\n            generated_outputs (`torch.Tensor` or `np.ndarray`):\n                The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`\n                or `(sequence_length,)`.\n\n        Returns:\n            `List[str]`: The decoded text.\n        \"\"\"\n        return self.tokenizer.batch_decode(\n            generated_outputs,\n            skip_special_tokens=True,\n            clean_up_tokenization_spaces=False,\n        )\n\n    @property\n    def model_input_names(self):\n        tokenizer_input_names = self.tokenizer.model_input_names\n        image_processor_input_names = self.image_processor.model_input_names\n        names_from_processor = list(\n            dict.fromkeys(tokenizer_input_names + image_processor_input_names)\n        )\n        return names_from_processor + [\"second_per_grid_ts\"]\n\n\n# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py\nclass Qwen2_5_VLVisionConfig(PretrainedConfig):\n    model_type = \"qwen2_5_vl\"\n    base_config_key = \"vision_config\"\n\n    def __init__(\n        self,\n        depth=32,\n        hidden_size=3584,\n        hidden_act=\"silu\",\n        intermediate_size=3420,\n        num_heads=16,\n        in_channels=3,\n        patch_size=14,\n        spatial_merge_size=2,\n        spatial_patch_size=14,\n        temporal_patch_size=2,\n        tokens_per_second=4,\n        window_size=112,\n        out_hidden_size=3584,\n        fullatt_block_indexes=[7, 15, 23, 31],\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n\n        self.depth = depth\n        self.hidden_size = hidden_size\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.num_heads = num_heads\n        self.in_channels = in_channels\n        self.patch_size = patch_size\n        self.spatial_patch_size = spatial_patch_size\n        self.spatial_merge_size = spatial_merge_size\n        self.temporal_patch_size = temporal_patch_size\n        self.tokens_per_second = tokens_per_second\n        self.window_size = window_size\n        self.fullatt_block_indexes = fullatt_block_indexes\n        self.out_hidden_size = out_hidden_size\n\n\nclass Qwen2_5_VLConfig(PretrainedConfig):\n    def __init__(\n        self,\n        vocab_size=152064,\n        hidden_size=8192,\n        intermediate_size=29568,\n        num_hidden_layers=80,\n        num_attention_heads=64,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-05,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=1000000.0,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=80,\n        attention_dropout=0.0,\n        vision_config=None,\n        rope_scaling=None,\n        **kwargs,\n    ):\n        if vision_config is not None:\n            self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)\n\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window\n        self.max_window_layers = max_window_layers\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n        self.rope_scaling = rope_scaling\n\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations\n        # one can set it to \"linear\"/\"dynamic\" etc. to have scaled RoPE\n        # TODO: @raushan update config in the hub\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            if self.rope_scaling[\"type\"] == \"mrope\":\n                self.rope_scaling[\"type\"] = \"default\"\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb_vision(\n    tensor: torch.Tensor, freqs: torch.Tensor\n) -> torch.Tensor:\n    orig_dtype = tensor.dtype\n    tensor = tensor.float()\n    cos = freqs.cos()\n    sin = freqs.sin()\n    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()\n    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()\n    output = (tensor * cos) + (rotate_half(tensor) * sin)\n    output = output.to(orig_dtype)\n    return output\n\n\nclass Qwen2_5VLAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size // weights.process_group.size()\n        self.head_dim = config.hidden_size // config.num_heads\n        self.num_heads = config.num_heads // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv\",\n            weights=weights,\n            bias=False,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_heads,\n        )\n        self.qkv.linear.bias = weights.get_sharded(f\"{prefix}.qkv.bias\", dim=0)\n\n        self.proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.proj\",\n            weights=weights,\n            bias=True,\n        )\n        self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: torch.Tensor,\n        max_seqlen: int,\n    ) -> torch.Tensor:\n        # apply the qkv linear layer to the hidden state\n        qkv = self.qkv(hidden_state)\n        query, key, value = qkv.split(\n            [self.embed_dim, self.embed_dim, self.embed_dim], dim=1\n        )\n\n        # reshape the query, key, and value tensors\n        _shape = (\n            hidden_state.shape[0],\n            self.num_heads,\n            self.embed_dim // self.num_heads,\n        )\n        query = query.view(*_shape)\n        key = key.view(*_shape)\n        value = value.view(*_shape)\n\n        # apply rotary positional embeddings\n        query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(\n            0\n        )\n        key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)\n\n        # calc maximum sequence length for any batch\n        query = query.contiguous()\n        key = key.contiguous()\n        value = value.contiguous()\n        causal = False\n\n        # execute flash attention\n        if SYSTEM == \"ipex\":\n            attn_output = torch.empty_like(query)\n            if query.device.type == \"xpu\":\n                ipex.llm.functional.varlen_attention(\n                    query.contiguous(),\n                    key.contiguous(),\n                    value.contiguous(),\n                    attn_output,\n                    cu_seqlens,\n                    cu_seqlens,\n                    None,\n                    max_seqlen,\n                    max_seqlen,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n            else:\n                ipex.llm.functional.varlen_attention(\n                    query,\n                    key,\n                    value,\n                    attn_output,\n                    cu_seqlens,\n                    cu_seqlens,\n                    max_seqlen,\n                    max_seqlen,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n        else:\n            attn_output = flash_attn_2_cuda.varlen_fwd(\n                query,\n                key,\n                value,\n                None,  # tmp buffer (auto-allocated)\n                cu_seqlens,  # cu_seqlens_q\n                cu_seqlens,  # cu_seqlens_k\n                None,  # max_seqlen_q (auto-computed)\n                None,  # max_seqlen_k (auto-computed)\n                None,  # block_tables\n                None,  # broadcast_mask\n                max_seqlen,  # max_seqlen\n                max_seqlen,  # max_seqlen\n                0.0,  # dropout_p\n                self.softmax_scale,\n                False,  # zero_tensors\n                causal,  # causal attention within each sequence\n                -1,  # window_size_left\n                -1,  # window_size_right\n                0.0,  # softmax_cap\n                False,  # deterministic\n                None,  # rng_state\n            )[0]\n\n        # reshape output to original dimensions\n        attn_output = attn_output.reshape(hidden_state.shape[0], -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2_5VLVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.activation_fn = ACT2FN[config.hidden_act]\n\n        self.intermediate_size = (\n            config.intermediate_size // weights.process_group.size()\n        )\n\n        self.up = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.up_proj\", weights=weights, config=config, bias=True\n        )\n        self.gate = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.gate_proj\", weights=weights, config=config, bias=True\n        )\n        self.down = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.down_proj\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        gate_states = self.gate(hidden_states)\n        up_states = self.up(hidden_states)\n        activated_states = self.activation_fn(gate_states) * up_states\n        down_states = self.down(activated_states)\n        return down_states\n\n\nclass Qwen2_5VLVisionBlock(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.attn = Qwen2_5VLAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.norm1 = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm1\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.norm2 = FastRMSNorm.load(\n            prefix=f\"{prefix}.norm2\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.mlp = Qwen2_5VLVisionMLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen\n    ) -> torch.Tensor:\n        norm1_out, _ = self.norm1(hidden_states)\n        attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)\n        hidden_states = hidden_states + attn_out\n        norm2_out, _ = self.norm2(hidden_states)\n        mlp_out = self.mlp(norm2_out)\n        hidden_states = hidden_states + mlp_out\n        return hidden_states\n\n\nclass Qwen2_5VLPatchMerger(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)\n        self.patch_merger_ln_q = FastRMSNorm.load(\n            prefix=f\"{prefix}.ln_q\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.mlp.0\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.mlp.2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_states, _ = self.patch_merger_ln_q(hidden_states)\n        hidden_states = hidden_states.view(-1, self.hidden_size)\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = F.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2_5VisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n\n        self.spatial_merge_size = config.spatial_merge_size\n        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]\n        self.patch_embedding = nn.Conv3d(\n            in_channels=config.in_channels,\n            out_channels=config.hidden_size,\n            kernel_size=kernel_size,\n            stride=kernel_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embed.proj.weight\"), requires_grad=False\n        )\n        head_dim = config.hidden_size // config.num_heads\n\n        theta = 10000.0\n        dim = head_dim // 2\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        self.blocks = nn.ModuleList(\n            [\n                Qwen2_5VLVisionBlock(\n                    prefix=f\"{prefix}.blocks.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(config.depth)\n            ]\n        )\n        self.merger = Qwen2_5VLPatchMerger(\n            prefix=f\"{prefix}.merger\",\n            config=config,\n            weights=weights,\n        )\n\n        self.temporal_patch_size = config.temporal_patch_size\n        self.spatial_patch_size = config.spatial_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.hidden_size\n        self.window_size = config.window_size\n        self.patch_size = config.patch_size\n        self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size\n        self.fullatt_block_indexes = config.fullatt_block_indexes\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def get_window_index(self, grid_thw):\n        window_index: list = []\n        cu_window_seqlens: list = [0]\n        window_index_id = 0\n        vit_merger_window_size = (\n            self.window_size // self.spatial_merge_size // self.patch_size\n        )\n\n        for grid_t, grid_h, grid_w in grid_thw:\n            llm_grid_h, llm_grid_w = (\n                grid_h // self.spatial_merge_size,\n                grid_w // self.spatial_merge_size,\n            )\n            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(\n                grid_t, llm_grid_h, llm_grid_w\n            )\n            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size\n            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size\n            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size\n            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size\n            index_padded = F.pad(index, (0, pad_w, 0, pad_h), \"constant\", -100)\n            index_padded = index_padded.reshape(\n                grid_t,\n                num_windows_h,\n                vit_merger_window_size,\n                num_windows_w,\n                vit_merger_window_size,\n            )\n            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(\n                grid_t,\n                num_windows_h * num_windows_w,\n                vit_merger_window_size,\n                vit_merger_window_size,\n            )\n            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)\n            index_padded = index_padded.reshape(-1)\n            index_new = index_padded[index_padded != -100]\n            window_index.append(index_new + window_index_id)\n            cu_seqlens_tmp = (\n                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]\n            )\n            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())\n            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()\n        window_index = torch.cat(window_index, dim=0)\n\n        return window_index, cu_window_seqlens\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: Optional[torch.LongTensor] = None,\n    ) -> torch.Tensor:\n        # reshape the input tensor for processing\n        shape = (\n            -1,\n            self.in_channels,\n            self.temporal_patch_size,\n            self.spatial_patch_size,\n            self.spatial_patch_size,\n        )\n        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)\n        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)\n        # TODO: revisit to see if we can avoid some of these reshapes\n\n        # find the position ids for the input tensor based on the grid_thw\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n\n        max_grid_size = grid_thw[:, 1:].max()\n\n        # apply the positional embeddings to the position ids\n        seq = torch.arange(\n            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype\n        )\n        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        window_index, cu_window_seqlens = self.get_window_index(grid_thw)\n        seq_len = hidden_states.shape[0]\n        patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)\n        og_shape = (seq_len, -1)\n\n        hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(\n            og_shape\n        )\n        rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(\n            og_shape\n        )\n\n        rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)\n\n        cu_window_seqlens = torch.tensor(\n            cu_window_seqlens,\n            device=hidden_states.device,\n            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,\n        )\n        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)\n\n        # create a cu_seqlens tensor to be used in the attention mask\n        cu_seqlens = torch.repeat_interleave(\n            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]\n        ).cumsum(dim=0, dtype=torch.int32)\n        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)\n        max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])\n\n        # iterately apply the blocks to the hidden states\n        for layer_num, block in enumerate(self.blocks):\n            # NOTE: qwen2_5_vl.py has a concept of full attention blocks\n            # that are applied at specific layers.\n            if layer_num in self.fullatt_block_indexes:\n                cu_seqlens_now = cu_seqlens\n            else:\n                cu_seqlens_now = cu_window_seqlens\n\n            hidden_states = block(\n                hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen\n            )\n\n        # apply the final patch merger to the hidden states\n        hidden_states = self.merger(hidden_states)\n        reverse_indices = torch.argsort(window_index)\n        hidden_states = hidden_states[reverse_indices, :]\n        return hidden_states\n\n\nclass Qwen2_5VLForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        # set rope_scaling.type == \"mrope\" since AutoConfig.from_pretrained incorrectly\n        # returns rope_scaling.type == \"default\" for Qwen2_5-VL model at the moment\n        if (\n            hasattr(config, \"rope_scaling\")\n            and config.rope_scaling is not None\n            and config.rope_scaling.get(\"type\", None) == \"default\"\n        ):\n            config.rope_scaling.update({\"rope_type\": \"mrope\"})\n        self.hidden_size = config.hidden_size\n        self.vision_start_token_id = config.vision_start_token_id\n        self.vision_end_token_id = config.vision_end_token_id\n        self.image_token_id = config.image_token_id\n        self.video_token_id = config.video_token_id\n        self.spatial_merge_size = config.vision_config.spatial_merge_size\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=\"model.embed_tokens\", weights=weights\n        )\n        self.visual = Qwen2_5VisionModel(\n            prefix=\"visual\", config=config.vision_config, weights=weights\n        )\n        self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=suffix if not prefix else f\"{prefix}.{suffix}\",\n            weights=weights,\n        )\n        self.device = weights.device\n\n    # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391\n    # modified to first find segments then initialize position ids for each segment\n    # Steps:\n    #  locate all vision and text segments\n    #  calculate `vision_segment_lengths` for each vision segment to be use as offset\n    #  calculate `text_segment_lengths` for each text segment to be used as offset\n    #  create position ids for each vision segment based on the image grid\n    #  create position ids for each text segment\n    #  combine all the position ids\n    #  the final segment is the difference between the last vision segment and the end of the input\n    #  combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)\n    def get_position_ids(\n        self,\n        input_ids: torch.Tensor,\n        image_grid_thw: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if image_grid_thw is None:\n            return (\n                torch.arange(input_ids.shape[0], device=input_ids.device)\n                .unsqueeze(1)\n                .repeat(1, 3)\n            )\n\n        spatial_merge_size = self.spatial_merge_size\n        vision_start_token_id = self.vision_start_token_id\n        vision_end_token_id = self.vision_end_token_id\n        device = input_ids.device\n        dtype = input_ids.dtype\n        input_ids_len = input_ids.shape[0]\n\n        vision_starts = torch.where(input_ids == vision_start_token_id)[0]\n        vision_ends = torch.where(input_ids == vision_end_token_id)[0]\n        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)\n        prev_vision_end = torch.cat(\n            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]\n        )\n        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1\n        vision_widths_max = torch.cat(\n            [\n                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),\n                image_grid_thw[:-1, 2] // spatial_merge_size,\n            ]\n        )\n        vision_segment_lengths = vision_widths_max + text_lengths_between_vision\n        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)\n        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision\n\n        # create position ids for each vision segment based on the image grid\n        llm_pos_ids_list = []\n        for i, _ in enumerate(vision_segments):\n            t, h, w = (\n                image_grid_thw[i][0],\n                image_grid_thw[i][1] // spatial_merge_size,\n                image_grid_thw[i][2] // spatial_merge_size,\n            )\n            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)\n            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)\n            w_indices = torch.arange(w, device=device).repeat(t * h)\n            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)\n\n            # offset by the position of the last vision segment\n            im = image_position_ids + vision_segment_lengths[i]\n            llm_pos_ids_list.append(im)\n\n        # create position ids for each text segment\n        text_ranges = [\n            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)\n            + text_segment_lengths[i]\n            for i, seq_len in enumerate(text_lengths_between_vision)\n        ]\n\n        full_llm_pos_ids_list = [\n            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist\n        ]\n        # import ipdb\n\n        # ipdb.set_trace()\n        max_s = full_llm_pos_ids_list[-1].max() + 1\n        final_text_len = input_ids_len - vision_ends[-1]\n        if final_text_len > 0:\n            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)\n            full_llm_pos_ids_list.append(m + max_s)\n\n        position_ids = (\n            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)\n        )\n        return position_ids\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)\n        return image_embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        # apply the visual model to the pixel values if they are provided\n        if vision_embeds is not None:\n            inputs_embeds[input_ids == self.image_token_id] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor],\n        # Unused in this model\n        attention_mask: Optional[torch.Tensor] = None,\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n    ):\n        hidden_states = self.text_model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=prefill_cache_indices,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/qwen2_vl.py",
    "content": "# coding=utf-8\n# Copyright 2024 the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2 VL model.\"\"\"\n\nfrom typing import Optional, Tuple, List\n\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom text_generation_server.utils.import_utils import SYSTEM\n\nif SYSTEM == \"ipex\":\n    import intel_extension_for_pytorch as ipex\nelse:\n    import flash_attn_2_cuda\n\nimport numpy as np\n\nfrom transformers.activations import ACT2FN\nimport torch.nn.functional as F\n\nfrom text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n    TensorParallelEmbedding,\n    SpeculativeHead,\n)\nfrom text_generation_server.layers.attention import (\n    Seqlen,\n)\nfrom text_generation_server.models.custom_modeling.flash_qwen2_modeling import (\n    Qwen2Model,\n)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb_vision(\n    tensor: torch.Tensor, freqs: torch.Tensor\n) -> torch.Tensor:\n    orig_dtype = tensor.dtype\n    tensor = tensor.float()\n    cos = freqs.cos()\n    sin = freqs.sin()\n    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()\n    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()\n    output = (tensor * cos) + (rotate_half(tensor) * sin)\n    output = output.to(orig_dtype)\n    return output\n\n\nclass Qwen2VLAttention(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.embed_dim = config.embed_dim // weights.process_group.size()\n        self.head_dim = config.hidden_size // config.num_heads\n        self.num_heads = config.num_heads // weights.process_group.size()\n\n        self.qkv = TensorParallelColumnLinear.load_qkv(\n            config,\n            prefix=f\"{prefix}.qkv\",\n            weights=weights,\n            bias=False,\n            num_heads=self.num_heads,\n            num_key_value_heads=self.num_heads,\n        )\n        self.qkv.linear.bias = weights.get_sharded(f\"{prefix}.qkv.bias\", dim=0)\n        self.proj = TensorParallelRowLinear.load(\n            config,\n            prefix=f\"{prefix}.proj\",\n            weights=weights,\n            bias=True,\n        )\n        self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)\n\n    def forward(\n        self,\n        hidden_state: torch.Tensor,\n        cu_seqlens: torch.Tensor,\n        rotary_pos_emb: torch.Tensor,\n        max_seqlen: int,\n    ) -> torch.Tensor:\n        # apply the qkv linear layer to the hidden state\n        qkv = self.qkv(hidden_state)\n        query, key, value = qkv.split(\n            [self.embed_dim, self.embed_dim, self.embed_dim], dim=1\n        )\n\n        # reshape the query, key, and value tensors\n        _shape = (\n            hidden_state.shape[0],\n            self.num_heads,\n            self.embed_dim // self.num_heads,\n        )\n        query = query.view(*_shape)\n        key = key.view(*_shape)\n        value = value.view(*_shape)\n\n        # apply rotary positional embeddings\n        query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(\n            0\n        )\n        key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)\n\n        # calc maximum sequence length for any batch\n        query = query.contiguous()\n        key = key.contiguous()\n        value = value.contiguous()\n        causal = False\n\n        # execute flash attention\n        if SYSTEM == \"ipex\":\n            attn_output = torch.empty_like(query)\n            if query.device.type == \"xpu\":\n                ipex.llm.functional.varlen_attention(\n                    query.contiguous(),\n                    key.contiguous(),\n                    value.contiguous(),\n                    attn_output,\n                    cu_seqlens,\n                    cu_seqlens,\n                    None,\n                    max_seqlen,\n                    max_seqlen,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n            else:\n                ipex.llm.functional.varlen_attention(\n                    query,\n                    key,\n                    value,\n                    attn_output,\n                    cu_seqlens,\n                    cu_seqlens,\n                    max_seqlen,\n                    max_seqlen,\n                    0.0,\n                    self.softmax_scale,\n                    False,\n                    causal,\n                    False,\n                    None,\n                )\n        else:\n            attn_output = flash_attn_2_cuda.varlen_fwd(\n                query,\n                key,\n                value,\n                None,  # tmp buffer (auto-allocated)\n                cu_seqlens,  # cu_seqlens_q\n                cu_seqlens,  # cu_seqlens_k\n                None,  # max_seqlen_q (auto-computed)\n                None,  # max_seqlen_k (auto-computed)\n                None,  # block_tables\n                None,  # broadcast_mask\n                max_seqlen,  # max_seqlen\n                max_seqlen,  # max_seqlen\n                0.0,  # dropout_p\n                self.softmax_scale,\n                False,  # zero_tensors\n                causal,  # causal attention within each sequence\n                -1,  # window_size_left\n                -1,  # window_size_right\n                0.0,  # softmax_cap\n                False,  # deterministic\n                None,  # rng_state\n            )[0]\n\n        # reshape output to original dimensions\n        attn_output = attn_output.reshape(hidden_state.shape[0], -1)\n        attn_output = self.proj(attn_output)\n        return attn_output\n\n\nclass Qwen2VLVisionMLP(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.fc1\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.fc2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VLVisionBlock(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.attn = Qwen2VLAttention(\n            prefix=f\"{prefix}.attn\",\n            config=config,\n            weights=weights,\n        )\n        self.norm1 = FastLayerNorm.load(\n            prefix=f\"{prefix}.norm1\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.norm2 = FastLayerNorm.load(\n            prefix=f\"{prefix}.norm2\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.mlp = Qwen2VLVisionMLP(\n            prefix=f\"{prefix}.mlp\",\n            config=config,\n            weights=weights,\n        )\n\n    def forward(\n        self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen\n    ) -> torch.Tensor:\n        norm1_out, residual = self.norm1(hidden_states)\n        attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)\n        hidden_states = attn_out + residual\n        norm2_out, residual = self.norm2(hidden_states)\n        hidden_states = hidden_states + self.mlp(norm2_out)\n        return hidden_states\n\n\nclass Qwen2VLPatchMerger(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)\n        self.patch_merger_ln_q = FastLayerNorm.load(\n            prefix=f\"{prefix}.ln_q\",\n            weights=weights,\n            eps=1e-6,\n        )\n        self.fc1 = TensorParallelColumnLinear.load(\n            prefix=f\"{prefix}.mlp.0\", weights=weights, config=config, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(\n            prefix=f\"{prefix}.mlp.2\", weights=weights, config=config, bias=True\n        )\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        hidden_states, _ = self.patch_merger_ln_q(hidden_states)\n        hidden_states = hidden_states.view(-1, self.hidden_size)\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = F.gelu(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VisionModel(nn.Module):\n    def __init__(self, *, prefix, config, weights):\n        super().__init__()\n        self.spatial_merge_size = config.spatial_merge_size\n        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]\n        self.patch_embedding = nn.Conv3d(\n            in_channels=config.in_chans,\n            out_channels=config.embed_dim,\n            kernel_size=kernel_size,\n            stride=kernel_size,\n            bias=False,\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embed.proj.weight\"), requires_grad=False\n        )\n        head_dim = config.embed_dim // config.num_heads\n        # TODO: replace with static positional embeddings once implemented\n        theta = 10000.0\n        dim = head_dim // 2\n        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        self.blocks = nn.ModuleList(\n            [\n                Qwen2VLVisionBlock(\n                    prefix=f\"{prefix}.blocks.{i}\",\n                    config=config,\n                    weights=weights,\n                )\n                for i in range(config.depth)\n            ]\n        )\n        self.merger = Qwen2VLPatchMerger(\n            prefix=f\"{prefix}.merger\",\n            config=config,\n            weights=weights,\n        )\n\n        self.temporal_patch_size = config.temporal_patch_size\n        self.spatial_patch_size = config.spatial_patch_size\n        self.in_channels = config.in_channels\n        self.embed_dim = config.embed_dim\n\n    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:\n        batch_size, _, hidden_size = hidden_state.shape\n        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)\n        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)\n        return hidden_state\n\n    def forward(\n        self,\n        pixel_values: torch.Tensor,\n        grid_thw: Optional[torch.LongTensor] = None,\n    ) -> torch.Tensor:\n        # reshape the input tensor for processing\n        shape = (\n            -1,\n            self.in_channels,\n            self.temporal_patch_size,\n            self.spatial_patch_size,\n            self.spatial_patch_size,\n        )\n        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)\n        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)\n        # TODO: revisit to see if we can avoid some of these reshapes\n\n        # find the position ids for the input tensor based on the grid_thw\n        pos_ids = []\n        for t, h, w in grid_thw:\n            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)\n            hpos_ids = hpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            hpos_ids = hpos_ids.permute(0, 2, 1, 3)\n            hpos_ids = hpos_ids.flatten()\n\n            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)\n            wpos_ids = wpos_ids.reshape(\n                h // self.spatial_merge_size,\n                self.spatial_merge_size,\n                w // self.spatial_merge_size,\n                self.spatial_merge_size,\n            )\n            wpos_ids = wpos_ids.permute(0, 2, 1, 3)\n            wpos_ids = wpos_ids.flatten()\n            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))\n\n        pos_ids = torch.cat(pos_ids, dim=0)\n        max_grid_size = grid_thw[:, 1:].max()\n\n        # apply the positional embeddings to the position ids\n        seq = torch.arange(\n            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype\n        )\n        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)\n        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)\n        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)\n\n        # create a cu_seqlens tensor to be used in the attention mask\n        cu_seqlens = torch.repeat_interleave(\n            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]\n        ).cumsum(dim=0, dtype=torch.int32)\n        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)\n        max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])\n        # iterately apply the blocks to the hidden states\n        for block in self.blocks:\n            hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)\n\n        # apply the final patch merger to the hidden states\n        hidden_states = self.merger(hidden_states)\n        return hidden_states\n\n\nclass Qwen2VLForConditionalGeneration(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        config.vision_config.quantize = None\n        config.vision_config.speculator = config.speculator\n        # set rope_scaling.type == \"mrope\" since AutoConfig.from_pretrained incorrectly\n        # returns rope_scaling.type == \"default\" for Qwen2-VL model at the moment\n        if (\n            hasattr(config, \"rope_scaling\")\n            and config.rope_scaling is not None\n            and config.rope_scaling.get(\"type\", None) == \"default\"\n        ):\n            config.rope_scaling.update({\"rope_type\": \"mrope\"})\n        self.hidden_size = config.hidden_size\n        self.vision_start_token_id = config.vision_start_token_id\n        self.vision_end_token_id = config.vision_end_token_id\n        self.image_token_id = config.image_token_id\n        self.video_token_id = config.video_token_id\n        self.spatial_merge_size = config.vision_config.spatial_merge_size\n        self.embed_tokens = TensorParallelEmbedding(\n            prefix=\"model.embed_tokens\", weights=weights\n        )\n        self.visual = Qwen2VisionModel(\n            prefix=\"visual\", config=config.vision_config, weights=weights\n        )\n        self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)\n        if config.tie_word_embeddings:\n            suffix = \"model.embed_tokens\"\n        else:\n            suffix = \"lm_head\"\n\n        self.lm_head = SpeculativeHead.load(\n            config,\n            prefix=suffix if not prefix else f\"{prefix}.{suffix}\",\n            weights=weights,\n        )\n        self.norm = FastRMSNorm.load(\n            prefix=\"model.norm\",\n            weights=weights,\n            eps=config.rms_norm_eps,\n        )\n        self.device = weights.device\n\n    # based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391\n    # modified to first find segments then initialize position ids for each segment\n    # Steps:\n    #  locate all vision and text segments\n    #  calculate `vision_segment_lengths` for each vision segment to be use as offset\n    #  calculate `text_segment_lengths` for each text segment to be used as offset\n    #  create position ids for each vision segment based on the image grid\n    #  create position ids for each text segment\n    #  combine all the position ids\n    #  the final segment is the difference between the last vision segment and the end of the input\n    #  combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)\n    def get_position_ids(\n        self,\n        input_ids: torch.Tensor,\n        image_grid_thw: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        if image_grid_thw is None:\n            return (\n                torch.arange(input_ids.shape[0], device=input_ids.device)\n                .unsqueeze(1)\n                .repeat(1, 3)\n            )\n\n        spatial_merge_size = self.spatial_merge_size\n        vision_start_token_id = self.vision_start_token_id\n        vision_end_token_id = self.vision_end_token_id\n        device = input_ids.device\n        dtype = input_ids.dtype\n        input_ids_len = input_ids.shape[0]\n\n        vision_starts = torch.where(input_ids == vision_start_token_id)[0]\n        vision_ends = torch.where(input_ids == vision_end_token_id)[0]\n        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)\n        prev_vision_end = torch.cat(\n            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]\n        )\n        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1\n        vision_widths_max = torch.cat(\n            [\n                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),\n                image_grid_thw[:-1, 2] // spatial_merge_size,\n            ]\n        )\n        vision_segment_lengths = vision_widths_max + text_lengths_between_vision\n        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)\n        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision\n\n        # create position ids for each vision segment based on the image grid\n        llm_pos_ids_list = []\n        for i, _ in enumerate(vision_segments):\n            t, h, w = (\n                image_grid_thw[i][0],\n                image_grid_thw[i][1] // spatial_merge_size,\n                image_grid_thw[i][2] // spatial_merge_size,\n            )\n            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)\n            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)\n            w_indices = torch.arange(w, device=device).repeat(t * h)\n            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)\n\n            # offset by the position of the last vision segment\n            im = image_position_ids + vision_segment_lengths[i]\n            llm_pos_ids_list.append(im)\n\n        # create position ids for each text segment\n        text_ranges = [\n            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)\n            + text_segment_lengths[i]\n            for i, seq_len in enumerate(text_lengths_between_vision)\n        ]\n\n        full_llm_pos_ids_list = [\n            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist\n        ]\n        max_s = full_llm_pos_ids_list[-1].max() + 1\n        final_text_len = input_ids_len - vision_ends[-1]\n        if final_text_len > 0:\n            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)\n            full_llm_pos_ids_list.append(m + max_s)\n\n        position_ids = (\n            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)\n        )\n        return position_ids\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)\n        return image_embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: torch.Tensor = None,\n    ):\n        inputs_embeds = self.embed_tokens(input_ids)\n\n        # apply the visual model to the pixel values if they are provided\n        if vision_embeds is not None:\n            inputs_embeds[input_ids == self.image_token_id] = vision_embeds\n\n        return inputs_embeds\n\n    def forward(\n        self,\n        inputs_embeds: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        prefill_cache_indices: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor],\n        adapter_data: Optional[torch.Tensor] = None,\n        image_indices=None,\n        attention_mask=None,\n    ):\n        hidden_states = self.text_model(\n            inputs_embeds=inputs_embeds,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            true_max_s=max_s,\n            prefill_cache_indices=prefill_cache_indices,\n            adapter_data=adapter_data,\n        )\n        if lm_head_indices is not None:\n            hidden_states = hidden_states[lm_head_indices]\n        logits, speculative_logits = self.lm_head(hidden_states)\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/siglip.py",
    "content": "from typing import Optional, Tuple\nimport warnings\nimport math\nimport torch\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPooling,\n)\nfrom transformers import SiglipConfig, SiglipVisionConfig\nfrom torch.nn.init import _calculate_fan_in_and_fan_out\n\nfrom text_generation_server.layers.tensor_parallel import (\n    TensorParallelEmbedding,\n    TensorParallelColumnLinear,\n    TensorParallelRowLinear,\n)\n\n\nclass SiglipVisionEmbeddings(nn.Module):\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.patch_embedding = nn.Conv2d(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            padding=\"valid\",\n        )\n        self.patch_embedding.weight = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.weight\"), requires_grad=False\n        )\n        self.patch_embedding.bias = nn.Parameter(\n            weights.get_tensor(f\"{prefix}.patch_embedding.bias\"), requires_grad=False\n        )\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches\n        self.position_embedding = TensorParallelEmbedding(\n            prefix=f\"{prefix}.position_embedding\", weights=weights\n        )\n        self.register_buffer(\n            \"position_ids\",\n            torch.arange(self.num_positions, device=weights.device).expand((1, -1)),\n            persistent=False,\n        )\n\n    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:\n        patch_embeds = self.patch_embedding(\n            pixel_values\n        )  # shape = [*, width, grid, grid]\n        embeddings = patch_embeds.flatten(2).transpose(1, 2)\n\n        embeddings = embeddings + self.position_embedding(self.position_ids)\n        return embeddings\n\n\nclass SiglipAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        self.head_size = self.head_dim\n        if self.head_dim * self.num_heads != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:\"\n                f\" {self.num_heads}).\"\n            )\n        self.num_heads = self.num_heads // weights.process_group.size()\n        self.embed_dim = self.embed_dim // weights.process_group.size()\n        self.scale = self.head_dim**-0.5\n        self.dropout = config.attention_dropout\n\n        self.k_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.k_proj\", weights=weights, bias=True\n        )\n        self.v_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.v_proj\", weights=weights, bias=True\n        )\n        self.q_proj = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.q_proj\", weights=weights, bias=True\n        )\n        self.out_proj = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.out_proj\", weights=weights, bias=True\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        \"\"\"Input shape: Batch x Time x Channel\"\"\"\n\n        bsz, tgt_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states)\n        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)\n        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)\n        proj_shape = (bsz * self.num_heads, -1, self.head_dim)\n        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)\n        key_states = key_states.view(*proj_shape)\n        value_states = value_states.view(*proj_shape)\n\n        src_len = key_states.size(1)\n        # scale post matmul\n        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) * self.scale\n\n        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, tgt_len, src_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = (\n                attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n                + attention_mask\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(attn_weights.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_size):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_size)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_size)\n        attn_output = attn_output.transpose(1, 2)\n        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)\n\n        attn_output = self.out_proj(attn_output)\n\n        return attn_output, attn_weights\n\n\nclass SiglipMLP(nn.Module):\n    def __init__(self, prefix, config, weights):\n        super().__init__()\n        self.config = config\n        self.activation_fn = ACT2FN[config.hidden_act]\n        self.fc1 = TensorParallelColumnLinear.load(  # config.hidden_size, config.intermediate_size\n            prefix=f\"{prefix}.fc1\", config=config, weights=weights, bias=True\n        )\n        self.fc2 = TensorParallelRowLinear.load(  # config.intermediate_size, config.hidden_size\n            prefix=f\"{prefix}.fc2\", config=config, weights=weights, bias=True\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass SiglipEncoderLayer(nn.Module):\n    def __init__(self, prefix, config: SiglipConfig, weights):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = SiglipAttention(\n            prefix=f\"{prefix}.self_attn\", config=config, weights=weights\n        )\n        self.layer_norm1 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm1\", weights=weights, eps=config.layer_norm_eps\n        )\n        self.mlp = SiglipMLP(prefix=f\"{prefix}.mlp\", config=config, weights=weights)\n        self.layer_norm2 = nn.LayerNorm.load(\n            prefix=f\"{prefix}.layer_norm2\", weights=weights, eps=config.layer_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: torch.Tensor,\n    ) -> Tuple[torch.FloatTensor]:\n        residual = hidden_states\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states, attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n        )\n        hidden_states = residual + hidden_states\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states, None\n\n\nclass SiglipMultiheadAttentionPoolingHead(nn.Module):\n    \"\"\"Multihead Attention Pooling.\"\"\"\n\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n\n        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))\n        self.attention = torch.nn.MultiheadAttention(\n            config.hidden_size, config.num_attention_heads, batch_first=True\n        )\n        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.mlp = SiglipMLP(prefix, config, weights)\n\n    def forward(self, hidden_state):\n        batch_size = hidden_state.shape[0]\n        probe = self.probe.repeat(batch_size, 1, 1)\n\n        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]\n\n        residual = hidden_state\n        hidden_state = self.layernorm(hidden_state)\n        hidden_state = residual + self.mlp(hidden_state)\n\n        return hidden_state[:, 0]\n\n\ndef _trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\n            \"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n            \"The distribution of values may be incorrect.\",\n            stacklevel=2,\n        )\n\n    # Values are generated by using a truncated uniform distribution and\n    # then using the inverse CDF for the normal distribution.\n    # Get upper and lower cdf values\n    lower = norm_cdf((a - mean) / std)\n    upper = norm_cdf((b - mean) / std)\n\n    # Uniformly fill tensor with values from [l, u], then translate to\n    # [2l-1, 2u-1].\n    tensor.uniform_(2 * lower - 1, 2 * upper - 1)\n\n    # Use inverse cdf transform for normal distribution to get truncated\n    # standard normal\n    tensor.erfinv_()\n\n    # Transform to proper mean, std\n    tensor.mul_(std * math.sqrt(2.0))\n    tensor.add_(mean)\n\n    # Clamp to ensure it's in the proper range\n    tensor.clamp_(min=a, max=b)\n\n\ndef trunc_normal_tf_(\n    tensor: torch.Tensor,\n    mean: float = 0.0,\n    std: float = 1.0,\n    a: float = -2.0,\n    b: float = 2.0,\n) -> torch.Tensor:\n    \"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\\\leq \\text{mean} \\\\leq b`.\n\n    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the\n    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0\n    and the result is subsquently scaled and shifted by the mean and std args.\n\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    \"\"\"\n    with torch.no_grad():\n        _trunc_normal_(tensor, 0, 1.0, a, b)\n        tensor.mul_(std).add_(mean)\n\n\ndef variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"normal\"):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n    if mode == \"fan_in\":\n        denom = fan_in\n    elif mode == \"fan_out\":\n        denom = fan_out\n    elif mode == \"fan_avg\":\n        denom = (fan_in + fan_out) / 2\n\n    variance = scale / denom\n\n    if distribution == \"truncated_normal\":\n        # constant is stddev of standard normal truncated to (-2, 2)\n        trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)\n    elif distribution == \"normal\":\n        with torch.no_grad():\n            tensor.normal_(std=math.sqrt(variance))\n    elif distribution == \"uniform\":\n        bound = math.sqrt(3 * variance)\n        with torch.no_grad():\n            tensor.uniform_(-bound, bound)\n    else:\n        raise ValueError(f\"invalid distribution {distribution}\")\n\n\ndef lecun_normal_(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"truncated_normal\")\n\n\ndef default_flax_embed_init(tensor):\n    variance_scaling_(tensor, mode=\"fan_in\", distribution=\"normal\")\n\n\nclass SiglipEncoder(nn.Module):\n    \"\"\"\n    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a\n    [`SiglipEncoderLayer`].\n\n    Args:\n        config: SiglipConfig\n    \"\"\"\n\n    def __init__(self, prefix, config: SiglipConfig, weights):\n        super().__init__()\n        self.config = config\n        self.layers = nn.ModuleList(\n            [\n                SiglipEncoderLayer(\n                    prefix=f\"{prefix}.layers.{i}\", config=config, weights=weights\n                )\n                for i in range(config.num_hidden_layers)\n            ]\n        )\n\n    def forward(\n        self,\n        inputs_embeds,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = inputs_embeds\n        for idx, encoder_layer in enumerate(self.layers):\n            hidden_states, _ = encoder_layer(\n                hidden_states,\n                attention_mask,\n            )\n\n        return hidden_states\n\n\nclass SiglipVisionTransformer(nn.Module):\n    def __init__(self, prefix, config: SiglipVisionConfig, weights):\n        super().__init__()\n        self.config = config\n\n        self.embeddings = SiglipVisionEmbeddings(\n            prefix=f\"{prefix}.embeddings\", config=config, weights=weights\n        )\n        self.encoder = SiglipEncoder(\n            prefix=f\"{prefix}.encoder\", config=config, weights=weights\n        )\n\n    def forward(\n        self,\n        pixel_values: Optional[torch.FloatTensor] = None,\n    ):\n        if pixel_values is None:\n            raise ValueError(\"You have to specify pixel_values\")\n\n        hidden_states = self.embeddings(pixel_values)\n\n        # NOTE: up until this point, the code logits are exactly\n        # the same as the transformers code. The values evaulate\n        # slightly differently in our encoder layer.\n        encoder_outputs = self.encoder(\n            inputs_embeds=hidden_states,\n        )\n        last_hidden_state = encoder_outputs\n\n        return BaseModelOutputWithPooling(\n            last_hidden_state=last_hidden_state,\n            # pooler_output=pooled_output,\n            # hidden_states=encoder_outputs,\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/t5_modeling.py",
    "content": "# coding=utf-8\n# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch T5 model.\"\"\"\n\nimport copy\nimport math\nimport warnings\nfrom typing import Optional, Tuple, Union\n\nfrom loguru import logger\n\nimport torch\nimport torch.distributed\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.modeling_outputs import (\n    BaseModelOutput,\n    BaseModelOutputWithPastAndCrossAttentions,\n    Seq2SeqLMOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    is_torch_fx_proxy,\n)\nfrom transformers import T5Config\nfrom text_generation_server.layers import (\n    TensorParallelColumnLinear,\n    TensorParallelEmbedding,\n    TensorParallelRowLinear,\n    SpeculativeHead,\n)\n\n# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316\n# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask\n__HEAD_MASK_WARNING_MSG = \"\"\"\nThe input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,\n`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.\nIf you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,\nnum_heads)`.\n\"\"\"\n\n\nclass PartialTPEmbedding(nn.Module):\n    def __init__(self, prefix: str, weights):\n        super().__init__()\n        weight = weights.get_sharded(f\"{prefix}.weight\", dim=1)\n        self.weight = nn.Parameter(weight)\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.embedding(input, self.weight)\n\n\n@torch.jit.script\ndef layer_norm(hidden_states, weight, epsilon):\n    # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean\n    # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated\n    # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for\n    # half-precision inputs is done in fp32\n\n    variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)\n    hidden_states = hidden_states * torch.rsqrt(variance + epsilon)\n\n    # convert into half-precision if necessary\n    if weight.dtype in [torch.float16, torch.bfloat16]:\n        hidden_states = hidden_states.to(weight.dtype)\n\n    return weight * hidden_states\n\n\nclass T5LayerNorm(nn.Module):\n    def __init__(self, prefix, weights, eps=1e-6):\n        \"\"\"\n        Construct a layernorm module in the T5 style. No bias and no subtraction of mean.\n        \"\"\"\n        super().__init__()\n        weight = weights.get_tensor(f\"{prefix}.weight\")\n        self.weight = nn.Parameter(weight)\n        self.variance_epsilon = torch.tensor(eps)\n\n    def forward(self, hidden_states):\n        return layer_norm(hidden_states, self.weight, self.variance_epsilon)\n\n\ntry:\n    from apex.normalization import FusedRMSNorm\n\n    T5LayerNorm = FusedRMSNorm  # noqa\n\n    logger.info(\n        \"Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm\"\n    )\nexcept ImportError:\n    # using the normal T5LayerNorm\n    pass\nexcept Exception:\n    logger.warning(\"discovered apex but it failed to load, falling back to T5LayerNorm\")\n    pass\n\nALL_LAYERNORM_LAYERS.append(T5LayerNorm)\n\n\nclass T5DenseActDense(nn.Module):\n    def __init__(self, config: T5Config, prefix, weights):\n        super().__init__()\n        self.wi = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.wi\", weights=weights, bias=False\n        )\n\n        ### XXX: T5 models do not handle well both f16 and quantization.\n        ### Overidding specifically this layer for that reason.\n        ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316\n        ### https://github.com/huggingface/transformers/issues/20287\n        _q = config.quantize\n        _dtype = weights.dtype\n        weights.dtype = torch.float32\n        config.quantize = None\n        self.wo_cast = (torch.float32, _dtype)\n        self.wo = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.wo\", weights=weights, bias=False\n        )\n        weights.dtype = _dtype\n        config.quantize = _q\n\n        self.dropout = nn.Dropout(config.dropout_rate)\n        self.act = (\n            ACT2FN[config.dense_act_fn]\n            if \"gelu\" not in config.dense_act_fn\n            else lambda x: torch.nn.functional.gelu(x, approximate=\"tanh\")\n        )\n\n    def forward(self, hidden_states):\n        hidden_states = self.wi(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n\n        hidden_states = hidden_states.to(dtype=self.wo_cast[0])\n        hidden_states = self.wo(hidden_states)\n        # XXX: Recasting is already done within the layer norm.\n        # Casting back to float16 here modifies results\n        # hidden_states = hidden_states.to(dtype=self.wo_cast[1])\n        return hidden_states\n\n\nclass T5DenseGatedActDense(nn.Module):\n    def __init__(self, config: T5Config, prefix, weights):\n        super().__init__()\n        self.wi_0 = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.wi_0\", weights=weights, bias=False\n        )\n        self.wi_1 = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.wi_1\", weights=weights, bias=False\n        )\n        ### XXX: T5 models do not handle well both f16 and quantization.\n        ### Overidding specifically this layer for that reason.\n        ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316\n        ### https://github.com/huggingface/transformers/issues/20287\n        _q = config.quantize\n        _dtype = weights.dtype\n        weights.dtype = torch.float32\n        config.quantize = None\n        self.wo_cast = (torch.float32, _dtype)\n        self.wo = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.wo\", weights=weights, bias=False\n        )\n        weights.dtype = _dtype\n        config.quantize = _q\n\n        self.dropout = nn.Dropout(config.dropout_rate)\n        self.act = (\n            ACT2FN[config.dense_act_fn]\n            if \"gelu\" not in config.dense_act_fn\n            else lambda x: torch.nn.functional.gelu(x, approximate=\"tanh\")\n        )\n\n    def forward(self, hidden_states):\n        hidden_gelu = self.act(self.wi_0(hidden_states))\n        hidden_linear = self.wi_1(hidden_states)\n        hidden_states = hidden_gelu * hidden_linear\n        hidden_states = self.dropout(hidden_states)\n\n        hidden_states = hidden_states.to(dtype=self.wo_cast[0])\n        hidden_states = self.wo(hidden_states)\n        # XXX: Recasting is already done within the layer norm.\n        # Casting back to float16 here modifies results\n        # hidden_states = hidden_states.to(dtype=self.wo_cast[1])\n        return hidden_states\n\n\nclass T5LayerFF(nn.Module):\n    def __init__(self, config: T5Config, prefix, weights):\n        super().__init__()\n        if config.is_gated_act:\n            self.DenseReluDense = T5DenseGatedActDense(\n                config, prefix=f\"{prefix}.DenseReluDense\", weights=weights\n            )\n        else:\n            self.DenseReluDense = T5DenseActDense(\n                config, prefix=f\"{prefix}.DenseReluDense\", weights=weights\n            )\n\n        self.layer_norm = T5LayerNorm(\n            prefix=f\"{prefix}.layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(self, hidden_states):\n        forwarded_states = self.layer_norm(hidden_states)\n        forwarded_states = self.DenseReluDense(forwarded_states)\n        hidden_states = hidden_states + self.dropout(forwarded_states)\n        return hidden_states\n\n\nclass T5Attention(nn.Module):\n    def __init__(\n        self, config: T5Config, prefix, weights, has_relative_attention_bias=False\n    ):\n        super().__init__()\n        self.is_decoder = config.is_decoder\n        self.has_relative_attention_bias = has_relative_attention_bias\n        self.relative_attention_num_buckets = config.relative_attention_num_buckets\n        self.relative_attention_max_distance = config.relative_attention_max_distance\n        self.d_model = config.d_model\n        self.key_value_proj_dim = config.d_kv\n        self.n_heads = config.num_heads\n        self.dropout = config.dropout_rate\n        self.inner_dim = self.n_heads * self.key_value_proj_dim\n\n        process_group = weights.process_group\n        # Mesh TensorFlow initialization to avoid scaling before softmax\n        assert self.n_heads % process_group.size() == 0\n        self.q = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.q\", weights=weights, bias=False\n        )\n        self.k = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.k\", weights=weights, bias=False\n        )\n        self.v = TensorParallelColumnLinear.load(\n            config, prefix=f\"{prefix}.v\", weights=weights, bias=False\n        )\n        self.o = TensorParallelRowLinear.load(\n            config, prefix=f\"{prefix}.o\", weights=weights, bias=False\n        )\n        if self.n_heads % weights.process_group.size() != 0:\n            raise ValueError(\n                f\"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} \"\n                f\"and `num_shards`: {weights.process_group.size()}\"\n            )\n        self.n_heads = self.n_heads // process_group.size()\n        self.inner_dim = self.inner_dim // process_group.size()\n\n        if self.has_relative_attention_bias:\n            self.relative_attention_bias = PartialTPEmbedding(\n                prefix=f\"{prefix}.relative_attention_bias\", weights=weights\n            )\n\n    @staticmethod\n    def _relative_position_bucket(\n        relative_position, bidirectional=True, num_buckets=32, max_distance=128\n    ):\n        \"\"\"\n        Adapted from Mesh Tensorflow:\n        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593\n\n        Translate relative position to a bucket number for relative attention. The relative position is defined as\n        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to\n        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for\n        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative\n        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.\n        This should allow for more graceful generalization to longer sequences than the model has been trained on\n\n        Args:\n            relative_position: an int32 Tensor\n            bidirectional: a boolean - whether the attention is bidirectional\n            num_buckets: an integer\n            max_distance: an integer\n\n        Returns:\n            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)\n        \"\"\"\n        relative_buckets = 0\n        if bidirectional:\n            num_buckets //= 2\n            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets\n            relative_position = torch.abs(relative_position)\n        else:\n            relative_position = -torch.min(\n                relative_position, torch.zeros_like(relative_position)\n            )\n        # now relative_position is in the range [0, inf)\n\n        # half of the buckets are for exact increments in positions\n        max_exact = num_buckets // 2\n        is_small = relative_position < max_exact\n\n        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n        relative_position_if_large = max_exact + (\n            torch.log(relative_position.float() / max_exact)\n            / math.log(max_distance / max_exact)\n            * (num_buckets - max_exact)\n        ).to(torch.long)\n        relative_position_if_large = torch.min(\n            relative_position_if_large,\n            torch.full_like(relative_position_if_large, num_buckets - 1),\n        )\n\n        relative_buckets += torch.where(\n            is_small, relative_position, relative_position_if_large\n        )\n        return relative_buckets\n\n    def compute_bias(self, query_length, key_length, device=None):\n        \"\"\"Compute binned relative position bias\"\"\"\n        if device is None:\n            device = self.relative_attention_bias.weight.device\n        context_position = torch.arange(query_length, dtype=torch.long, device=device)[\n            :, None\n        ]\n        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[\n            None, :\n        ]\n        relative_position = (\n            memory_position - context_position\n        )  # shape (query_length, key_length)\n        relative_position_bucket = self._relative_position_bucket(\n            relative_position,  # shape (query_length, key_length)\n            bidirectional=(not self.is_decoder),\n            num_buckets=self.relative_attention_num_buckets,\n            max_distance=self.relative_attention_max_distance,\n        )\n        values = self.relative_attention_bias(\n            relative_position_bucket\n        )  # shape (query_length, key_length, num_heads)\n        values = values.permute([2, 0, 1]).unsqueeze(\n            0\n        )  # shape (1, num_heads, query_length, key_length)\n        return values\n\n    def forward(\n        self,\n        hidden_states,\n        mask=None,\n        key_value_states=None,\n        position_bias=None,\n        past_key_value=None,\n        layer_head_mask=None,\n        query_length=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        \"\"\"\n        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).\n        \"\"\"\n        # Input is (batch_size, seq_length, dim)\n        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)\n        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)\n\n        batch_size, seq_length = hidden_states.shape[:2]\n\n        real_seq_length = seq_length\n\n        if past_key_value is not None:\n            assert (\n                len(past_key_value) == 2\n            ), f\"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states\"\n            real_seq_length += (\n                past_key_value[0].shape[2] if query_length is None else query_length\n            )\n\n        key_length = (\n            real_seq_length if key_value_states is None else key_value_states.shape[1]\n        )\n\n        def shape(states):\n            \"\"\"projection\"\"\"\n            return states.view(\n                batch_size, -1, self.n_heads, self.key_value_proj_dim\n            ).transpose(1, 2)\n\n        def unshape(states):\n            \"\"\"reshape\"\"\"\n            return (\n                states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)\n            )\n\n        def project(hidden_states, proj_layer, key_value_states, past_key_value):\n            \"\"\"projects hidden states correctly to key/query states\"\"\"\n            if key_value_states is None:\n                # self-attn\n                # (batch_size, n_heads, seq_length, dim_per_head)\n                hidden_states = shape(proj_layer(hidden_states))\n            elif past_key_value is None:\n                # cross-attn\n                # (batch_size, n_heads, seq_length, dim_per_head)\n                hidden_states = shape(proj_layer(key_value_states))\n\n            if past_key_value is not None:\n                if key_value_states is None:\n                    # self-attn\n                    # (batch_size, n_heads, key_length, dim_per_head)\n                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)\n                elif past_key_value.shape[2] != key_value_states.shape[1]:\n                    # checking that the `sequence_length` of the `past_key_value` is the same as\n                    # the provided `key_value_states` to support prefix tuning\n                    # cross-attn\n                    # (batch_size, n_heads, seq_length, dim_per_head)\n                    hidden_states = shape(proj_layer(key_value_states))\n                else:\n                    # cross-attn\n                    hidden_states = past_key_value\n            return hidden_states\n\n        # get query states\n        query_states = shape(\n            self.q(hidden_states)\n        )  # (batch_size, n_heads, seq_length, dim_per_head)\n\n        # get key/value states\n        key_states = project(\n            hidden_states,\n            self.k,\n            key_value_states,\n            past_key_value[0] if past_key_value is not None else None,\n        )\n        value_states = project(\n            hidden_states,\n            self.v,\n            key_value_states,\n            past_key_value[1] if past_key_value is not None else None,\n        )\n\n        # compute scores\n        scores = torch.matmul(\n            query_states, key_states.transpose(3, 2)\n        )  # equivalent of torch.einsum(\"bnqd,bnkd->bnqk\", query_states, key_states), compatible with onnx op>9\n\n        if position_bias is None:\n            if not self.has_relative_attention_bias:\n                position_bias = torch.zeros(\n                    (1, self.n_heads, real_seq_length, key_length),\n                    device=scores.device,\n                    dtype=scores.dtype,\n                )\n            else:\n                position_bias = self.compute_bias(\n                    real_seq_length, key_length, device=scores.device\n                )\n\n            # if key and values are already calculated\n            # we want only the last query position bias\n            if past_key_value is not None:\n                position_bias = position_bias[:, :, -hidden_states.size(1) :, :]\n\n            if mask is not None:\n                position_bias = (\n                    position_bias + mask\n                )  # (batch_size, n_heads, seq_length, key_length)\n\n        position_bias_masked = position_bias\n\n        scores += position_bias_masked\n        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(\n            scores\n        )  # (batch_size, n_heads, seq_length, key_length)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.dropout, training=self.training\n        )  # (batch_size, n_heads, seq_length, key_length)\n\n        # Mask heads if we want to\n        if layer_head_mask is not None:\n            attn_weights = attn_weights * layer_head_mask\n\n        attn_output = unshape(\n            torch.matmul(attn_weights, value_states)\n        )  # (batch_size, seq_length, dim)\n        attn_output = self.o(attn_output)\n\n        present_key_value_state = (\n            (key_states, value_states) if (self.is_decoder and use_cache) else None\n        )\n        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)\n\n        if output_attentions:\n            outputs = outputs + (attn_weights,)\n        return outputs\n\n\nclass T5LayerSelfAttention(nn.Module):\n    def __init__(self, config, prefix, weights, has_relative_attention_bias=False):\n        super().__init__()\n        self.SelfAttention = T5Attention(\n            config,\n            prefix=f\"{prefix}.SelfAttention\",\n            weights=weights,\n            has_relative_attention_bias=has_relative_attention_bias,\n        )\n        self.layer_norm = T5LayerNorm(\n            prefix=f\"{prefix}.layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        position_bias=None,\n        layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        output_attentions=False,\n    ):\n        normed_hidden_states = self.layer_norm(hidden_states)\n        attention_output = self.SelfAttention(\n            normed_hidden_states,\n            mask=attention_mask,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        hidden_states = hidden_states + self.dropout(attention_output[0])\n        outputs = (hidden_states,) + attention_output[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass T5LayerCrossAttention(nn.Module):\n    def __init__(self, config, prefix, weights):\n        super().__init__()\n        self.EncDecAttention = T5Attention(\n            config,\n            prefix=f\"{prefix}.EncDecAttention\",\n            weights=weights,\n            has_relative_attention_bias=False,\n        )\n        self.layer_norm = T5LayerNorm(\n            prefix=f\"{prefix}.layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(\n        self,\n        hidden_states,\n        key_value_states,\n        attention_mask=None,\n        position_bias=None,\n        layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        query_length=None,\n        output_attentions=False,\n    ):\n        normed_hidden_states = self.layer_norm(hidden_states)\n        attention_output = self.EncDecAttention(\n            normed_hidden_states,\n            mask=attention_mask,\n            key_value_states=key_value_states,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            query_length=query_length,\n            output_attentions=output_attentions,\n        )\n        layer_output = hidden_states + self.dropout(attention_output[0])\n        outputs = (layer_output,) + attention_output[\n            1:\n        ]  # add attentions if we output them\n        return outputs\n\n\nclass T5Block(nn.Module):\n    def __init__(self, config, prefix, weights, has_relative_attention_bias: bool):\n        super().__init__()\n        self.is_decoder = config.is_decoder\n        self.layer = nn.ModuleList()\n        self.layer.append(\n            T5LayerSelfAttention(\n                config,\n                prefix=f\"{prefix}.layer.0\",\n                weights=weights,\n                has_relative_attention_bias=has_relative_attention_bias,\n            )\n        )\n        if self.is_decoder:\n            i = 2\n            self.layer.append(\n                T5LayerCrossAttention(\n                    config, prefix=f\"{prefix}.layer.1\", weights=weights\n                )\n            )\n        else:\n            i = 1\n\n        self.layer.append(\n            T5LayerFF(config, prefix=f\"{prefix}.layer.{i}\", weights=weights)\n        )\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        position_bias=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        encoder_decoder_position_bias=None,\n        layer_head_mask=None,\n        cross_attn_layer_head_mask=None,\n        past_key_value=None,\n        use_cache=False,\n        output_attentions=False,\n        return_dict=True,\n    ):\n        if past_key_value is not None:\n            if not self.is_decoder:\n                logger.warning(\n                    \"`past_key_values` is passed to the encoder. Please make sure this is intended.\"\n                )\n            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4\n\n            if len(past_key_value) != expected_num_past_key_values:\n                raise ValueError(\n                    f\"There should be {expected_num_past_key_values} past states. \"\n                    f\"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}\"\n                    f\"Got {len(past_key_value)} past key / value states\"\n                )\n\n            self_attn_past_key_value = past_key_value[:2]\n            cross_attn_past_key_value = past_key_value[2:]\n        else:\n            self_attn_past_key_value, cross_attn_past_key_value = None, None\n\n        self_attention_outputs = self.layer[0](\n            hidden_states,\n            attention_mask=attention_mask,\n            position_bias=position_bias,\n            layer_head_mask=layer_head_mask,\n            past_key_value=self_attn_past_key_value,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n        )\n        hidden_states, present_key_value_state = self_attention_outputs[:2]\n        attention_outputs = self_attention_outputs[\n            2:\n        ]  # Keep self-attention outputs and relative position weights\n\n        # clamp inf values to enable fp16 training\n        if hidden_states.dtype == torch.float16:\n            clamp_value = torch.where(\n                torch.isinf(hidden_states).any(),\n                torch.finfo(hidden_states.dtype).max - 1000,\n                torch.finfo(hidden_states.dtype).max,\n            )\n            hidden_states = torch.clamp(\n                hidden_states, min=-clamp_value, max=clamp_value\n            )\n\n        do_cross_attention = self.is_decoder and encoder_hidden_states is not None\n        if do_cross_attention:\n            # the actual query length is unknown for cross attention\n            # if using past key value states. Need to inject it here\n            if present_key_value_state is not None:\n                query_length = present_key_value_state[0].shape[2]\n            else:\n                query_length = None\n\n            cross_attention_outputs = self.layer[1](\n                hidden_states,\n                key_value_states=encoder_hidden_states,\n                attention_mask=encoder_attention_mask,\n                position_bias=encoder_decoder_position_bias,\n                layer_head_mask=cross_attn_layer_head_mask,\n                past_key_value=cross_attn_past_key_value,\n                query_length=query_length,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n            hidden_states = cross_attention_outputs[0]\n\n            # clamp inf values to enable fp16 training\n            if hidden_states.dtype == torch.float16:\n                clamp_value = torch.where(\n                    torch.isinf(hidden_states).any(),\n                    torch.finfo(hidden_states.dtype).max - 1000,\n                    torch.finfo(hidden_states.dtype).max,\n                )\n                hidden_states = torch.clamp(\n                    hidden_states, min=-clamp_value, max=clamp_value\n                )\n\n            # Combine self attn and cross attn key value states\n            if present_key_value_state is not None:\n                present_key_value_state = (\n                    present_key_value_state + cross_attention_outputs[1]\n                )\n\n            # Keep cross-attention outputs and relative position weights\n            attention_outputs = attention_outputs + cross_attention_outputs[2:]\n\n        # Apply Feed Forward layer\n        hidden_states = self.layer[-1](hidden_states)\n\n        # clamp inf values to enable fp16 training\n        if hidden_states.dtype == torch.float16:\n            clamp_value = torch.where(\n                torch.isinf(hidden_states).any(),\n                torch.finfo(hidden_states.dtype).max - 1000,\n                torch.finfo(hidden_states.dtype).max,\n            )\n            hidden_states = torch.clamp(\n                hidden_states, min=-clamp_value, max=clamp_value\n            )\n\n        outputs = (hidden_states,)\n\n        if use_cache:\n            outputs = outputs + (present_key_value_state,) + attention_outputs\n        else:\n            outputs = outputs + attention_outputs\n\n        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)\n\n\nclass T5PreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = T5Config\n\n    def _shift_right(self, input_ids):\n        decoder_start_token_id = self.config.decoder_start_token_id\n        pad_token_id = self.config.pad_token_id\n\n        assert decoder_start_token_id is not None, (\n            \"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.\"\n            \" See T5 docs for more information\"\n        )\n\n        # shift inputs to the right\n        if is_torch_fx_proxy(input_ids):\n            # Item assignment is not supported natively for proxies.\n            shifted_input_ids = torch.full(\n                input_ids.shape[:-1] + (1,), decoder_start_token_id\n            )\n            shifted_input_ids = torch.cat(\n                [shifted_input_ids, input_ids[..., :-1]], dim=-1\n            )\n        else:\n            shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n            shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()\n            shifted_input_ids[..., 0] = decoder_start_token_id\n\n        assert (\n            pad_token_id is not None\n        ), \"self.model.config.pad_token_id has to be defined.\"\n        # replace possible -100 values in labels by `pad_token_id`\n        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n\n        return shifted_input_ids\n\n\nclass T5Stack(T5PreTrainedModel):\n    def __init__(self, config, prefix, weights, embed_tokens):\n        super().__init__(config)\n\n        self.is_decoder = config.is_decoder\n\n        self.embed_tokens = embed_tokens\n        self.block = nn.ModuleList(\n            [\n                T5Block(\n                    config,\n                    prefix=f\"{prefix}.block.{layer_id}\",\n                    weights=weights,\n                    has_relative_attention_bias=(layer_id == 0),\n                )\n                for layer_id in range(config.num_layers)\n            ]\n        )\n        self.final_layer_norm = T5LayerNorm(\n            prefix=f\"{prefix}.final_layer_norm\",\n            weights=weights,\n            eps=config.layer_norm_epsilon,\n        )\n        self.dropout = nn.Dropout(config.dropout_rate)\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        inputs_embeds=None,\n        head_mask=None,\n        cross_attn_head_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        # Model parallel\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if input_ids is not None and inputs_embeds is not None:\n            err_msg_prefix = \"decoder_\" if self.is_decoder else \"\"\n            raise ValueError(\n                f\"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n        else:\n            err_msg_prefix = \"decoder_\" if self.is_decoder else \"\"\n            raise ValueError(\n                f\"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds\"\n            )\n\n        if inputs_embeds is None:\n            assert (\n                self.embed_tokens is not None\n            ), \"You have to initialize the model with valid token embeddings\"\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        batch_size, seq_length = input_shape\n\n        # required mask seq length can be calculated via length of past\n        mask_seq_length = (\n            past_key_values[0][0].shape[2] + seq_length\n            if past_key_values is not None\n            else seq_length\n        )\n\n        if use_cache is True:\n            assert (\n                self.is_decoder\n            ), f\"`use_cache` can only be set to `True` if {self} is used as a decoder\"\n\n        if attention_mask is None:\n            attention_mask = torch.ones(\n                batch_size, mask_seq_length, device=inputs_embeds.device\n            )\n        if (\n            self.is_decoder\n            and encoder_attention_mask is None\n            and encoder_hidden_states is not None\n        ):\n            encoder_seq_length = encoder_hidden_states.shape[1]\n            encoder_attention_mask = torch.ones(\n                batch_size,\n                encoder_seq_length,\n                device=inputs_embeds.device,\n                dtype=torch.long,\n            )\n\n        # initialize past_key_values with `None` if past does not exist\n        if past_key_values is None:\n            past_key_values = [None] * len(self.block)\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask = self.get_extended_attention_mask(\n            attention_mask, input_shape\n        )\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if self.is_decoder and encoder_hidden_states is not None:\n            (\n                encoder_batch_size,\n                encoder_sequence_length,\n                _,\n            ) = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n            if encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(\n                    encoder_hidden_shape, device=inputs_embeds.device\n                )\n            encoder_extended_attention_mask = self.invert_attention_mask(\n                encoder_attention_mask\n            )\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        head_mask = self.get_head_mask(head_mask, self.config.num_layers)\n        cross_attn_head_mask = self.get_head_mask(\n            cross_attn_head_mask, self.config.num_layers\n        )\n        present_key_value_states = () if use_cache else None\n        all_hidden_states = () if output_hidden_states else None\n        all_attentions = () if output_attentions else None\n        all_cross_attentions = () if (output_attentions and self.is_decoder) else None\n        position_bias = None\n        encoder_decoder_position_bias = None\n\n        hidden_states = self.dropout(inputs_embeds)\n\n        for i, (layer_module, past_key_value) in enumerate(\n            zip(self.block, past_key_values)\n        ):\n            layer_head_mask = head_mask[i]\n            cross_attn_layer_head_mask = cross_attn_head_mask[i]\n            # Model parallel\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_outputs = layer_module(\n                hidden_states,\n                attention_mask=extended_attention_mask,\n                position_bias=position_bias,\n                encoder_hidden_states=encoder_hidden_states,\n                encoder_attention_mask=encoder_extended_attention_mask,\n                encoder_decoder_position_bias=encoder_decoder_position_bias,\n                layer_head_mask=layer_head_mask,\n                cross_attn_layer_head_mask=cross_attn_layer_head_mask,\n                past_key_value=past_key_value,\n                use_cache=use_cache,\n                output_attentions=output_attentions,\n            )\n\n            # layer_outputs is a tuple with:\n            # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)\n            if use_cache is False:\n                layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]\n\n            hidden_states, present_key_value_state = layer_outputs[:2]\n\n            # We share the position biases between the layers - the first layer store them\n            # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),\n            # (cross-attention position bias), (cross-attention weights)\n            position_bias = layer_outputs[2]\n            if self.is_decoder and encoder_hidden_states is not None:\n                encoder_decoder_position_bias = layer_outputs[\n                    4 if output_attentions else 3\n                ]\n            # append next layer key value states\n            if use_cache:\n                present_key_value_states = present_key_value_states + (\n                    present_key_value_state,\n                )\n\n            if output_attentions:\n                all_attentions = all_attentions + (layer_outputs[3],)\n                if self.is_decoder:\n                    all_cross_attentions = all_cross_attentions + (layer_outputs[5],)\n\n        hidden_states = self.final_layer_norm(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n\n        # Add last layer\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    present_key_value_states,\n                    all_hidden_states,\n                    all_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=present_key_value_states,\n            hidden_states=all_hidden_states,\n            attentions=all_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass T5ForConditionalGeneration(T5PreTrainedModel):\n    def __init__(self, config: T5Config, weights):\n        super().__init__(config)\n        self.model_dim = config.d_model\n\n        self.shared = TensorParallelEmbedding(prefix=\"shared\", weights=weights)\n\n        encoder_config = copy.deepcopy(config)\n        encoder_config.is_decoder = False\n        encoder_config.use_cache = False\n        encoder_config.is_encoder_decoder = False\n        self.encoder = T5Stack(\n            config=encoder_config,\n            prefix=\"encoder\",\n            weights=weights,\n            embed_tokens=self.shared,\n        )\n\n        decoder_config = copy.deepcopy(config)\n        decoder_config.is_decoder = True\n        decoder_config.is_encoder_decoder = False\n        decoder_config.num_layers = config.num_decoder_layers\n        self.decoder = T5Stack(\n            config=decoder_config,\n            prefix=\"decoder\",\n            weights=weights,\n            embed_tokens=self.shared,\n        )\n\n        try:\n            self.lm_head = SpeculativeHead.load(\n                config, prefix=\"lm_head\", weights=weights\n            )\n        except RuntimeError:\n            # Some models like t5-small were saved with shared weights unlike flan\n            # Since they are declared as the same arch we have no choice but hope\n            # that this is OK instead of using a proper flag.\n            self.lm_head = SpeculativeHead.load(\n                config, prefix=\"shared\", weights=weights\n            )\n\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        decoder_input_ids: Optional[torch.LongTensor] = None,\n        decoder_attention_mask: Optional[torch.BoolTensor] = None,\n        head_mask: Optional[torch.FloatTensor] = None,\n        decoder_head_mask: Optional[torch.FloatTensor] = None,\n        cross_attn_head_mask: Optional[torch.Tensor] = None,\n        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask\n        if head_mask is not None and decoder_head_mask is None:\n            if self.config.num_layers == self.config.num_decoder_layers:\n                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)\n                decoder_head_mask = head_mask\n\n        # Encode if needed (training, first prediction pass)\n        if encoder_outputs is None:\n            # Convert encoder inputs in embeddings if needed\n            encoder_outputs = self.encoder(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                inputs_embeds=inputs_embeds,\n                head_mask=head_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):\n            encoder_outputs = BaseModelOutput(\n                last_hidden_state=encoder_outputs[0],\n                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,\n                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,\n            )\n\n        hidden_states = encoder_outputs[0]\n\n        if (\n            labels is not None\n            and decoder_input_ids is None\n            and decoder_inputs_embeds is None\n        ):\n            # get decoder inputs from shifting lm labels to the right\n            decoder_input_ids = self._shift_right(labels)\n\n        # Decode\n        decoder_outputs = self.decoder(\n            input_ids=decoder_input_ids,\n            attention_mask=decoder_attention_mask,\n            inputs_embeds=decoder_inputs_embeds,\n            past_key_values=past_key_values,\n            encoder_hidden_states=hidden_states,\n            encoder_attention_mask=attention_mask,\n            head_mask=decoder_head_mask,\n            cross_attn_head_mask=cross_attn_head_mask,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = decoder_outputs[0]\n\n        if self.config.tie_word_embeddings:\n            # Rescale output before projecting on vocab\n            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586\n            sequence_output = sequence_output * (self.model_dim**-0.5)\n\n        logits, speculative_logits = self.lm_head(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss(ignore_index=-100)\n            # move labels to correct device to enable PP\n            labels = labels.to(logits.device)\n            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))\n            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666\n\n        if not return_dict:\n            output = (logits,) + decoder_outputs[1:] + encoder_outputs\n            return ((loss,) + output) if loss is not None else output\n\n        return (\n            Seq2SeqLMOutput(\n                loss=loss,\n                logits=logits,\n                past_key_values=decoder_outputs.past_key_values,\n                decoder_hidden_states=decoder_outputs.hidden_states,\n                decoder_attentions=decoder_outputs.attentions,\n                cross_attentions=decoder_outputs.cross_attentions,\n                encoder_last_hidden_state=encoder_outputs.last_hidden_state,\n                encoder_hidden_states=encoder_outputs.hidden_states,\n                encoder_attentions=encoder_outputs.attentions,\n            ),\n            speculative_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        head_mask=None,\n        decoder_head_mask=None,\n        decoder_attention_mask=None,\n        cross_attn_head_mask=None,\n        use_cache=None,\n        encoder_outputs=None,\n        **kwargs,\n    ):\n        # cut decoder_input_ids if past is used\n        if past_key_values is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"decoder_input_ids\": input_ids,\n            \"past_key_values\": past_key_values,\n            \"encoder_outputs\": encoder_outputs,\n            \"attention_mask\": attention_mask,\n            \"head_mask\": head_mask,\n            \"decoder_head_mask\": decoder_head_mask,\n            \"decoder_attention_mask\": decoder_attention_mask,\n            \"cross_attn_head_mask\": cross_attn_head_mask,\n            \"use_cache\": use_cache,\n        }\n\n    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):\n        return self._shift_right(labels)\n\n    def _reorder_cache(self, past_key_values, beam_idx):\n        # if decoder past is not included in output\n        # speedy decoding is disabled and no need to reorder\n        if past_key_values is None:\n            logger.warning(\n                \"You might want to consider setting `use_cache=True` to speed up decoding\"\n            )\n            return past_key_values\n\n        reordered_decoder_past = ()\n        for layer_past_states in past_key_values:\n            # get the correct batch idx from layer past batch dim\n            # batch dim of `past` is at 2nd position\n            reordered_layer_past_states = ()\n            for layer_past_state in layer_past_states:\n                # need to set correct `past` for each of the four key / value states\n                reordered_layer_past_states = reordered_layer_past_states + (\n                    layer_past_state.index_select(\n                        0, beam_idx.to(layer_past_state.device)\n                    ),\n                )\n\n            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape\n            assert len(reordered_layer_past_states) == len(layer_past_states)\n\n            reordered_decoder_past = reordered_decoder_past + (\n                reordered_layer_past_states,\n            )\n        return reordered_decoder_past\n"
  },
  {
    "path": "server/text_generation_server/models/custom_modeling/vlm.py",
    "content": "def load_text_model(prefix, config, weights, name=None):\n    if config.model_type == \"llama\":\n        from text_generation_server.models.custom_modeling.flash_llama_modeling import (\n            FlashLlamaForCausalLM,\n        )\n\n        return FlashLlamaForCausalLM(prefix, config, weights, name=name)\n    elif config.model_type == \"mistral\":\n        from text_generation_server.models.custom_modeling.flash_mistral_modeling import (\n            FlashMistralForCausalLM,\n        )\n\n        return FlashMistralForCausalLM(prefix, config, weights, name=name)\n    elif config.model_type == \"gemma\":\n        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n            FlashGemmaForCausalLM,\n        )\n\n        return FlashGemmaForCausalLM(prefix, config, weights, causal=False)\n    elif config.model_type == \"gemma2\":\n        from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (\n            FlashGemma2ForCausalLM,\n        )\n\n        return FlashGemma2ForCausalLM(prefix, config, weights)\n\n    elif config.model_type == \"gemma3\" or config.model_type == \"gemma3_text\":\n        from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (\n            FlashGemma3ForCausalLM,\n        )\n\n        return FlashGemma3ForCausalLM(prefix, config, weights)\n    elif config.model_type == \"paligemma\":\n        from text_generation_server.models.custom_modeling.flash_gemma_modeling import (\n            FlashGemmaForCausalLM,\n        )\n\n        return FlashGemmaForCausalLM(prefix, config, weights)\n    else:\n        raise RuntimeError(f\"Unsupported model type {config.model_type}\")\n\n\ndef load_vision_model(prefix, config, weights):\n    if config.model_type == \"clip_vision_model\":\n        from text_generation_server.models.custom_modeling.clip import (\n            CLIPVisionTransformer,\n        )\n\n        return CLIPVisionTransformer(\n            prefix=f\"{prefix}.vision_model\", config=config, weights=weights\n        )\n    if (\n        config.model_type == \"siglip_vision_model\"\n        or config.model_type == \"gemma3_vision\"\n    ):\n        from text_generation_server.models.custom_modeling.siglip import (\n            SiglipVisionTransformer,\n        )\n\n        # TODO: ensure that using the prefix doesn't break any existing models\n        # that rely on the old prefix (update the old models if necessary)\n        return SiglipVisionTransformer(\n            # prefix=\"vision_model.vision_model\", config=config, weights=weights\n            prefix=f\"{prefix}.vision_model\",\n            config=config,\n            weights=weights,\n        )\n    else:\n        raise RuntimeError(f\"Unsupported model type {config.model_type}\")\n"
  },
  {
    "path": "server/text_generation_server/models/flash_causal_lm.py",
    "content": "from contextlib import nullcontext\nimport math\nimport os\nimport time\nimport torch\nimport torch.distributed\n\nimport numpy as np\n\nfrom loguru import logger\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    PreTrainedTokenizerBase,\n    AutoConfig,\n    AutoTokenizer,\n    GenerationConfig,\n)\nfrom typing import (\n    Any,\n    ContextManager,\n    Iterable,\n    Optional,\n    Tuple,\n    List,\n    Type,\n    Dict,\n    Union,\n)\n\nfrom text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata\nfrom huggingface_hub.constants import HUGGINGFACE_HUB_CACHE\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.models import Model\nfrom text_generation_server.utils.log import log_master\nfrom text_generation_server.utils.prefill_chunking import (\n    get_support_chunking,\n    get_max_prefill_tokens,\n)\nfrom text_generation_server.utils.tokens import batch_top_tokens\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.models.types import (\n    Batch,\n    Tokens,\n    Generation,\n    GeneratedText,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.globals import (\n    MEM_POOL,\n    ATTENTION,\n    BLOCK_SIZE,\n    CUDA_GRAPHS,\n    REQUEST_LOGPROBS,\n    TGI_WIGGLE_ROOM,\n    get_adapter_to_index,\n)\nfrom text_generation_server.layers.attention import KVCache, Seqlen\nfrom text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser\nfrom text_generation_server.utils.dist import MEMORY_FRACTION\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.segments import SegmentConcatBuilder, find_segments\n\nfrom text_generation_server.utils.import_utils import (\n    empty_cache,\n    synchronize,\n    get_free_memory,\n)\nfrom text_generation_server.models.metadata_kernels import (\n    has_triton,\n    copy_next_input_ids_inplace,\n    block_tables_to_ragged,\n    block_tables_to_padded,\n    prepare_position_slot_ids,\n    slots_filtering,\n)\n\ntracer = trace.get_tracer(__name__)\n\n\ndef small_power_of_2(n: int):\n    return 1 << ((n - 1).bit_length() - 1)\n\n\ndef init_cpu_threads_env(rank_id: int, world_size: int):\n    import importlib.util\n\n    if importlib.util.find_spec(\"numa\") is not None:\n        import numa\n        import psutil\n\n        nodes = numa.info.get_max_node() + 1\n        rank_per_node = math.ceil(world_size / nodes)\n        num_cpus_per_nodes = int(psutil.cpu_count(logical=False) / nodes)\n        node_id = int(rank_id / rank_per_node)\n        rank_offset_per_node = rank_id % rank_per_node\n        if os.getenv(\"OMP_NUM_THREADS\") is None:\n            num_cpus_per_rank = max(int(num_cpus_per_nodes / rank_per_node), 1)\n        else:\n            num_cpus_per_rank = int(os.getenv(\"OMP_NUM_THREADS\"))\n        if len(numa.memory.get_membind_nodes()) == nodes:\n            numa.memory.set_membind_nodes((node_id))\n        torch.set_num_threads(num_cpus_per_rank)\n        if len(numa.schedule.get_affinitive_cpus(0)) == psutil.cpu_count(logical=True):\n            cpu_start = num_cpus_per_rank * rank_offset_per_node\n            numa.schedule.run_on_cpus(\n                0,\n                *(\n                    numa.info.node_to_cpus(node_id)[\n                        cpu_start : cpu_start + num_cpus_per_rank\n                    ]\n                ),\n            )\n        logger.info(\n            f\"affinity={numa.schedule.get_affinitive_cpus(0)}, membind = {numa.memory.get_membind_nodes()}\"\n        )\n\n\n@dataclass\nclass FlashCausalLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    # request id -> idx in list mapping\n    requests_idx_mapping: Dict[int, int]\n\n    # Decoder values\n    # Can be a list for easy filtering\n    # If `input_ids` is a list, it needs to be materialized to a tensor first\n    input_ids: Union[torch.Tensor, List[List[int]]]\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    position_ids: Optional[torch.Tensor]\n    speculative_ids: Optional[torch.Tensor]\n\n    # Set when creating the batch\n    # tensor of indices of the currently used slots, length = \\sum_{i=0}^{b} s_i in prefill, length = b in decode\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    slot_indices: Optional[torch.Tensor]\n\n    # list of length b of list of length s_i // block_size\n    block_tables: List[List[int]]\n    # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences\n    block_tables_tensor: torch.Tensor\n    # tensor of length \\sum_{i=0}^{b} max_s_i  holding the paged attention slots for all sequences\n    slots: torch.Tensor\n    # list of length b + 1  containing the cumulative sequence slot lengths of the sequences in the batch\n    # used for filtering\n    cu_slots: torch.Tensor\n\n    max_input_length: int\n    max_current_length: int\n\n    # Whether this batch contains at least one request that is prefilling\n    prefilling: bool\n    # Whether each request is prefilling\n    prefilling_mask: List[bool]\n\n    # Prefill metadata tensors to efficiently compute logprobs\n    # tensor of length b + 1  containing the cumulative sequence lengths of the sequences in the batch, only used in prefill\n    cu_seqlen_prefill: Optional[torch.Tensor]\n    # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers\n    # as we only keep SLIDING_WINDOW values instead of the whole tensor\n    prefill_cache_indices: Optional[torch.Tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_head_indices: Optional[torch.Tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_next_token_indices: Optional[torch.tensor]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_cu_outlens: Optional[List[int]]\n    # Will be set by `generate_token` and reset after each prefill forward\n    prefill_logprob_tokens: List[Optional[Tokens]]\n\n    # All tokens\n    all_input_ids: List[List[int]]\n    all_input_ids_tensor: torch.Tensor\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    # size [b], containing the number of blocks that can be retrieved from the cache\n    cache_lengths: List[int]\n    prompt_lengths: List[int]\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    input_lengths_tensor: Optional[torch.Tensor]\n    cache_lengths_tensor: Optional[torch.Tensor]\n    prompt_lengths_tensor: torch.Tensor\n\n    prefix_offsets: List[Optional[int]]\n    read_offsets: List[Optional[int]]\n\n    # Generation helpers\n    next_token_chooser: HeterogeneousNextTokenChooser\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Adapter metadata for each request\n    # Will be set by `generate_token` and reset after each prefill forward before staying set in decode\n    adapter_meta: Optional[AdapterBatchMetadata]\n\n    # Number of blocks in this batch\n    num_blocks: int\n    # Maximum number of blocks\n    max_blocks: int\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.num_blocks * BLOCK_SIZE,\n            current_tokens=(\n                sum([len(i) for i in self.input_ids])\n                if isinstance(self.input_ids, list)\n                else len(self.input_ids)\n            ),\n        )\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[generate_pb2.Request], tokenizer\n    ):\n        max_length = 0\n        all_input_ids = []\n        batch_size = 0\n        for r in requests:\n            batch_size += 1\n            inputs = concat_text_chunks(r.input_chunks.chunks)\n            input_ids = tokenizer(\n                inputs,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=r.add_special_tokens,\n            )[\"input_ids\"]\n            max_length = max(max_length, len(input_ids))\n            all_input_ids.append(input_ids)\n        return all_input_ids\n\n    @classmethod\n    def from_tokenized(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        batch_tokenized_inputs,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashCausalLMBatch\":\n        speculate = get_speculate()\n\n        cache_lengths = []\n        input_lengths = []\n        prompt_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        all_postfix_ids = []\n        requests_idx_mapping = {}\n        slots = []\n        cu_slots = [0]\n\n        next_token_chooser_parameters = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        num_blocks = 0\n        max_input_length = 0\n        max_current_length = 0\n        max_length = 0\n        max_blocks = 0\n\n        cu_blocks = [0]\n        block_tables = []\n        block_tables_ragged = []\n\n        # Parse batch\n        for i, (r, tokenized_input) in enumerate(\n            zip(pb.requests, batch_tokenized_inputs)\n        ):\n            ### XXX: This consumes so much memory on long requests\n            ### Deactivating it by default seems like the best course.\n            if not REQUEST_LOGPROBS:\n                r.prefill_logprobs = False\n            # request id -> idx in list mapping\n            requests_idx_mapping[r.id] = i\n\n            prompt_length = len(tokenized_input)\n            prompt_lengths.append(prompt_length)\n\n            cache_length = r.cache_len\n\n            assert (\n                cache_length <= prompt_length\n            ), f\"Prefix {cache_length} vs input {prompt_length}\"\n            if cache_length == prompt_length:\n                assert False, \"unreachable\"\n\n            # `chunk_len` is an optional field in the protobuf\n            # It is only set if the model support chunking\n            if r.HasField(\"chunk_len\"):\n                input_length = r.chunk_len\n\n                if cache_length + input_length < prompt_length:\n                    # FIXME: speculate is not supported for context chunking at the moment\n                    assert speculate == 0\n                    assert get_support_chunking()\n                    assert input_length > 0\n\n                postfix_ids = tokenized_input[\n                    cache_length : cache_length + input_length\n                ]\n                assert (\n                    len(postfix_ids) == input_length\n                ), \"Rust and Python tokenizers are not aligned\"\n            else:\n                # Use all the remaining ids\n                postfix_ids = tokenized_input[cache_length:]\n                input_length = len(postfix_ids)\n\n            input_lengths.append(input_length)\n\n            prefix_offsets.append(prompt_length - 5)\n            read_offsets.append(prompt_length)\n\n            all_postfix_ids.append(postfix_ids)\n            all_input_ids.append(tokenized_input)\n\n            next_token_chooser_parameters.append(r.parameters)\n\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            max_new_tokens = stopping_criteria.max_new_tokens\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n\n            # Paged attention\n            # Remove one as the first token des not have a past\n            speculative_length = get_speculate()\n            speculative_length = 0 if speculative_length is None else speculative_length\n\n            # Tokens that need to be mapped to blocks.\n            block_tokens = prompt_length + max_new_tokens - 1 + speculative_length\n\n            # blocks and slots can be empty (for example in warmup)\n            if not r.blocks:\n                needed_blocks = math.ceil(block_tokens / BLOCK_SIZE)\n                request_blocks = [\n                    b for b in range(num_blocks, num_blocks + needed_blocks)\n                ]\n                request_slots = [\n                    s\n                    for b in request_blocks\n                    for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)\n                ]\n            else:\n                request_blocks = r.blocks\n                request_slots = r.slots\n\n            block_tables.append(request_blocks)\n            block_tables_ragged.extend(request_blocks)\n            cu_blocks.append(len(block_tables_ragged))\n\n            slots.extend(request_slots)\n            cu_slots.append(len(slots))\n\n            cache_lengths.append(cache_length)\n            num_blocks += len(request_blocks)\n\n            # Update\n            max_blocks = max(max_blocks, len(request_blocks))\n            max_input_length = max(max_input_length, input_length)\n            max_current_length = max(max_current_length, cache_length + input_length)\n            max_length = max(\n                max_length,\n                prompt_length + max_new_tokens + speculative_length,\n            )\n\n        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n            next_token_chooser_parameters, dtype, device, tokenizer\n        )\n\n        # Padded all_input_ids_tensor\n        all_input_ids_tensor = np.zeros(\n            (len(all_input_ids), max_length), dtype=np.int64\n        )\n        for i, input_ids in enumerate(all_input_ids):\n            all_input_ids_tensor[i, : len(input_ids)] = input_ids\n\n        # Create tensors on device\n        all_input_ids_tensor = torch.tensor(\n            all_input_ids_tensor, dtype=torch.int64, device=device\n        )\n\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n\n        block_tables_ragged = torch.tensor(\n            block_tables_ragged, device=device, dtype=torch.int32\n        )\n        cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)\n        block_tables_tensor = torch.empty(\n            (len(block_tables), max_blocks),\n            device=device,\n            dtype=torch.int32,\n        )\n\n        # If the device supports Triton, we can use a fused kernel\n        if has_triton():\n            block_tables_to_padded(\n                max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged\n            )\n        else:\n            for i, request_blocks in enumerate(block_tables):\n                block_tables_tensor[i, : len(request_blocks)] = torch.tensor(\n                    request_blocks\n                )\n\n        prompt_lengths_tensor = torch.tensor(\n            prompt_lengths, dtype=torch.int32, device=device\n        )\n\n        slots = torch.tensor(slots, dtype=torch.int64, device=device)\n        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=all_postfix_ids,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            cache_lengths=cache_lengths,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=True,\n            prefilling_mask=[True] * len(pb.requests),\n            prefill_logprob_tokens=[None] * len(pb.requests),\n            input_lengths=input_lengths,\n            prompt_lengths=prompt_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=all_input_ids_tensor,\n            next_token_chooser=next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=None,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids=None,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=None,\n            slots=slots,\n            cu_slots=cu_slots,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            cache_lengths_tensor=None,\n            input_lengths_tensor=None,\n            adapter_meta=None,\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"FlashCausalLMBatch\":\n        assert len(pb.requests) > 0\n        batch_tokenized_inputs = cls.batch_tokenized_inputs(pb.requests, tokenizer)\n        return cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> \"FlashCausalLMBatch\":\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        # We assume that if len(requests) == len(self) then the requests are the same\n        if len(request_ids) == len(self):\n            return self\n\n        device = self.block_tables_tensor.device\n\n        # New values after filtering\n        requests_idx_mapping = {}\n\n        # Used to index into tensors\n        indices = []\n\n        if not has_triton():\n            # slots to keep after filtering\n            slot_filtering_indices = torch.zeros(\n                self.slots.shape[0], dtype=torch.bool, device=device\n            )\n\n        # Create on CPU to only move to GPU once instead of at every copy\n        slot_indices = torch.empty(len(request_ids), dtype=torch.int64)\n        max_input_length = 0\n        max_current_length = 0\n\n        requests = []\n        block_tables = []\n        all_input_ids = []\n        input_ids = []\n\n        prompt_lengths = []\n        input_lengths = []\n        cache_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        cu_slots = [0]\n\n        prefilling_mask = []\n        prefill_logprob_tokens = []\n\n        stopping_criterias = []\n        top_n_tokens = []\n        adapter_set = set()\n\n        num_blocks = 0\n        max_blocks = 0\n        max_slots = 0\n        cumulative_slot_tokens = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            indices.append(idx)\n            requests_idx_mapping[request_id] = i\n\n            requests.append(self.requests[idx])\n\n            # Prefilling\n            request_prefilling = self.prefilling_mask[idx]\n            prefilling_mask.append(request_prefilling)\n\n            # Get length\n            request_input_length = self.input_lengths[idx]\n            request_cache_length = self.cache_lengths[idx]\n            max_input_length = max(max_input_length, request_input_length)\n            max_current_length = max(\n                max_current_length, request_cache_length + request_input_length\n            )\n\n            all_input_ids.append(self.all_input_ids[idx])\n\n            prompt_lengths.append(self.prompt_lengths[idx])\n            input_lengths.append(request_input_length)\n            cache_lengths.append(request_cache_length)\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n\n            top_n_tokens.append(self.top_n_tokens[idx])\n            prefill_logprob_tokens.append(self.prefill_logprob_tokens[idx])\n\n            ADAPTER_TO_INDEX = get_adapter_to_index()\n            adapter_index = ADAPTER_TO_INDEX.get(self.requests[idx].adapter_id, 0)\n            adapter_set.add(adapter_index)\n\n            request_block_table = self.block_tables[idx]\n            num_blocks += len(request_block_table)\n            block_tables.append(request_block_table)\n\n            start_slot = self.cu_slots[idx]\n            end_slot = self.cu_slots[idx + 1]\n            slot_length = end_slot - start_slot\n\n            if not has_triton():\n                # Set slice\n                slot_filtering_indices[start_slot:end_slot] = True\n\n            cu_slots.append(cumulative_slot_tokens + slot_length)\n\n            # Input ids if the request was part of a prefilling batch\n            # If the batch was decoding we can index into the tensor directly later\n            if self.prefilling:\n                input_ids.append(self.input_ids[idx])\n            else:\n                # Copy to tensor (CPU)\n                slot_indices[i] = cumulative_slot_tokens + request_cache_length\n\n            cumulative_slot_tokens += slot_length\n            max_blocks = max(max_blocks, len(request_block_table))\n            max_slots = max(max_slots, slot_length)\n\n        all_input_ids_tensor = self.all_input_ids_tensor[indices]\n        block_tables_tensor = self.block_tables_tensor[indices]\n        next_token_chooser = self.next_token_chooser.filter(indices)\n        top_n_tokens_tensor = self.top_n_tokens_tensor[indices]\n        speculative_ids = (\n            self.speculative_ids[indices] if self.speculative_ids is not None else None\n        )\n        prompt_lengths_tensor = self.prompt_lengths_tensor[indices]\n\n        cu_slots = torch.tensor(cu_slots, dtype=torch.int64)\n\n        if not has_triton():\n            slots = self.slots[slot_filtering_indices]\n        else:\n            slots = self.slots.new_empty(cumulative_slot_tokens)\n            gpu_cu_slots = cu_slots.to(device)\n            slots_indexing_start = self.cu_slots.to(device)[indices]\n            slots_filtering(\n                max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start\n            )\n\n        if self.prefilling:\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids = None\n            slot_indices = None\n            cache_lengths_tensor = None\n            input_lengths_tensor = None\n            adapter_meta = None\n        else:\n            # Index into tensors\n            input_ids = self.input_ids[indices]\n            position_ids = self.position_ids[indices]\n            adapter_indices = self.adapter_meta.adapter_indices[indices]\n            input_lengths_tensor = self.input_lengths_tensor[indices]\n            cache_lengths_tensor = self.cache_lengths_tensor[indices]\n\n            # Move to GPU now that we have the whole tensor\n            slot_indices = slot_indices.to(device)\n\n            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)\n            adapter_segments = torch.tensor(\n                adapter_segments, dtype=torch.int32, device=device\n            )\n            adapter_meta = AdapterBatchMetadata(\n                adapter_indices=adapter_indices,\n                adapter_set=adapter_set,\n                adapter_segments=adapter_segments,\n                segment_indices=adapter_segment_indices,\n            )\n\n        return type(self)(\n            batch_id=self.batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=slot_indices,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            slots=slots,\n            cu_slots=cu_slots,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=self.prefilling,\n            prefilling_mask=prefilling_mask,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            prefill_logprob_tokens=prefill_logprob_tokens,\n            prompt_lengths=prompt_lengths,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            input_lengths=input_lengths,\n            input_lengths_tensor=input_lengths_tensor,\n            cache_lengths=cache_lengths,\n            cache_lengths_tensor=cache_lengths_tensor,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=all_input_ids_tensor,\n            next_token_chooser=next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=speculative_ids,\n            adapter_meta=adapter_meta,\n        )\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches: List[\"FlashCausalLMBatch\"]) -> \"FlashCausalLMBatch\":\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n\n        prefilling = False\n        num_blocks = 0\n        total_batch_size = 0\n        total_slots = 0\n        max_blocks = 0\n        max_length = 0\n        max_input_length = 0\n        max_current_length = 0\n        for b in batches:\n            total_batch_size += len(b)\n            max_blocks = max(max_blocks, b.max_blocks)\n            total_slots += len(b.slots)\n            num_blocks += b.num_blocks\n            speculative_length = (\n                b.speculative_ids.shape[1] if b.speculative_ids is not None else 0\n            )\n            max_input_length = max(max_input_length, b.max_input_length)\n            max_current_length = max(max_current_length, b.max_current_length)\n            max_length = max(\n                max_length,\n                max(\n                    prompt_length\n                    + stopping_criteria.max_new_tokens\n                    + speculative_length\n                    for prompt_length, stopping_criteria in zip(\n                        b.prompt_lengths, b.stopping_criterias\n                    )\n                ),\n            )\n            prefilling = prefilling or b.prefilling\n\n        slots = batches[0].slots.new_empty(total_slots)\n        cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)\n        if prefilling:\n            input_ids = []\n            # These values will be set by `FlashCausalLMBatch.prepare_for_prefill`\n            position_ids = None\n            slot_indices = None\n            cache_lengths_tensor = None\n            input_lengths_tensor = None\n            adapter_meta = None\n            adapter_segment_builder = None\n        else:\n            input_ids = batches[0].input_ids.new_empty(total_batch_size)\n            if (\n                batches[0].position_ids is not None\n                and batches[0].position_ids.dim() == 2\n            ):\n                # Qwen2_vl case:\n                position_ids = batches[0].position_ids.new_empty(\n                    (total_batch_size, batches[0].position_ids.shape[-1])\n                )\n            else:\n                position_ids = batches[0].position_ids.new_empty(total_batch_size)\n            slot_indices = batches[0].slot_indices.new_empty(total_batch_size)\n            input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(\n                total_batch_size\n            )\n            cache_lengths_tensor = batches[0].cache_lengths_tensor.new_empty(\n                total_batch_size\n            )\n            total_indices_size = sum(\n                b.adapter_meta.adapter_indices.shape[0] for b in batches\n            )\n            adapter_indices = batches[0].adapter_meta.adapter_indices.new_empty(\n                total_indices_size\n            )\n            adapter_segment_builder = SegmentConcatBuilder()\n            adapter_set = set()\n\n        prompt_lengths_tensor = batches[0].prompt_lengths_tensor.new_empty(\n            total_batch_size\n        )\n        block_tables_tensor = batches[0].block_tables_tensor.new_zeros(\n            (total_batch_size, max_blocks)\n        )\n        all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(\n            (total_batch_size, max_length)\n        )\n        top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n            total_batch_size,\n        )\n\n        block_tables = []\n        cache_lengths = []\n        all_input_ids = []\n\n        prompt_lengths = []\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n\n        prefill_logprob_tokens = []\n\n        next_token_chooser_parameters = []\n        fsm_grammar_states = []\n        stopping_criterias = []\n        top_n_tokens = []\n        prefilling_mask = []\n\n        # Cumulative length\n        cumulative_batch_size = 0\n        cumulative_slots = 0\n        cumulative_adapter_indices_size = 0\n\n        for i, batch in enumerate(batches):\n            requests.extend(batch.requests)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + cumulative_batch_size\n\n            start_index = cumulative_batch_size\n            end_index = cumulative_batch_size + len(batch)\n\n            # Copy tensors (GPU)\n            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor\n            all_input_ids_tensor[\n                start_index:end_index, : batch.all_input_ids_tensor.shape[1]\n            ] = batch.all_input_ids_tensor[:, :max_length]\n\n            block_tables_tensor[\n                start_index:end_index, : batch.block_tables_tensor.shape[1]\n            ] = batch.block_tables_tensor[:, :max_blocks]\n            prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor\n\n            slots_start_index = cumulative_slots\n            slots_end_index = cumulative_slots + len(batch.slots)\n            slots[slots_start_index:slots_end_index] = batch.slots\n            cu_slots[start_index + 1 : end_index + 1] = (\n                batch.cu_slots[1:] + cumulative_slots\n            )\n\n            if not prefilling:\n                input_ids[start_index:end_index] = batch.input_ids\n                position_ids[start_index:end_index] = batch.position_ids\n                slot_indices[start_index:end_index] = (\n                    batch.slot_indices + cumulative_slots\n                )\n                input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor\n                cache_lengths_tensor[start_index:end_index] = batch.cache_lengths_tensor\n\n                # Copy over adapter indices\n                adapter_start_index = cumulative_adapter_indices_size\n                adapter_end_index = (\n                    cumulative_adapter_indices_size\n                    + batch.adapter_meta.adapter_indices.shape[0]\n                )\n                adapter_indices[adapter_start_index:adapter_end_index] = (\n                    batch.adapter_meta.adapter_indices\n                )\n                cumulative_adapter_indices_size = adapter_end_index\n                adapter_set.update(batch.adapter_meta.adapter_set)\n                adapter_segment_builder.concat(\n                    batch.adapter_meta.adapter_segments,\n                    batch.adapter_meta.segment_indices,\n                )\n            else:\n                if isinstance(batch.input_ids, torch.Tensor):\n                    batch.input_ids = batch.input_ids.view(-1, 1).tolist()\n                input_ids.extend(batch.input_ids)\n\n            prefilling_mask.extend(batch.prefilling_mask)\n            block_tables.extend(batch.block_tables)\n            cache_lengths.extend(batch.cache_lengths)\n            all_input_ids.extend(batch.all_input_ids)\n\n            prompt_lengths.extend(batch.prompt_lengths)\n            input_lengths.extend(batch.input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n\n            prefill_logprob_tokens.extend(batch.prefill_logprob_tokens)\n\n            next_token_chooser_parameters.extend([r.parameters for r in batch.requests])\n            fsm_grammar_states.extend(batch.next_token_chooser.fsm_grammar_states)\n            stopping_criterias.extend(batch.stopping_criterias)\n\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            # Update\n            cumulative_slots += len(batch.slots)\n            cumulative_batch_size += len(batch)\n\n        next_token_chooser = HeterogeneousNextTokenChooser.from_pb(\n            next_token_chooser_parameters,\n            dtype=batches[0].next_token_chooser.dtype,\n            device=batches[0].next_token_chooser.device,\n            tokenizer=batches[0].next_token_chooser.tokenizer,\n            fsm_grammar_states=fsm_grammar_states,\n        )\n\n        # We skip computing the speculative_ids when the batch size is too large, so\n        # we must check that all batches have them, otherwise they must be discarded\n        if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches):\n            speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0)\n        else:\n            speculative_ids = None\n\n        if adapter_segment_builder is not None:\n            adapter_segments, adapter_segment_indices = adapter_segment_builder.build()\n            adapter_meta = AdapterBatchMetadata(\n                adapter_indices=adapter_indices,\n                adapter_set=adapter_set,\n                adapter_segments=adapter_segments,\n                segment_indices=adapter_segment_indices,\n            )\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=None,\n            prefill_cache_indices=None,\n            slot_indices=slot_indices,\n            block_tables=block_tables,\n            block_tables_tensor=block_tables_tensor,\n            cache_lengths=cache_lengths,\n            cache_lengths_tensor=cache_lengths_tensor,\n            slots=slots,\n            cu_slots=cu_slots,\n            max_input_length=max_input_length,\n            max_current_length=max_current_length,\n            prefilling=prefilling,\n            prefilling_mask=prefilling_mask,\n            prefill_head_indices=None,\n            prefill_next_token_indices=None,\n            prefill_cu_outlens=None,\n            prefill_logprob_tokens=prefill_logprob_tokens,\n            prompt_lengths=prompt_lengths,\n            prompt_lengths_tensor=prompt_lengths_tensor,\n            input_lengths=input_lengths,\n            input_lengths_tensor=input_lengths_tensor,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            all_input_ids=all_input_ids,\n            all_input_ids_tensor=all_input_ids_tensor,\n            next_token_chooser=next_token_chooser,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            num_blocks=num_blocks,\n            max_blocks=max_blocks,\n            speculative_ids=speculative_ids,\n            adapter_meta=adapter_meta,\n        )\n\n    def prepare_for_prefill(self):\n        # Prepare values if we need to continue prefilling\n        # Speculation must be ignored while we prefill even with chunking\n        # it simplifies everything\n        assert self.speculative_ids is None\n\n        device = self.block_tables_tensor.device\n\n        if isinstance(self.input_ids, list):\n            if len(self) > 1:\n                input_ids = np.concatenate(self.input_ids, dtype=np.int64)\n            else:\n                input_ids = self.input_ids[0]\n            self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)\n\n        self.input_lengths_tensor = torch.tensor(\n            self.input_lengths, dtype=torch.int32, device=device\n        )\n        cu_seqlen_prefill = self.input_lengths_tensor.new_zeros(len(self) + 1)\n        torch.cumsum(self.input_lengths_tensor, out=cu_seqlen_prefill[1:], dim=0)\n        self.cu_seqlen_prefill = cu_seqlen_prefill.to(torch.int32)\n        self.cache_lengths_tensor = torch.tensor(\n            self.cache_lengths, dtype=torch.int32, device=device\n        )\n\n        # If the device supports Triton, we can use a fused kernel\n        if has_triton():\n            self.position_ids = torch.empty(\n                len(self.input_ids), dtype=torch.int32, device=device\n            )\n            self.slot_indices = torch.empty(\n                len(self.input_ids), dtype=torch.int64, device=device\n            )\n            cu_slots_gpu = self.cu_slots.to(device)\n\n            prepare_position_slot_ids(\n                self.max_input_length,\n                self.cache_lengths_tensor,\n                self.cu_seqlen_prefill,\n                cu_slots_gpu,\n                self.position_ids,\n                self.slot_indices,\n            )\n\n        position_ids = []\n        slot_indices = []\n        all_prefill_logprobs = True\n        no_prefill_logprobs = True\n        prefill_cu_outlens = [0]\n\n        # Cumulative length\n        cumulative_length = 0\n        cumulative_slot_tokens = 0\n        prefill_out_cumulative_length = 0\n\n        adapter_indices_list = []\n        adapter_set = set()\n\n        for i, (\n            r,\n            cache_length,\n            input_length,\n            prompt_length,\n            request_prefilling,\n            blocks,\n        ) in enumerate(\n            zip(\n                self.requests,\n                self.cache_lengths,\n                self.input_lengths,\n                self.prompt_lengths,\n                self.prefilling_mask,\n                self.block_tables,\n            )\n        ):\n            next_chunk_length = input_length\n\n            if not has_triton():\n                # Position ids\n                request_position_ids = torch.arange(\n                    cache_length, cache_length + input_length, dtype=torch.int32\n                )\n                position_ids.append(request_position_ids)\n\n                if not r.slots:\n                    request_slots = [\n                        s\n                        for b in blocks\n                        for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)\n                    ]\n                else:\n                    request_slots = r.slots\n\n                request_slot_indices = torch.arange(\n                    cache_length + cumulative_slot_tokens,\n                    cache_length + cumulative_slot_tokens + input_length,\n                    dtype=torch.int64,\n                )\n\n                slot_indices.append(request_slot_indices)\n\n                # Update\n                cumulative_slot_tokens += len(request_slots)\n\n            # Prefill logprobs is ignored if the request is done prefilling\n            prefill_logprobs = r.prefill_logprobs and request_prefilling\n\n            all_prefill_logprobs = all_prefill_logprobs and prefill_logprobs\n            no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs\n\n            if prefill_logprobs:\n                prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)\n                prefill_out_cumulative_length += input_length\n            else:\n                prefill_cu_outlens.append(prefill_out_cumulative_length + 1)\n                prefill_out_cumulative_length += 1\n\n            ADAPTER_TO_INDEX = get_adapter_to_index()\n            if ADAPTER_TO_INDEX:\n                adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)\n                adapter_indices_list.append(\n                    torch.full((next_chunk_length,), adapter_index)\n                )\n                adapter_set.add(adapter_index)\n\n            # Update\n            cumulative_length += next_chunk_length\n\n        if not all_prefill_logprobs and not no_prefill_logprobs:\n            prefill_head_indices = []\n            prefill_next_token_indices = []\n\n            # Cumulative length\n            cumulative_length = 0\n            prefill_out_cumulative_length = 0\n\n            for i, (\n                r,\n                input_length,\n                request_prefilling,\n            ) in enumerate(\n                zip(\n                    self.requests,\n                    self.input_lengths,\n                    self.prefilling_mask,\n                )\n            ):\n                # Prefill logprobs is ignored if the request is done prefilling\n                prefill_logprobs = r.prefill_logprobs and request_prefilling\n\n                if prefill_logprobs:\n                    prefill_head_indices.append(\n                        torch.arange(\n                            cumulative_length,\n                            cumulative_length + input_length,\n                            dtype=torch.int64,\n                        )\n                    )\n                    prefill_next_token_indices.append(\n                        prefill_out_cumulative_length + input_length - 1\n                    )\n                    prefill_out_cumulative_length += input_length\n                else:\n                    prefill_head_indices.append(\n                        torch.tensor(\n                            [cumulative_length + input_length - 1],\n                            dtype=torch.int64,\n                        )\n                    )\n                    prefill_next_token_indices.append(prefill_out_cumulative_length)\n                    prefill_out_cumulative_length += 1\n\n                # Update\n                cumulative_length += input_length\n\n        if len(self) > 1:\n            if position_ids:\n                position_ids = torch.cat(position_ids)\n            if slot_indices:\n                slot_indices = torch.cat(slot_indices)\n        else:\n            if position_ids:\n                position_ids = position_ids[0]\n            if slot_indices:\n                slot_indices = slot_indices[0]\n\n        if not has_triton():\n            self.position_ids = position_ids.to(device)\n            self.slot_indices = slot_indices.to(device)\n\n        self.prefill_cu_outlens = prefill_cu_outlens\n        self.prefill_cache_indices = None\n\n        if all_prefill_logprobs:\n            prefill_head_indices = None\n            prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1\n        elif no_prefill_logprobs:\n            prefill_head_indices = self.cu_seqlen_prefill[1:] - 1\n            prefill_next_token_indices = None\n        else:\n            prefill_head_indices = torch.cat(prefill_head_indices).to(device)\n            prefill_next_token_indices = torch.tensor(\n                prefill_next_token_indices, dtype=torch.int64, device=device\n            )\n\n        self.prefill_head_indices = prefill_head_indices\n        self.prefill_next_token_indices = prefill_next_token_indices\n\n        if adapter_set:\n            adapter_indices = torch.cat(adapter_indices_list).to(\n                dtype=torch.int64, device=device\n            )\n            adapter_segments, adapter_segment_indices = find_segments(adapter_indices)\n        else:\n            adapter_indices = torch.zeros_like(self.input_ids)\n            adapter_segments = [0, len(adapter_indices)]\n            adapter_segment_indices = [len(adapter_indices) - 1]\n\n        adapter_segments = torch.tensor(\n            adapter_segments, dtype=torch.int32, device=device\n        )\n\n        self.adapter_meta = AdapterBatchMetadata(\n            adapter_indices=adapter_indices,\n            adapter_set=adapter_set,\n            adapter_segments=adapter_segments,\n            segment_indices=adapter_segment_indices,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nADAPTER_LAYERS = [\n    \"q_proj\",\n    \"k_proj\",\n    \"v_proj\",\n    \"o_proj\",\n    \"gate_proj\",\n    \"up_proj\",\n    \"down_proj\",\n]\nROW_PARALLEL = {\"o_proj\", \"down_proj\", \"lm_head\"}\n\n\nclass FlashCausalLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n        lora_adapter_ids: Optional[list] = [],\n        tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,\n        config_class: PreTrainedTokenizerBase = AutoConfig,\n        default_dtype=torch.float16,\n        aliases=None,\n        # Used for Santacoder override of config\n        num_kv_heads: Optional[int] = None,\n        # Deepseek V2 uses different QK and V dims.\n        head_size: Optional[int] = None,\n        skip_special_tokens: bool = True,\n        kv_cache_dtype: Optional[torch.dtype] = None,\n        support_chunking: bool = True,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            if hasattr(torch, \"xpu\") and torch.xpu.is_available():\n                device = torch.device(f\"xpu:{rank}\")\n                dtype = default_dtype if dtype is None else dtype\n            else:\n                device = torch.device(\"cpu\")\n                dtype = torch.bfloat16 if dtype is None else dtype\n                init_cpu_threads_env(rank_id=rank, world_size=world_size)\n        else:\n            raise NotImplementedError(f\"{model_class} is only available on GPU\")\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        try:\n            generation_config = GenerationConfig.from_pretrained(\n                model_id, revision=revision, trust_remote_code=trust_remote_code\n            )\n            if isinstance(generation_config.eos_token_id, (list, set)):\n                # TODO Huge hack\n                tokenizer._eos_token_ids = set(generation_config.eos_token_id)\n        except Exception:\n            pass\n\n        config = config_class.from_pretrained(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n\n        torch.distributed.barrier(group=self.process_group)\n\n        weights_loader = get_loader(quantize, model_id, revision)\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device,\n            dtype,\n            process_group=self.process_group,\n            aliases=aliases,\n            weights_loader=weights_loader,\n        )\n\n        prefix = None\n        model = model_class(prefix, config, weights)\n        torch.distributed.barrier(group=self.process_group)\n\n        # VLM models define the config we care about in their text_config\n        text_config = getattr(config, \"text_config\", None)\n        if text_config is not None:\n            config = text_config\n\n        if getattr(config, \"sliding_window\", None) is None:\n            config.sliding_window = None\n\n        self.num_layers = config.num_hidden_layers\n        self.num_heads = config.num_attention_heads // self.process_group.size()\n        self.config = config\n        # Validation is done in the model itself\n        if num_kv_heads is None:\n            num_kv_heads = getattr(config, \"num_key_value_heads\", None)\n            # GPT-2 workaround\n            if num_kv_heads is None:\n                num_kv_heads = getattr(config, \"n_head\", None)\n        if num_kv_heads is None:\n            raise ValueError(\"Cannot get the number of key/value heads\")\n        self.num_kv_heads = (\n            num_kv_heads // self.process_group.size()\n            if num_kv_heads > 1\n            else num_kv_heads\n        )\n        assert self.num_kv_heads > 0\n\n        if head_size is None:\n            # Some models use GQA and different sizes for o_proj\n            # and q_proj, that allows for that.\n            if getattr(config, \"head_dim\", None) is not None:\n                self.head_size = config.head_dim\n            else:\n                self.head_size = config.hidden_size // config.num_attention_heads\n        else:\n            self.head_size = head_size\n\n        self.cuda_graphs = {}\n        self.kv_cache = []\n        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype\n\n        if ATTENTION == \"flashinfer\":\n            from text_generation_server.layers.attention.flashinfer import (\n                create_prefill_state,\n                create_decode_state,\n                create_prefill_with_paged_kv_state,\n            )\n\n            self.prefill_state = create_prefill_state(device=device)\n            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(\n                device=device\n            )\n\n            self.decode_state = create_decode_state(\n                device=device,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n            )\n\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=False,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n            sliding_window=config.sliding_window,\n            support_chunking=support_chunking,\n        )\n\n    @property\n    def batch_type(self) -> Type[FlashCausalLMBatch]:\n        return FlashCausalLMBatch\n\n    def init_kv_cache(\n        self,\n        num_blocks: int,\n        num_layers: int,\n        num_heads: int,\n        head_size: int,\n        dtype: torch.dtype,\n        device: torch.device,\n    ):\n        self.kv_cache = []\n        empty_cache()\n        self.kv_cache = [\n            KVCache(\n                num_blocks=num_blocks,\n                num_heads=num_heads,\n                head_size=head_size,\n                dtype=dtype,\n                device=device,\n            )\n            for _ in range(num_layers)\n        ]\n\n    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):\n        max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None\n        input_lengths = [max_s] * bs\n        cache_lengths = [0] * bs\n        if max_bs is None:\n            input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)\n            position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)\n            config = getattr(self.model, \"config\", None)\n            rope_scaling = getattr(config, \"rope_scaling\", None) if config else None\n            if (  # mrope have position_ids per section, if so repeat n times\n                isinstance(rope_scaling, dict) and rope_scaling[\"rope_type\"] == \"mrope\"\n            ):\n                n_sections = len(self.model.config.rope_scaling[\"mrope_section\"])\n                position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)\n            slots = torch.arange(bs, dtype=torch.int64, device=self.device)\n            input_lengths_tensor = (\n                torch.ones(bs, dtype=torch.int32, device=self.device) * max_s\n            )\n            cache_lengths_tensor = torch.zeros(\n                bs, dtype=torch.int32, device=self.device\n            )\n            block_tables = torch.arange(\n                max_bt, dtype=torch.int32, device=self.device\n            ).repeat(bs)\n            block_tables = block_tables.reshape((bs, max_bt))\n            if ATTENTION == \"flashinfer\":\n                block_tables = block_tables_to_ragged(\n                    block_tables=block_tables,\n                    input_lengths=input_lengths,\n                    cache_lengths=cache_lengths,\n                    input_lengths_tensor=input_lengths_tensor,\n                    cache_lengths_tensor=cache_lengths_tensor,\n                    max_current_length=max_s,\n                )\n        else:\n            if bs > max_bs:\n                raise RuntimeError(\n                    \"Cuda graphs should be generated in decreasing order size to reduce VRAM usage\"\n                )\n            input_ids = self.cuda_graphs[max_bs][\"input_ids\"][:bs]\n            position_ids = self.cuda_graphs[max_bs][\"position_ids\"][:bs]\n            if ATTENTION == \"flashinfer\":\n                block_tables = self.cuda_graphs[max_bs][\"block_tables\"][: bs * max_bt]\n            else:\n                block_tables = self.cuda_graphs[max_bs][\"block_tables\"][:bs]\n            slots = self.cuda_graphs[max_bs][\"slots\"][:bs]\n            input_lengths_tensor = self.cuda_graphs[max_bs][\"input_lengths\"][:bs]\n            cache_lengths_tensor = self.cuda_graphs[max_bs][\"cache_lengths\"][:bs]\n\n        if ATTENTION == \"flashinfer\":\n            from text_generation_server.layers.attention.flashinfer import (\n                create_decode_state_cuda_graphs,\n            )\n\n            block_tables_ptr = torch.zeros(\n                bs + 1, dtype=torch.int32, device=self.device\n            )\n            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)\n            state = create_decode_state_cuda_graphs(\n                device=input_ids.device,\n                block_tables=block_tables,\n                block_tables_ptr=block_tables_ptr,\n                last_page_len=last_page_len,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n            )\n        else:\n            state = None\n\n        graph = torch.cuda.CUDAGraph()\n        self.cuda_graphs[bs] = {\n            \"input_ids\": input_ids,\n            \"position_ids\": position_ids,\n            \"kv_cache\": self.kv_cache,\n            \"block_tables\": block_tables,\n            \"slots\": slots,\n            \"input_lengths\": input_lengths_tensor,\n            \"cache_lengths\": cache_lengths_tensor,\n            \"state\": state,\n            \"graph\": graph,\n        }\n\n        torch.cuda.synchronize()\n        # Run once outside to warmup\n        with self._forward_context(\n            block_tables=block_tables,\n            cu_seqlen_prefill=None,\n            input_lengths_tensor=input_lengths_tensor,\n            state=state,\n            cache_lengths_tensor=cache_lengths_tensor,\n        ):\n            seqlen = Seqlen(\n                input_lengths=input_lengths_tensor,\n                cache_lengths=cache_lengths_tensor,\n                cu_seqlen_q=None,\n                max_q=1,\n                max_k=max_s,\n            )\n            self.model.forward(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                cu_seqlen_prefill=None,\n                kv_cache=self.kv_cache,\n                block_tables=block_tables,\n                slots=slots,\n                seqlen=seqlen,\n                max_s=max_s,\n                prefill_cache_indices=None,\n                lm_head_indices=None,\n            )\n            del seqlen\n\n            torch.cuda.synchronize()\n\n            with torch.cuda.graph(graph, pool=MEM_POOL):\n                seqlen = Seqlen(\n                    input_lengths=input_lengths_tensor,\n                    cache_lengths=cache_lengths_tensor,\n                    cu_seqlen_q=None,\n                    max_q=1,\n                    max_k=max_s,\n                )\n                logits, speculative_logits = self.model.forward(\n                    input_ids=input_ids,\n                    position_ids=position_ids,\n                    cu_seqlen_prefill=None,\n                    kv_cache=self.kv_cache,\n                    block_tables=block_tables,\n                    slots=slots,\n                    seqlen=seqlen,\n                    max_s=max_s,\n                    prefill_cache_indices=None,\n                    lm_head_indices=None,\n                )\n                self.cuda_graphs[bs][\"logits\"] = logits\n                self.cuda_graphs[bs][\"speculative_logits\"] = speculative_logits\n        torch.cuda.synchronize()\n\n    def warmup(\n        self,\n        batch: FlashCausalLMBatch,\n        max_input_tokens: Optional[int],\n        max_total_tokens: Optional[int],\n    ):\n        # The warmup batch is the biggest batch we could ever receive\n        self.kv_cache = []\n        empty_cache()\n\n        # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)\n        # Calculate the number of blocks that can be allocated with the free memory\n        dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()\n        cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size\n        total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size\n\n        try:\n            self.init_kv_cache(\n                batch.num_blocks,\n                self.num_layers,\n                self.num_kv_heads,\n                self.head_size,\n                self.kv_cache_dtype,\n                self.device,\n            )\n\n            batch_num_blocks = batch.num_blocks\n\n            num_tokens = batch.to_pb().current_tokens\n            if SYSTEM == \"rocm\" and os.environ.get(\"PYTORCH_TUNABLEOP_ENABLED\", False):\n                torch.cuda.tunable.tuning_enable(False)\n            synchronize(self.device)\n            free_memory = get_free_memory(\n                self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM\n            )\n            real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)\n            log_master(\n                logger.debug,\n                f\"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB\",\n            )\n\n            _, _batch, _ = self.generate_token(batch)\n        except torch.cuda.OutOfMemoryError as e:\n            raise RuntimeError(\n                f\"Not enough memory to handle {num_tokens} prefill tokens. \"\n                f\"You need to decrease `--max-batch-prefill-tokens`\"\n            ) from e\n\n        synchronize(self.device)\n        free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)\n        kv_memory = free_memory\n        num_blocks = (\n            # Leave 5% for some wiggle room\n            int(kv_memory // total_cache_size)\n            # Add batch.num_blocks as we allocated it above, so it is included in the peak memory.\n            + batch_num_blocks\n        )\n\n        log_master(logger.info, f\"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}\")\n        if max_total_tokens is None:\n            if get_support_chunking():\n                model_max_length = self.tokenizer.model_max_length\n                max_position_embeddings = getattr(\n                    self.config, \"max_position_embeddings\", model_max_length\n                )\n                max_total_tokens = min(\n                    num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings\n                )\n            else:\n                max_total_tokens = sum(batch.cache_lengths)\n\n        if max_input_tokens is None:\n            max_input_tokens = max_total_tokens - 1\n\n        del _batch, batch\n        self.kv_cache = []\n        empty_cache()\n\n        self.init_kv_cache(\n            num_blocks,\n            self.num_layers,\n            self.num_kv_heads,\n            self.head_size,\n            self.kv_cache_dtype,\n            self.device,\n        )\n\n        if SYSTEM == \"rocm\":\n            if (\n                os.environ.get(\"PYTORCH_TUNABLEOP_ENABLED\") is None\n                or os.environ.get(\"PYTORCH_TUNABLEOP_ENABLED\") == \"1\"\n            ):\n                torch.cuda.tunable.enable()\n\n                if os.environ.get(\"PYTORCH_TUNABLEOP_TUNING\") != \"0\":\n                    torch.cuda.tunable.tuning_enable(True)\n\n                if os.environ.get(\"PYTORCH_TUNABLEOP_SEQLENS\") is not None:\n                    tuning_sequences = [\n                        int(val)\n                        for val in os.environ[\"PYTORCH_TUNABLEOP_SEQLENS\"].split(\",\")\n                    ]\n                elif CUDA_GRAPHS is not None:\n                    tuning_sequences = CUDA_GRAPHS\n                else:\n                    tuning_sequences = [1, 2, 3, 4, 5, 6, 7]\n\n                tunableop_filepath = os.path.join(\n                    HUGGINGFACE_HUB_CACHE,\n                    f\"tunableop_{self.model_id.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv\",\n                )\n\n                log_master(\n                    logger.info,\n                    f\"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.\",\n                )\n\n                torch.cuda.tunable.set_filename(\n                    tunableop_filepath, insert_device_ordinal=False\n                )\n\n                if os.path.isfile(tunableop_filepath):\n                    log_master(\n                        logger.info,\n                        f\"The file {tunableop_filepath} already exists and will be reused.\",\n                    )\n                    torch.cuda.tunable.read_file(tunableop_filepath)\n\n                os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)\n\n                for seqlen in tuning_sequences:\n                    log_master(logger.info, f\"Warming up TunableOp for seqlen={seqlen}\")\n                    self.tunableop_warmup(seqlen, max_total_tokens)\n                    torch.cuda.tunable.write_file(tunableop_filepath)\n                if os.environ.get(\"PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP\") != \"1\":\n                    torch.cuda.tunable.tuning_enable(False)\n            else:\n                log_master(\n                    logger.info,\n                    \"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.\",\n                )\n\n        if CUDA_GRAPHS:\n            try:\n                log_master(\n                    logger.info, f\"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}\"\n                )\n                # Warmup cuda graphs\n                for bs in CUDA_GRAPHS:\n                    synchronize(self.device)\n                    free_memory = get_free_memory(\n                        self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM\n                    )\n                    log_master(\n                        logger.debug,\n                        f\"Free RAM before cuda graph {bs} {free_memory / 1e9:.2f}GB\",\n                    )\n                    if self.speculate is None or self.speculate + 1 <= bs:\n                        self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens)\n                empty_cache()\n                synchronize(self.device)\n                free_memory = get_free_memory(\n                    self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM\n                )\n                log_master(\n                    logger.debug,\n                    f\"Free RAM after cuda graphs {free_memory / 1e9:.2f}GB\",\n                )\n            except torch.cuda.OutOfMemoryError:\n                logger.exception(\"Decode cuda graph warmup failed\")\n        else:\n            log_master(\n                logger.info, f\"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).\"\n            )\n\n        assert max_input_tokens is not None\n        assert max_total_tokens is not None\n        return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens\n\n    def tunableop_warmup(self, seqlen: int, max_bt: int):\n        input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)\n        position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)\n        slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)\n\n        # Dummy value, some models (starcoder2) don't accept `None`.\n        input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)\n        cache_lengths_tensor = torch.zeros(\n            seqlen, dtype=torch.int32, device=self.device\n        )\n        cu_seqlen_prefill = torch.tensor(\n            [0, seqlen], device=self.device, dtype=torch.int32\n        )\n        max_s = seqlen\n\n        block_tables = torch.arange(\n            max_bt, dtype=torch.int32, device=self.device\n        ).repeat(seqlen)\n        block_tables = block_tables.reshape((seqlen, max_bt))\n\n        seqlen = Seqlen(\n            input_lengths=input_lengths,\n            cache_lengths=cache_lengths_tensor,\n            max_k=seqlen,\n        )\n\n        # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.\n        self.model.forward(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=self.kv_cache,\n            block_tables=block_tables,\n            seqlen=seqlen,\n            slots=slots,\n            max_s=max_s,\n            lm_head_indices=None,\n            prefill_cache_indices=None,\n        )\n\n    def forward(\n        self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n\n            # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices,\n            # then update the slots with the additional indices to ensure we're grabbing the ones that have been\n            # allocated\n            slot_indices = (\n                batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n            slots = batch.slots[slot_indices]\n\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n            cache_lengths_tensor = (\n                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)\n            ).reshape(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            cache_lengths_tensor = batch.cache_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        bs = input_ids.shape[0]\n        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])\n        if sorted_padded_bs:\n            # Get associated cuda graph\n            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]\n        else:\n            cuda_graph = None\n\n        if cu_seqlen_prefill is not None or cuda_graph is None:\n            if ATTENTION == \"flashinfer\":\n                block_tables = block_tables_to_ragged(\n                    block_tables=block_tables,\n                    input_lengths=batch.input_lengths,\n                    cache_lengths=batch.cache_lengths,\n                    input_lengths_tensor=batch.input_lengths_tensor,\n                    cache_lengths_tensor=batch.cache_lengths_tensor,\n                    max_current_length=batch.max_current_length,\n                )\n            with self._forward_context(\n                block_tables=block_tables,\n                cu_seqlen_prefill=cu_seqlen_prefill,\n                input_lengths_tensor=input_lengths,\n                cache_lengths_tensor=cache_lengths_tensor,\n            ):\n                seqlen = Seqlen(\n                    input_lengths=input_lengths,\n                    cache_lengths=cache_lengths_tensor,\n                    cu_seqlen_q=cu_seqlen_prefill,\n                    max_q=batch.max_input_length,\n                    max_k=batch.max_current_length,\n                )\n                logits, speculative_logits = self.model.forward(\n                    input_ids=input_ids,\n                    position_ids=position_ids,\n                    cu_seqlen_prefill=cu_seqlen_prefill,\n                    kv_cache=kv_cache,\n                    block_tables=block_tables,\n                    slots=slots,\n                    seqlen=seqlen,\n                    max_s=max_s,\n                    prefill_cache_indices=batch.prefill_cache_indices,\n                    lm_head_indices=lm_head_indices,\n                    adapter_data=adapter_data,\n                )\n                if batch.prefill_cache_indices is not None:\n                    batch.prefill_cache_indices = None\n                return logits, speculative_logits\n\n        # Copy inputs to the static inputs of the cuda graph\n        # Static inputs are potentially padded\n        cuda_graph[\"input_ids\"][: input_ids.shape[0]] = input_ids\n        cuda_graph[\"position_ids\"][: position_ids.shape[-1]] = position_ids\n        if ATTENTION == \"flashinfer\":\n            block_tables = block_tables_to_ragged(\n                block_tables=block_tables,\n                input_lengths=batch.input_lengths,\n                cache_lengths=batch.cache_lengths,\n                input_lengths_tensor=batch.input_lengths_tensor,\n                cache_lengths_tensor=batch.cache_lengths_tensor,\n                max_current_length=batch.max_current_length,\n            )\n            # assert block_tables.shape[0] >= slots.shape[0]\n            cuda_graph[\"block_tables\"][: block_tables.shape[0]] = block_tables\n        else:\n            cuda_graph[\"block_tables\"][\n                : block_tables.shape[0], : block_tables.shape[1]\n            ] = block_tables\n\n        # XXX: This is working only because block 0 is reserved for the healthcheck\n        # so it doesn't matter if we override it with bogus values.\n        cuda_graph[\"slots\"].fill_(0)\n        cuda_graph[\"slots\"][: slots.shape[0]] = slots\n        cuda_graph[\"input_lengths\"].zero_()\n        cuda_graph[\"input_lengths\"][: input_lengths.shape[0]] = input_lengths\n        cuda_graph[\"cache_lengths\"].zero_()\n        cuda_graph[\"cache_lengths\"][\n            : cache_lengths_tensor.shape[0]\n        ] = cache_lengths_tensor\n\n        with self._forward_context(\n            block_tables=cuda_graph[\"block_tables\"],\n            cu_seqlen_prefill=None,\n            input_lengths_tensor=cuda_graph[\"input_lengths\"],\n            cache_lengths_tensor=cuda_graph[\"cache_lengths\"],\n            state=cuda_graph[\"state\"],\n        ):\n            # Replay the graph\n            cuda_graph[\"graph\"].replay()\n\n        # Slice output to the correct shape\n        speculative_logits = (\n            cuda_graph[\"speculative_logits\"][:bs]\n            if cuda_graph[\"speculative_logits\"] is not None\n            else None\n        )\n        logits = cuda_graph[\"logits\"][:bs]\n        return logits, speculative_logits\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batch: FlashCausalLMBatch\n    ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:\n        start = time.time_ns()\n        prefill = batch.prefilling\n        if prefill:\n            batch.prepare_for_prefill()\n\n        if hasattr(self, \"set_inputs_embeds\") and callable(self.set_inputs_embeds):\n            self.set_inputs_embeds(batch)\n\n        prefill_logprobs = batch.prefill_next_token_indices is not None\n\n        # Update adapter indices for speculative tokens (if present)\n        adapter_meta = batch.adapter_meta\n        if batch.speculative_ids is not None:\n            B, speculative_length = batch.speculative_ids.shape\n            new_length = speculative_length + 1\n            adapter_indices = (\n                adapter_meta.adapter_indices.unsqueeze(-1)\n                .expand(B, new_length)\n                .reshape(-1)\n            )\n            adapter_segments = adapter_meta.adapter_segments * new_length\n            adapter_meta = AdapterBatchMetadata(\n                adapter_indices=adapter_indices,\n                adapter_set=adapter_meta.adapter_set,\n                adapter_segments=adapter_segments,\n                segment_indices=adapter_meta.segment_indices,\n            )\n\n        # Assign pointers to adapter weights\n        # TODO(travis): don't update this if indices haven't changed\n        adapter_data = AdapterBatchData.from_meta(\n            adapter_meta,\n            self.layer_to_adapter_weights,\n            prefill,\n            batch.prefill_head_indices,\n        )\n\n        out, speculative_logits = self.forward(batch, adapter_data)\n\n        if prefill:\n            next_token_logits = (\n                out[batch.prefill_next_token_indices] if prefill_logprobs else out\n            )\n            if speculative_logits is not None:\n                speculative_logits = (\n                    speculative_logits[batch.prefill_next_token_indices]\n                    if prefill_logprobs\n                    else speculative_logits\n                )\n            if len(batch) > 1 and prefill_logprobs:\n                # We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs\n                # When batch == 1, we will just use the batch.input_ids values directly\n                prefill_tokens_indices = batch.input_ids.new_zeros(len(out))\n        else:\n            prefill_logprobs = None\n            next_token_logits = out\n\n        finished_prefilling = True\n        next_chunk_lengths = []\n        current_prefilling_mask = batch.prefilling_mask\n        if prefill:\n            if get_support_chunking():\n                next_prefilling_mask = []\n                # Budget in tokens for the next batch\n                # We remove (len(batch) - 1) to always have enough space for at least a single decode\n                # for the remaining requests -1 because the first request does not need to be removed from the budget\n                # (ex: you have one request in the batch, you want it to take the full budget not budget -1)\n                batch_budget = get_max_prefill_tokens() - (len(batch) - 1)\n                # We reverse to prioritize older requests\n                # zip() is not reversible so reverse the underlying lists instead\n                for cache_length, input_length, prompt_length in zip(\n                    reversed(batch.cache_lengths),\n                    reversed(batch.input_lengths),\n                    reversed(batch.prompt_lengths),\n                ):\n                    remaining_prefill_tokens = max(\n                        prompt_length - cache_length - input_length, 0\n                    )\n                    if remaining_prefill_tokens > 0:\n                        next_chunk_length = max(\n                            min(remaining_prefill_tokens, batch_budget), 1\n                        )\n                        batch_budget -= next_chunk_length\n                        finished_prefilling = False\n                        next_prefilling_mask.append(True)\n                    else:\n                        # FIXME: use true number of accepted tokens instead of 1\n                        # Since speculation will be turned off, this is always true\n                        next_chunk_length = 1\n                        next_prefilling_mask.append(False)\n                    next_chunk_lengths.append(next_chunk_length)\n\n                # Reverse back the obtained values²\n                next_chunk_lengths.reverse()\n                next_prefilling_mask.reverse()\n            else:\n                # The model does not support chunking\n                # We know we only do a single prefill\n                finished_prefilling = True\n                next_prefilling_mask = [False] * len(batch)\n\n            batch.prefilling = not finished_prefilling\n            batch.prefilling_mask = next_prefilling_mask\n\n        speculate = get_speculate()\n        (\n            next_input_ids,\n            next_token_logprobs,\n            logprobs,\n            accepted_ids,\n            speculative_ids,\n        ) = batch.next_token_chooser(\n            batch.all_input_ids_tensor[:, : batch.max_current_length],\n            next_token_logits,\n            speculate,\n            batch.speculative_ids,\n            speculative_logits,\n        )\n\n        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n            batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids\n        )\n\n        # Since we are done prefilling, all the tensors that were concatenating values for all the requests\n        # instantly become of shape [BATCH_SIZE]\n        if prefill and finished_prefilling:\n            indices = batch.cu_seqlen_prefill[1:] - 1\n            batch.position_ids = batch.position_ids[indices]\n            batch.slot_indices = batch.slot_indices[indices]\n            batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[\n                indices\n            ]\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.prompt_lengths,\n            batch.cache_lengths,\n            batch.input_lengths,\n            batch.all_input_ids,\n            accepted_ids,\n            current_prefilling_mask,\n            batch.prefilling_mask,\n        )\n\n        # We do two for loops as the first one can run completely asynchronously from the GPU while for the second\n        # one, we need to first do a GPU <-> CPU sync\n        # It is faster if we delay this sync for the maximum amount of time\n\n        # For each member of the batch\n        # Cumulative length\n        cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)\n        torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])\n        cumulative_length = 0\n        for i, (\n            request,\n            prompt_length,\n            cache_length,\n            input_length,\n            all_input_ids,\n            n_accepted_ids,\n            request_was_prefilling,\n            request_is_prefilling,\n        ) in enumerate(iterator):\n            # Used to gather prefill logprobs\n            # Copy batch.all_input_ids_tensor to prefill_token_indices\n            if request.prefill_logprobs and request_was_prefilling:\n                # Indexing metadata\n                out_start_index = batch.prefill_cu_outlens[i]\n                out_end_index = batch.prefill_cu_outlens[i + 1]\n\n                # Logprobs generated by the model are for the next token\n                # So we need to translate the id tensor by 1\n                ids = batch.all_input_ids_tensor[\n                    i, cache_length + 1 : cache_length + input_length + 1\n                ]\n                if len(batch) > 1:\n                    prefill_tokens_indices[out_start_index:out_end_index] = ids\n                else:\n                    # Set prefill_tokens_indices to the correct slice\n                    prefill_tokens_indices = ids\n\n            # If the device does not support triton, we copy one by one\n            if not request_is_prefilling and not has_triton():\n                # Only save tokens if we are done prefilling for this request\n                batch.all_input_ids_tensor[\n                    i,\n                    batch.cache_lengths_tensor[i]\n                    + batch.input_lengths[i] : batch.cache_lengths_tensor[i]\n                    + batch.input_lengths[i]\n                    + accepted_ids[i],\n                ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]\n            cumulative_length += input_length\n\n        # If the device support triton, we can use a fused kernel\n        if has_triton():\n            copy_next_input_ids_inplace(\n                speculate + 1,\n                batch.all_input_ids_tensor,\n                batch.cache_lengths_tensor,\n                batch.input_lengths_tensor,\n                batch.prompt_lengths_tensor,\n                next_input_ids,\n                cu_accepted_ids,\n            )\n\n        # Update values\n        # These values can be updated without a GPU -> CPU sync\n        if not prefill or (prefill and finished_prefilling):\n            batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]\n            batch.speculative_ids = speculative_ids\n            if batch.position_ids.dim() == 2:\n                # Qwen2_vl case:\n                batch.position_ids += accepted_ids.unsqueeze(-1)\n            else:\n                batch.position_ids += accepted_ids\n            batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1\n            batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)\n            batch.slot_indices += accepted_ids\n\n        if prefill and prefill_logprobs:\n            # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))\n            torch.log_softmax(out, -1, out=out)\n            prefill_logprobs_tensor = out\n            prefill_logprobs = torch.gather(\n                prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1)\n            )\n            # GPU <-> CPU sync\n            prefill_logprobs = prefill_logprobs.view(-1).tolist()\n\n        # Does a GPU <-> CPU sync internally\n        if prefill and finished_prefilling:\n            # adjust segment lengths to account for all request lengths being 1 during decoding\n            adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)\n            batch.adapter_meta.adapter_segments = torch.tensor(\n                adapter_segments,\n                dtype=torch.int32,\n                device=batch.adapter_meta.adapter_segments.device,\n            )\n\n        # GPU <-> CPU sync\n        next_token_logprobs = next_token_logprobs.tolist()\n        next_token_ids = next_input_ids.tolist()\n        accepted_ids = accepted_ids.tolist()\n\n        # Update values if we need to continue prefilling\n        # This represents the `else` case of the `Update values` if above\n        # but since this require the `next_token_ids` to be on CPU, it is better to do it here\n        if prefill and not finished_prefilling:\n            # Speculation must be ignored while we prefill even with chunking\n            # it simplifies everything\n            assert batch.speculative_ids is None\n\n            all_postfix_ids = []\n            for i, (\n                request_prefilling,\n                next_token_id,\n                all_input_ids,\n                cache_length,\n                input_length,\n                next_chunk_length,\n            ) in enumerate(\n                zip(\n                    batch.prefilling_mask,\n                    next_token_ids,\n                    batch.all_input_ids,\n                    batch.cache_lengths,\n                    batch.input_lengths,\n                    next_chunk_lengths,\n                )\n            ):\n                if request_prefilling:\n                    next_cache_length = cache_length + input_length\n                    # Get new prompt IDs to prefill\n                    postfix_ids = all_input_ids[\n                        next_cache_length : next_cache_length + next_chunk_length\n                    ]\n                else:\n                    # This request is done prefilling, the new id is the one selected the sampling method\n                    postfix_ids = [next_token_id]\n\n                all_postfix_ids.append(postfix_ids)\n\n            batch.input_ids = all_postfix_ids\n\n        start_decode = time.time_ns()\n\n        # Results\n        generations: List[Generation] = []\n        stopped = True\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.prompt_lengths,\n            batch.cache_lengths,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            batch.stopping_criterias,\n            batch.all_input_ids,\n            batch.next_token_chooser.do_sample,\n            batch.next_token_chooser.seeds,\n            batch.top_n_tokens,\n            current_prefilling_mask,\n            batch.prefilling_mask,\n            accepted_ids,\n            batch_top_token_ids,\n            batch_top_token_logprobs,\n        )\n\n        # Reset max_input_length\n        batch.max_input_length = 0\n        # For each member of the batch\n        index = 0\n        for i, (\n            request,\n            prompt_length,\n            cache_length,\n            input_length,\n            prefix_offset,\n            read_offset,\n            stopping_criteria,\n            all_input_ids,\n            do_sample,\n            seed,\n            top_n_tokens,\n            request_was_prefilling,\n            request_is_prefilling,\n            n_accepted_ids,\n            top_token_ids,\n            top_token_logprobs,\n        ) in enumerate(iterator):\n            # Compute logprobs first as, even though we might skip the token,\n            # it can still be required to compute the logprobs\n            # modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need\n            # this state to be stable\n            if request.id % self.world_size == self.rank:\n                # Prefill\n                if request_was_prefilling and request.prefill_logprobs:\n                    out_start_index = batch.prefill_cu_outlens[i]\n                    out_end_index = batch.prefill_cu_outlens[i + 1]\n                    if not request_is_prefilling:\n                        # The request is dones prefilling, meaning that we started generating new tokens\n                        # The last logprob is a logprob for a generated token that was not part of the prompt\n                        # We need to remove it\n                        out_end_index -= 1\n\n                    request_prefill_logprobs = prefill_logprobs[\n                        out_start_index:out_end_index\n                    ]\n                    # Logprobs generated by the model are for the next token\n                    # So we need to translate the id tensor by 1\n                    prefill_token_ids = all_input_ids[\n                        cache_length + 1 : cache_length + input_length + 1\n                    ]\n\n                    past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]\n\n                    if past_prefill_logprob_tokens is None:\n                        # add nan for cached prompt tokens/first token\n                        request_prefill_logprobs = [float(\"nan\")] * (\n                            cache_length + 1\n                        ) + request_prefill_logprobs\n                        prefill_token_ids = (\n                            all_input_ids[: cache_length + 1] + prefill_token_ids\n                        )\n\n                    prefill_texts = self.tokenizer.batch_decode(\n                        prefill_token_ids,\n                        clean_up_tokenization_spaces=False,\n                        skip_special_tokens=False,\n                    )\n\n                    prefill_logprob_tokens = Tokens(\n                        prefill_token_ids,\n                        request_prefill_logprobs,\n                        prefill_texts,\n                        is_special=[],\n                    )\n                    if past_prefill_logprob_tokens is not None:\n                        prefill_logprob_tokens = (\n                            past_prefill_logprob_tokens + prefill_logprob_tokens\n                        )\n\n                    batch.prefill_logprob_tokens[i] = prefill_logprob_tokens\n                else:\n                    batch.prefill_logprob_tokens[i] = None\n\n            # If it is, the tokens we decoded should be ignored\n            if request_is_prefilling:\n                # Make sure that we do not stop as even though this request did not create a token, it is still\n                # processing\n                stopped = False\n                new_input_length = next_chunk_lengths[i]\n                new_cache_length = cache_length + input_length\n            else:\n                new_input_length = 1\n                new_cache_length = cache_length + input_length + n_accepted_ids - 1\n                # Append next token to all tokens\n                next_token_texts = []\n                left = 0\n\n                if n_accepted_ids > 1:\n                    log_master(logger.debug, f\"speculated ids {n_accepted_ids - 1}\")\n\n                current_stopped = False\n                for j in range(index, index + n_accepted_ids):\n                    # Generated token\n                    next_token_id = next_token_ids[j]\n                    all_input_ids.append(next_token_id)\n                    next_token_text, prefix_offset, read_offset = self.decode_token(\n                        all_input_ids,\n                        prefix_offset,\n                        read_offset,\n                    )\n                    next_token_texts.append(next_token_text)\n\n                    stop, reason = stopping_criteria(\n                        next_token_id,\n                        next_token_text,\n                    )\n\n                    if stop:\n                        left = index + n_accepted_ids - j - 1\n                        current_stopped = True\n                        break\n                    else:\n                        current_stopped = False\n                stopped = stopped and current_stopped\n\n                _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]\n                _next_token_logprobs = next_token_logprobs[\n                    index : index + n_accepted_ids - left\n                ]\n\n                # Shard generations\n                # All generations will be appended in the rust sharded client\n                if request.id % self.world_size == self.rank:\n                    if stop:\n                        # Decode generated tokens\n                        output_text, _, _ = self.decode_token(\n                            all_input_ids,\n                            prefix_offset=len(all_input_ids)\n                            - stopping_criteria.current_tokens\n                            - 1,\n                            read_offset=len(all_input_ids)\n                            - stopping_criteria.current_tokens,\n                            skip_special_tokens=True,\n                        )\n                        generated_text = GeneratedText(\n                            output_text,\n                            stopping_criteria.current_tokens,\n                            reason,\n                            seed if do_sample else None,\n                        )\n                    else:\n                        generated_text = None\n\n                    if top_n_tokens > 0:\n                        all_top_tokens = []\n                        for top_token_ids, top_token_logprobs in zip(\n                            top_token_ids, top_token_logprobs\n                        ):\n                            toptoken_texts = self.tokenizer.batch_decode(\n                                top_token_ids,\n                                clean_up_tokenization_spaces=False,\n                                skip_special_tokens=False,\n                            )\n                            special_toptokens = [\n                                token_id in self.all_special_ids\n                                for token_id in top_token_ids\n                            ]\n                            top_tokens = Tokens(\n                                top_token_ids,\n                                top_token_logprobs,\n                                toptoken_texts,\n                                special_toptokens,\n                            )\n                            all_top_tokens.append(top_tokens)\n                        top_tokens = all_top_tokens\n                    else:\n                        top_tokens = None\n\n                    generation = Generation(\n                        request.id,\n                        batch.prefill_logprob_tokens[i],\n                        Tokens(\n                            _next_token_ids,\n                            _next_token_logprobs,\n                            next_token_texts,\n                            [nid in self.all_special_ids for nid in _next_token_ids],\n                        ),\n                        generated_text,\n                        top_tokens,\n                    )\n\n                    generations.append(generation)\n\n                # accept each new token for this specific request since we may\n                # have more than one new token per request with speculative decoding\n                for next_token_id in _next_token_ids:\n                    batch.next_token_chooser = (\n                        batch.next_token_chooser.advance_grammar_single(\n                            i, next_token_id\n                        )\n                    )\n\n            # Update values\n            index += n_accepted_ids\n            batch.cache_lengths[i] = new_cache_length\n            batch.max_input_length = max(batch.max_input_length, new_input_length)\n            batch.input_lengths[i] = new_input_length\n            current_length = new_cache_length + new_input_length\n            batch.max_current_length = max(batch.max_current_length, current_length)\n\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.all_input_ids[i] = all_input_ids\n\n        if stopped:\n            # No need to return a batch if we know that all requests stopped\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        if prefill and finished_prefilling:\n            # We do not need prefill tensors anymore\n            batch.cu_seqlen_prefill = None\n            batch.prefill_cache_indices = None\n            batch.prefill_cu_outlens = None\n            batch.prefill_head_indices = None\n            batch.prefill_next_token_indices = None\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n\n    def _forward_context(\n        self,\n        *,\n        block_tables: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        input_lengths_tensor: torch.Tensor,\n        cache_lengths_tensor: torch.Tensor,\n        state: Optional[Any] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ) -> ContextManager:\n        if ATTENTION != \"flashinfer\":\n            return nullcontext()\n\n        from text_generation_server.layers.attention.flashinfer import (\n            use_decode_state,\n            use_prefill_with_paged_kv_state,\n        )\n\n        if cu_seqlen_prefill is not None:\n            return use_prefill_with_paged_kv_state(\n                state=(\n                    state if state is not None else self.prefill_with_paged_kv_state\n                ),\n                block_tables=block_tables,\n                cu_seqlens=cu_seqlen_prefill,\n                custom_mask=attention_mask,\n                input_lengths=input_lengths_tensor + cache_lengths_tensor,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n                head_size=self.head_size,\n                page_size=BLOCK_SIZE,\n                kv_dtype=self.kv_cache_dtype,\n                q_dtype=self.dtype,\n            )\n        else:\n            assert input_lengths_tensor is not None\n            return use_decode_state(\n                state=state if state is not None else self.decode_state,\n                input_lengths=input_lengths_tensor + cache_lengths_tensor,\n                block_tables=block_tables,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n                head_size=self.head_size,\n                page_size=BLOCK_SIZE,\n                kv_cache_dtype=self.kv_cache_dtype,\n                q_dtype=self.dtype,\n            )\n"
  },
  {
    "path": "server/text_generation_server/models/galactica.py",
    "content": "import re\nimport torch\nimport torch.distributed\n\n\nfrom transformers import (\n    PreTrainedTokenizerBase,\n)\nfrom text_generation_server.models.causal_lm import CausalLMBatch\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import (\n    NextTokenChooser,\n    StoppingCriteria,\n)\nfrom text_generation_server.utils.chunks import concat_text_chunks\n\n# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py\n\n# we split individual characters inside special tokens like [START_DNA]\nCUSTOM_SEQ_RE = re.compile(r\"(\\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\\[END_\\2])\")\n\n# token added to implement a custom sequence tokenization. This token is added at\n# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance\n# that they do not occur in the corpus. The digits are escaped so that the token does not appear\n# literally in the source code in case we ever include it in the training data.\nSPLIT_MARKER = f\"SPL{1}T-TH{1}S-Pl3A5E\"\n\n\ndef _insert_split_marker(m: re.Match):\n    \"\"\"\n    Applies split marker based on a regex match of special tokens such as\n    [START_DNA].\n    Parameters\n    ----------\n    n : str\n        Input text to split\n    Returns\n    ----------\n    str - the text with the split token added\n    \"\"\"\n    start_token, _, sequence, end_token = m.groups()\n    sequence = re.sub(r\"(.)\", rf\"{SPLIT_MARKER}\\1\", sequence, flags=re.DOTALL)\n    return f\"{start_token}{sequence}{SPLIT_MARKER}{end_token}\"\n\n\ndef escape_custom_split_sequence(text):\n    \"\"\"\n    Applies custom splitting to the text for GALILEO's tokenization\n    Parameters\n    ----------\n    text : str\n        Input text to split\n    Returns\n    ----------\n    str - the text with the split token added\n    \"\"\"\n    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)\n\n\n# END CREDIT\n\n\nclass GalacticaCausalLMBatch(CausalLMBatch):\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"GalacticaCausalLMBatch\":\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        prefix_offsets = []\n        top_n_tokens = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            requests_idx_mapping[r.id] = i\n            # Add escape_custom_split_sequence to the CausalLMBatch logic\n            inputs.append(\n                escape_custom_split_sequence(concat_text_chunks(r.input_chunks.chunks))\n            )\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        tokenized_inputs = tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            return_token_type_ids=False,\n            truncation=True,\n            max_length=max_truncation,\n        ).to(device)\n        for _ in pb.requests:\n            input_len = tokenized_inputs[\"input_ids\"].shape[1]\n            prefix_offsets.append(0)\n            read_offsets.append(input_len)\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n\n        input_ids = tokenized_inputs[\"input_ids\"]\n        # Allocate maximum attention_mask\n        attention_mask = input_ids.new_zeros(\n            (pb.size, max_input_length + padding_right_offset)\n        )\n        # Copy tokenizer attention_mask into fully allocated attention_mask\n        attention_mask[:, :max_input_length] = tokenized_inputs[\"attention_mask\"]\n\n        position_ids = tokenized_inputs[\"attention_mask\"].long().cumsum(-1) - 1\n        position_ids.masked_fill_(tokenized_inputs[\"attention_mask\"] == 0, 1)\n        all_input_ids = tokenized_inputs[\"input_ids\"].T.split(1, dim=1)\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n\n        max_tokens = len(inputs) * max_input_length + max_decode_tokens\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=None,\n            all_input_ids=list(all_input_ids),\n            input_lengths=input_lengths.tolist(),\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length.item(),\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/globals.py",
    "content": "import torch\nimport os\nfrom loguru import logger\nfrom typing import Dict, Optional\n\nfrom text_generation_server.utils.log import log_master\n\nREQUEST_LOGPROBS = os.getenv(\"REQUEST_LOGPROBS\", \"0\").lower() in {\"1\", \"true\"}\nATTENTION = os.environ[\"ATTENTION\"]\n# default_prefix_caching = \"1\" if ATTENTION in {\"flashinfer\", \"flashdecoding\"} else \"0\"\nPREFIX_CACHING = os.environ[\"PREFIX_CACHING\"].lower() in {\n    \"1\",\n    \"true\",\n}\nPREFILL_CHUNKING = os.getenv(\"PREFILL_CHUNKING\", \"1\").lower() in {\"1\", \"true\"}\nlog_master(logger.info, f\"Using prefix caching = {PREFIX_CACHING}\")\n_expected = {\"paged\", \"flashdecoding\", \"flashdecoding-ipex\", \"flashinfer\"}\nassert (\n    ATTENTION in _expected\n), f\"Attention is not valid {ATTENTION}, expected {_expected}\"\nlog_master(logger.info, f\"Using Attention = {ATTENTION}\")\n\nif PREFIX_CACHING and ATTENTION not in {\n    \"flashinfer\",\n    \"flashdecoding\",\n    \"flashdecoding-ipex\",\n}:\n    raise RuntimeError(\"Prefix caching is only supported with flashinfer\")\n\nMEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None\n# Test a 70B model on 4xA100 under load for latest failure\nTGI_WIGGLE_ROOM = float(os.getenv(\"TGI_WIGGLE_ROOM\", \"0.90\"))\nassert TGI_WIGGLE_ROOM > 0\nassert TGI_WIGGLE_ROOM < 1\n\n\n# This is overridden by the cli\nBLOCK_SIZE: int\nif ATTENTION == \"flashdecoding\":\n    BLOCK_SIZE = 256\nelif ATTENTION == \"flashinfer\":\n    BLOCK_SIZE = 1\nelif ATTENTION == \"flashdecoding-ipex\":\n    BLOCK_SIZE = 64\nelse:\n    BLOCK_SIZE = 16\n\ncuda_graphs = os.getenv(\"CUDA_GRAPHS\")\nif cuda_graphs is not None:\n    try:\n        cuda_graphs = [int(item) for item in cuda_graphs.split(\",\")]\n    except Exception as e:\n        raise RuntimeError(\n            f\"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}\"\n        )\nelse:\n    cuda_graphs = None\n# sorting the cuda graphs in descending order helps reduce the\n# memory impact and results in less memory usage\nif cuda_graphs is not None:\n    cuda_graphs.sort(reverse=True)\n\nCUDA_GRAPHS = cuda_graphs\n\n# NOTE: eventually we should move this into the router and pass back the\n# index in all cases.\nADAPTER_TO_INDEX: Optional[Dict[str, int]] = None\n\n\ndef set_adapter_to_index(adapter_to_index: Dict[str, int]):\n    global ADAPTER_TO_INDEX\n    ADAPTER_TO_INDEX = adapter_to_index\n\n\ndef get_adapter_to_index():\n    global ADAPTER_TO_INDEX\n    return ADAPTER_TO_INDEX\n"
  },
  {
    "path": "server/text_generation_server/models/idefics_causal_lm.py",
    "content": "from io import BytesIO\nfrom PIL import Image\nimport torch\nimport time\n\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    AutoConfig,\n    AutoProcessor,\n    AutoTokenizer,\n    PreTrainedTokenizerBase,\n    ProcessorMixin,\n)\nfrom typing import Optional, Tuple, List, Type, Dict\n\nfrom text_generation_server.models import Model\nfrom text_generation_server.models.types import (\n    Batch,\n    Tokens,\n    Generation,\n    GeneratedText,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling\nimport torch.distributed\nfrom text_generation_server.models.custom_modeling.idefics_modeling import (\n    IdeficsForVisionText2Text,\n)\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.utils.quantization import get_loader\n\nfrom text_generation_server.utils.import_utils import SYSTEM\n\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass IdeficsCausalLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    requests_idx_mapping: Dict[int, int]\n\n    # Decoder values\n    input_ids: torch.Tensor\n    attention_mask: torch.Tensor\n    position_ids: torch.Tensor\n    pixel_values: Optional[torch.Tensor]\n    image_hidden_states: Optional[torch.Tensor]\n    image_attention_mask: Optional[torch.Tensor]\n    past_key_values: Optional[List[Tuple]]\n\n    # All tokens\n    all_input_ids: List[torch.Tensor]\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    prefix_offsets: List[int]\n    read_offsets: List[int]\n\n    # Generation helpers\n    next_token_choosers: List[NextTokenChooser]\n    stopping_criterias: List[StoppingCriteria]\n\n    # Metadata used for padding\n    max_input_length: int\n    padding_right_offset: int\n\n    # Maximum number of tokens this batch will grow to\n    max_tokens: int\n\n    # Past metadata\n    keys_head_dim_last: bool = True\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.max_tokens,\n            current_tokens=len(self),\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"IdeficsCausalLMBatch\":\n        raise NotImplementedError\n\n    @classmethod\n    def from_pb_processor(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        processor: ProcessorMixin,  # Hack\n        config,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"IdeficsCausalLMBatch\":\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        prefix_offsets = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            requests_idx_mapping[r.id] = i\n            inputs.append(r.input_chunks.chunks)\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        # TODO Check impact on idefics\n        prompts = []\n        for inp in inputs:\n            # Each input is encoded into a list, where each element of this input list is either a string or a URL\n            prompt = []\n            for chunk in inp:\n                chunk_type = chunk.WhichOneof(\"chunk\")\n                if chunk_type == \"text\":\n                    prompt.append(chunk.text)\n                elif chunk_type == \"image\":\n                    image = Image.open(BytesIO(chunk.image.data))\n                    prompt.append(image)\n                else:\n                    raise RuntimeError(f\"Invalid chunk type {chunk_type}\")\n            prompts.append(prompt)\n\n        # The processor replaces the call to tokenizer, and\n        # a/ takes care of fetching images from the URL\n        # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model\n        tokenized_inputs = processor(\n            prompts,\n            return_tensors=\"pt\",\n            padding=True,\n            truncation=True,\n            max_length=max_truncation,\n            # TODO Check impact on idefics\n            # add_end_of_utterance_token=False,  # Already taken care of inside the prompts, so bypassing the processor's handling of this token\n        ).to(device)\n        for _ in pb.requests:\n            input_len = tokenized_inputs[\"input_ids\"].shape[1]\n            prefix_offsets.append(\n                input_len - 5\n            )  # To decode without potential fallbacks errors\n            read_offsets.append(\n                input_len\n            )  # To decode without potential fallbacks errors\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n\n        input_ids = tokenized_inputs[\"input_ids\"]\n        pixel_values = tokenized_inputs.get(\"pixel_values\", None)\n        image_hidden_states = None\n        # Allocate maximum attention_mask\n        attention_mask = input_ids.new_zeros(\n            (pb.size, max_input_length + padding_right_offset)\n        )\n        # Copy tokenizer attention_mask into fully allocated attention_mask\n        attention_mask[:, :max_input_length] = tokenized_inputs[\"attention_mask\"]\n        # Do the same for image_attention_mask\n        if pixel_values is None:\n            image_attention_mask = None\n        else:\n            image_attention_mask = input_ids.new_zeros(\n                (\n                    pb.size,\n                    max_input_length + padding_right_offset,\n                    pixel_values.size(1),\n                )\n            )\n            image_attention_mask[:, :max_input_length, :] = tokenized_inputs[\n                \"image_attention_mask\"\n            ]\n\n        position_ids = tokenized_inputs[\"attention_mask\"].long().cumsum(-1) - 1\n        position_ids.masked_fill_(tokenized_inputs[\"attention_mask\"] == 0, 1)\n        all_input_ids = tokenized_inputs[\"input_ids\"].T.split(\n            1, dim=1\n        )  # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list\n\n        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            pixel_values=pixel_values,\n            image_hidden_states=image_hidden_states,\n            image_attention_mask=image_attention_mask,\n            past_key_values=None,\n            all_input_ids=list(all_input_ids),\n            input_lengths=input_lengths.tolist(),\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            max_input_length=max_input_length.item(),\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> Optional[\"IdeficsCausalLMBatch\"]:\n        # It deletes requests from the batch. For instance when client lost connection\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        if len(request_ids) == len(self):\n            return self\n\n        keep_indices = []\n\n        # New values after filtering\n        requests_idx_mapping = {}\n        requests = []\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        max_input_length = 0\n\n        next_token_choosers = []\n        stopping_criterias = []\n\n        total_remaining_decode_tokens = 0\n        new_padding_right_offset = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            requests_idx_mapping[request_id] = i\n            keep_indices.append(idx)\n\n            requests.append(self.requests[idx])\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n            all_input_ids.append(self.all_input_ids[idx])\n\n            request_input_length = self.input_lengths[idx]\n            input_lengths.append(request_input_length)\n            max_input_length = max(max_input_length, request_input_length)\n\n            next_token_choosers.append(self.next_token_choosers[idx])\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n            remaining_decode_tokens = (\n                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens\n            )\n            total_remaining_decode_tokens += remaining_decode_tokens\n            new_padding_right_offset = max(\n                new_padding_right_offset, remaining_decode_tokens\n            )\n\n        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached\n        input_ids = self.input_ids[keep_indices]\n        position_ids = self.position_ids[keep_indices]\n        self.attention_mask = self.attention_mask[\n            keep_indices,\n            -(self.padding_right_offset + max_input_length) : (\n                self.attention_mask.shape[1] - self.padding_right_offset\n            )\n            + new_padding_right_offset,\n        ]\n        # Do the same for pixel_values and image_attention_mask\n        pixel_values = self.pixel_values[keep_indices]\n        self.image_attention_mask = self.image_attention_mask[\n            keep_indices,\n            -(self.padding_right_offset + max_input_length) : (\n                self.image_attention_mask.shape[1] - self.padding_right_offset\n            )\n            + new_padding_right_offset,\n            :,\n        ]\n        if self.image_hidden_states is None:\n            image_hidden_states = None\n        else:\n            image_hidden_states = self.image_hidden_states[keep_indices]\n\n        # Ensure that past_key_values tensors can be updated in-place\n        if type(self.past_key_values[0]) is tuple:\n            self.past_key_values = [list(layer) for layer in self.past_key_values]\n\n        # Update tensors in-place to allow incremental garbage collection\n        past_kv_length = max_input_length - 1\n        for layer in self.past_key_values:\n            past_keys, past_values = layer\n            if len(past_keys.shape) == 3:\n                # Force past to be of dim [self_size, num_heads, ...] for easy indexing\n                past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])\n                past_values = past_values.view(len(self), -1, *past_values.shape[-2:])\n            if self.keys_head_dim_last:\n                layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]\n            else:\n                layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]\n            del past_keys\n            layer[1] = past_values[keep_indices, :, -past_kv_length:, :]\n            del past_values\n\n        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens\n\n        self.requests = requests\n        self.requests_idx_mapping = requests_idx_mapping\n        self.input_ids = input_ids\n        self.pixel_values = pixel_values\n        self.image_hidden_states = image_hidden_states\n        self.position_ids = position_ids\n        self.all_input_ids = all_input_ids\n        self.input_lengths = input_lengths\n        self.prefix_offsets = prefix_offsets\n        self.read_offsets = read_offsets\n        self.next_token_choosers = next_token_choosers\n        self.stopping_criterias = stopping_criterias\n        self.max_input_length = max_input_length\n        self.padding_right_offset = new_padding_right_offset\n        self.max_tokens = max_tokens\n\n        return self\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(\n        cls, batches: List[\"IdeficsCausalLMBatch\"]\n    ) -> \"IdeficsCausalLMBatch\":\n        # It adds new requests to the batch\n        # Used for padding\n        total_batch_size = 0\n        max_input_length = 0\n        max_num_images = 0\n        padding_right_offset = 0\n        for batch in batches:\n            total_batch_size += len(batch)\n            max_input_length = max(max_input_length, batch.max_input_length)\n            max_num_images = max(max_num_images, batch.pixel_values.size(1))\n            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)\n\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        next_token_choosers = []\n        stopping_criterias = []\n        max_tokens = 0\n\n        # Batch tensors\n        input_ids = None\n        attention_mask = None\n        position_ids = None\n        pixel_values = None\n        image_hidden_states = None\n        image_attention_mask = None\n        past_key_values = []\n\n        # Used for slicing correctly inside the tensors\n        # Equivalent to a cumsum on batch sizes\n        start_index = 0\n        for i, batch in enumerate(batches):\n            requests.extend(batch.requests)\n            input_lengths.extend(batch.input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n            all_input_ids.extend(batch.all_input_ids)\n            next_token_choosers.extend(batch.next_token_choosers)\n            stopping_criterias.extend(batch.stopping_criterias)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + start_index\n\n            # Slicing end index for this batch\n            end_index = start_index + len(batch)\n\n            # We only concatenate batches that did at least one step\n            if batch.past_key_values is None:\n                raise ValueError(\"only concatenate prefilled batches\")\n\n            # Create empty tensor\n            # input_ids is always of shape [batch_size, 1]\n            # We do not need to pad it\n            if input_ids is None:\n                input_ids = batch.input_ids.new_empty((total_batch_size, 1))\n            # Copy to correct indices\n            input_ids[start_index:end_index] = batch.input_ids\n\n            # Create padded tensor\n            if attention_mask is None:\n                attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_input_length + padding_right_offset),\n                )\n\n            curr_batch_max_num_images = batch.pixel_values.size(1)\n            if pixel_values is None:\n                pixel_values = batch.pixel_values.new_zeros(\n                    (total_batch_size, max_num_images, 3, 224, 224)\n                )\n            pixel_values[start_index:end_index, :curr_batch_max_num_images] = (\n                batch.pixel_values\n            )\n\n            if image_attention_mask is None:\n                image_attention_mask = batch.image_attention_mask.new_zeros(\n                    (\n                        total_batch_size,\n                        max_input_length + padding_right_offset,\n                        max_num_images,\n                    )\n                )\n\n            # We need to slice the attention mask to remove padding from previous steps\n            # and to remove unused allocated space\n            left_offset = max_input_length - batch.max_input_length\n            batch_left_offset = (\n                batch.attention_mask.shape[1]\n                - batch.max_input_length\n                - batch.padding_right_offset\n            )\n            attention_mask[\n                start_index:end_index,\n                left_offset:-padding_right_offset,\n            ] = batch.attention_mask[\n                :,\n                batch_left_offset : -batch.padding_right_offset,\n            ]\n            image_attention_mask[\n                start_index:end_index,\n                left_offset:-padding_right_offset,\n                :curr_batch_max_num_images,\n            ] = batch.image_attention_mask[\n                :, batch_left_offset : -batch.padding_right_offset, :\n            ]\n\n            # Create empty tensor\n            # position_ids is always of shape [batch_size, 1]\n            if position_ids is None:\n                position_ids = batch.position_ids.new_empty((total_batch_size, 1))\n            position_ids[start_index:end_index] = batch.position_ids\n\n            # Shenanigans to get dimensions because BLOOM outputs a past with a different shape\n            # BLOOM Keys:   [batch_size * num_heads, head_dim, seq_length]\n            # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]\n            # And ensure that we can update tensors in-place\n            if isinstance(batch.past_key_values[0], tuple):\n                batch.past_key_values = [\n                    [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]\n                    for layer in batch.past_key_values\n                ]\n            elif len(batch.past_key_values[0][0].shape) == 3:\n                for layer in batch.past_key_values:\n                    for k, t in enumerate(layer):\n                        layer[k] = t.view(len(batch), -1, *t.shape[-2:])\n\n            # Add eventual padding tokens that were added while concatenating\n            max_tokens += batch.max_tokens + (\n                max_input_length - batch.max_input_length\n            ) * len(batch)\n\n            start_index = end_index\n\n        first_past_kvs = batches[0].past_key_values\n        _, num_heads, padded_sequence_length, head_dim = first_past_kvs[0][1].shape\n\n        padded_past_values_shape = (\n            total_batch_size,\n            num_heads,\n            max_input_length - 1,\n            head_dim,\n        )\n\n        if batches[0].keys_head_dim_last:\n            padded_past_keys_shape = padded_past_values_shape\n        else:\n            # seq_length is last for BLOOM\n            padded_past_keys_shape = (\n                total_batch_size,\n                num_heads,\n                head_dim,\n                max_input_length - 1,\n            )\n\n        # Iterate over attention layers\n        # Concatenate past key values layer by layer to allow incremental garbage collection\n        for j in range(len(first_past_kvs)):\n            padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)\n            start_index = 0\n            for batch in batches:\n                past_keys = batch.past_key_values[j][0]\n                # Clear reference to the original tensor\n                batch.past_key_values[j][0] = None\n\n                # Slicing end index for this batch\n                end_index = start_index + len(batch)\n                # We slice the keys to remove the padding from previous batches\n                past_seq_len = batch.max_input_length - 1\n                if batch.keys_head_dim_last:\n                    padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (\n                        past_keys[:, :, -past_seq_len:, :]\n                    )\n                else:\n                    # BLOOM case\n                    padded_past_keys[start_index:end_index, :, :, -past_seq_len:] = (\n                        past_keys[:, :, :, -past_seq_len:]\n                    )\n                del past_keys\n\n                start_index = end_index\n\n            padded_past_values = first_past_kvs[j][1].new_zeros(\n                padded_past_values_shape\n            )\n            start_index = 0\n            for batch in batches:\n                past_values = batch.past_key_values[j][1]\n                # Clear reference to the original tensor\n                batch.past_key_values[j][1] = None\n\n                # Slicing end index for this batch\n                end_index = start_index + len(batch)\n                # We slice the past values to remove the padding from previous batches\n                past_seq_len = batch.max_input_length - 1\n                padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (\n                    past_values[:, :, -past_seq_len:, :]\n                )\n                del past_values\n\n                # Update values\n                start_index = end_index\n\n            past_key_values.append([padded_past_keys, padded_past_values])\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            pixel_values=pixel_values,\n            image_hidden_states=image_hidden_states,\n            image_attention_mask=image_attention_mask,\n            past_key_values=past_key_values,\n            all_input_ids=all_input_ids,\n            input_lengths=input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            max_input_length=max_input_length,\n            padding_right_offset=padding_right_offset,\n            keys_head_dim_last=batches[0].keys_head_dim_last,\n            max_tokens=max_tokens,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nclass IdeficsCausalLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            # 9b seems to work correctly enough in float16, but 80b seems\n            # to be really saturating for f16.\n            dtype = torch.float16 if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            if hasattr(torch, \"xpu\") and torch.xpu.is_available():\n                device = torch.device(f\"xpu:{rank}\")\n                dtype = torch.float16 if dtype is None else dtype\n            else:\n                device = torch.device(\"cpu\")\n                # Float16 doesn't exist on target.\n                dtype = torch.bfloat16 if dtype is None else dtype\n        else:\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n        self.device, self.dtype = device, dtype\n\n        config = AutoConfig.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n        config.vision_config.quantize = quantize\n\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        self.processor = AutoProcessor.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n\n        weights_loader = get_loader(\n            quantize=quantize, model_id=model_id, revision=revision\n        )\n        torch.distributed.barrier(group=self.process_group)\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device=device,\n            dtype=dtype,\n            process_group=self.process_group,\n            weights_loader=weights_loader,\n        )\n\n        model = IdeficsForVisionText2Text(config, weights)\n\n        self.config = config\n\n        torch.distributed.barrier(group=self.process_group)\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n        )\n\n    @property\n    def batch_type(self) -> Type[IdeficsCausalLMBatch]:\n        return IdeficsCausalLMBatch\n\n    def forward(\n        self,\n        input_ids,\n        attention_mask,\n        position_ids,\n        pixel_values,\n        image_hidden_states,\n        image_attention_mask,\n        past_key_values: Optional = None,\n    ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:\n        # Model Forward\n        kwargs = {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"pixel_values\": pixel_values,\n            \"image_hidden_states\": image_hidden_states,\n            \"image_attention_mask\": image_attention_mask,\n            \"past_key_values\": past_key_values,\n            \"use_cache\": True,\n            \"return_dict\": True,\n        }\n        if self.has_position_ids:\n            kwargs[\"position_ids\"] = position_ids\n\n        outputs, speculative_logits = self.model.forward(**kwargs)\n        return (\n            outputs.logits,\n            speculative_logits,\n            outputs.past_key_values,\n            outputs.image_hidden_states,\n        )\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batch: IdeficsCausalLMBatch\n    ) -> Tuple[List[Generation], Optional[IdeficsCausalLMBatch], Tuple[int, int]]:\n        start = time.time_ns()\n        # slice the attention mask to the correct shape\n        attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]\n        if batch.image_attention_mask is None:\n            image_attention_mask = None\n        else:\n            if batch.input_ids.size(1) == 1:\n                # THIS is a hack: when calling idefics.generate, the first time, we need the whole image_attention_mask (size bs x max_seq_len x max_num_images),\n                # but the subsequent times, we only need the last attention mask along the `max_seq_len` dimension\n                # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated\n                # token need to attend to the encoder hidden states (i.e. the vision encoder)\n                # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic\n                image_attention_mask = batch.image_attention_mask[\n                    :, -(batch.padding_right_offset + 1)\n                ].unsqueeze(1)\n            else:\n                image_attention_mask = batch.image_attention_mask[\n                    :, : -batch.padding_right_offset\n                ]\n\n        logits, speculative_logits, past, image_hidden_states = self.forward(\n            input_ids=batch.input_ids,\n            attention_mask=attention_mask,\n            position_ids=batch.position_ids,\n            pixel_values=batch.pixel_values,\n            image_hidden_states=batch.image_hidden_states,\n            image_attention_mask=image_attention_mask,\n            past_key_values=batch.past_key_values,\n        )\n        # Hardcoded remove image tokens\n        logits[:, 32000:32001] = torch.finfo(logits.dtype).min\n\n        start_decode = time.time_ns()\n\n        # Results\n        generations: List[Generation] = []\n        stopped = True\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            logits,\n            batch.next_token_choosers,\n            batch.stopping_criterias,\n            batch.all_input_ids,\n        )\n\n        # For each member of the batch\n        for i, (\n            request,\n            input_length,\n            prefix_offset,\n            read_offset,\n            logits,\n            next_token_chooser,\n            stopping_criteria,\n            all_input_ids,\n        ) in enumerate(iterator):\n            # Select next token\n            next_token_id, logprobs = next_token_chooser(\n                all_input_ids.view(1, -1), logits[-1:, :]\n            )\n\n            # Append next token to all tokens\n            all_input_ids = torch.cat([all_input_ids, next_token_id])\n            new_input_length = input_length + 1\n\n            # Generated token\n            next_token_logprob = logprobs[-1, next_token_id]\n            next_token_id_squeezed = next_token_id.squeeze()\n            next_token_text, prefix_offset, read_offset = self.decode_token(\n                all_input_ids[:, 0], prefix_offset, read_offset\n            )\n\n            # Evaluate stopping criteria\n            stop, reason = stopping_criteria(\n                next_token_id_squeezed,\n                next_token_text,\n            )\n\n            if not stop:\n                stopped = False\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if i % self.world_size == self.rank:\n                if stop:\n                    # Decode generated tokens\n                    output_text, _, _ = self.decode_token(\n                        all_input_ids[:, 0],\n                        prefix_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens\n                        - 1,\n                        read_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens,\n                        skip_special_tokens=True,\n                    )\n                    # Get seed\n                    if isinstance(next_token_chooser.choice, Sampling):\n                        seed = next_token_chooser.choice.seed\n                    else:\n                        seed = None\n\n                    generated_text = GeneratedText(\n                        output_text, stopping_criteria.current_tokens, reason, seed\n                    )\n                else:\n                    generated_text = None\n\n                # Prefill\n                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:\n                    # Remove generated token to only have prefill and add nan for first prompt token\n                    prefill_logprobs = [float(\"nan\")] + torch.log_softmax(\n                        logits, -1\n                    ).gather(1, all_input_ids[1:]).squeeze(1)[\n                        -new_input_length:-1\n                    ].tolist()\n                    prefill_token_ids = all_input_ids[-new_input_length:-1]\n                    prefill_texts = self.tokenizer.batch_decode(\n                        prefill_token_ids,\n                        clean_up_tokenization_spaces=False,\n                        skip_special_tokens=False,\n                    )\n                    prefill_tokens = Tokens(\n                        prefill_token_ids,\n                        prefill_logprobs,\n                        prefill_texts,\n                        is_special=[],\n                    )\n                else:\n                    prefill_tokens = None\n\n                top_tokens = None\n\n                generation = Generation(\n                    request.id,\n                    prefill_tokens,\n                    Tokens(\n                        [next_token_id_squeezed],\n                        [next_token_logprob],\n                        [next_token_text],\n                        [next_token_id_squeezed.item() in self.all_special_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n            # Update values\n            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(\n                next_token_id_squeezed.item()\n            )\n            batch.input_ids[i, 0] = next_token_id\n            batch.all_input_ids[i] = all_input_ids\n            batch.input_lengths[i] = new_input_length\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.max_input_length = max(batch.max_input_length, new_input_length)\n\n        # We finished all generations in the batch; there is no next batch\n        if stopped:\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        # Slice unused values from prefill\n        batch.input_ids = batch.input_ids[:, :1]\n\n        # Update attention_mask as we added a new token to input_ids\n        batch.attention_mask[:, -batch.padding_right_offset] = 1\n        batch.image_attention_mask[:, -batch.padding_right_offset, :] = (\n            batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]\n        )\n        # Decrease right offset\n        batch.padding_right_offset -= 1\n\n        # Update position_ids\n        batch.position_ids = batch.position_ids[:, -1:] + 1\n\n        # Update past key values\n        batch.past_key_values = past\n        batch.image_hidden_states = image_hidden_states\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "server/text_generation_server/models/mamba.py",
    "content": "import torch\nimport torch.distributed\nfrom transformers import AutoTokenizer, PreTrainedTokenizerBase\nfrom typing import Optional, Union\nfrom text_generation_server.models.custom_modeling.mamba_modeling import (\n    MambaConfig,\n)\nfrom loguru import logger\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.models.globals import CUDA_GRAPHS, MEM_POOL\nimport time\nfrom text_generation_server.models.custom_modeling.mamba_modeling import (\n    MambaModel,\n    InferenceParams,\n)\nfrom text_generation_server.models import Model\nfrom typing import Any, List, Tuple, Type, Dict\nfrom text_generation_server.models.types import (\n    Batch,\n    Tokens,\n    Generation,\n    GeneratedText,\n)\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.tokens import batch_top_tokens, Sampling\nfrom dataclasses import dataclass\nfrom text_generation_server.utils import NextTokenChooser, StoppingCriteria\n\n\ndef new_inference_params(\n    n_blocks: int,\n    batch_size: int,\n    d_inner: int,\n    d_conv: int,\n    d_state: int,\n    seqlen_offset: int,\n    dtype: torch.dtype,\n    device: torch.device,\n):\n    max_seqlen = 0\n    conv_states = torch.zeros(\n        (\n            n_blocks,\n            batch_size,\n            d_inner,\n            d_conv,\n        ),\n        device=device,\n        dtype=dtype,\n    )\n    ssm_states = torch.zeros(\n        (\n            n_blocks,\n            batch_size,\n            d_inner,\n            d_state,\n        ),\n        device=device,\n        dtype=dtype,\n    )\n    inference_params = InferenceParams(\n        max_seqlen=max_seqlen,\n        max_batch_size=batch_size,\n        seqlen_offset=seqlen_offset,\n        conv_states=conv_states,\n        ssm_states=ssm_states,\n    )\n    return inference_params\n\n\n@dataclass\nclass MambaBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    requests_idx_mapping: Dict[int, int]\n\n    # Decoder values\n    input_ids: torch.Tensor\n\n    # All tokens\n    all_input_ids: List[torch.Tensor]\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    prefix_offsets: List[int]\n    read_offsets: List[int]\n\n    # Generation helpers\n    next_token_choosers: List[NextTokenChooser]\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Metadata used for padding\n    max_input_length: int\n    padding_right_offset: int\n\n    # Maximum number of tokens this batch will grow to\n    max_tokens: int\n\n    # Past metadata\n    keys_head_dim_last: bool = True\n\n    # Inference params\n    inference_params: Optional[Dict[str, Any]] = None\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.max_tokens,\n            current_tokens=len(self),\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"MambaBatch\":\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        prefix_offsets = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            requests_idx_mapping[r.id] = i\n            inputs.append(concat_text_chunks(r.input_chunks.chunks))\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        tokenized_inputs = tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            return_token_type_ids=False,\n            truncation=True,\n            max_length=max_truncation,\n        ).to(device)\n        for _ in pb.requests:\n            input_len = tokenized_inputs[\"input_ids\"].shape[1]\n            prefix_offsets.append(input_len - 5)\n            read_offsets.append(input_len)\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n        input_ids = tokenized_inputs[\"input_ids\"]\n        all_input_ids = tokenized_inputs[\"input_ids\"].T.split(1, dim=1)\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            # past_input_ids=None,\n            all_input_ids=list(all_input_ids),\n            input_lengths=input_lengths.tolist(),\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length.item(),\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    def filter(self, request_ids: List[int]) -> Optional[\"MambaBatch\"]:\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        if len(request_ids) == len(self):\n            return self\n\n        keep_indices = []\n\n        # New values after filtering\n        requests_idx_mapping = {}\n        requests = []\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        max_input_length = 0\n\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        total_remaining_decode_tokens = 0\n        new_padding_right_offset = 0\n\n        indices = []\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            requests_idx_mapping[request_id] = i\n            keep_indices.append(idx)\n\n            requests.append(self.requests[idx])\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n            all_input_ids.append(self.all_input_ids[idx])\n\n            request_input_length = self.input_lengths[idx]\n            input_lengths.append(request_input_length)\n            max_input_length = max(max_input_length, request_input_length)\n            indices.append(idx)\n\n            next_token_choosers.append(self.next_token_choosers[idx])\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(self.top_n_tokens[idx])\n            remaining_decode_tokens = (\n                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens\n            )\n            total_remaining_decode_tokens += remaining_decode_tokens\n            new_padding_right_offset = max(\n                new_padding_right_offset, remaining_decode_tokens\n            )\n\n        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached\n        input_ids = self.input_ids[keep_indices]\n\n        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]\n        max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens\n\n        self.requests = requests\n        self.requests_idx_mapping = requests_idx_mapping\n        self.input_ids = input_ids\n        self.all_input_ids = all_input_ids\n        self.input_lengths = input_lengths\n        self.prefix_offsets = prefix_offsets\n        self.read_offsets = read_offsets\n        self.next_token_choosers = next_token_choosers\n        self.stopping_criterias = stopping_criterias\n        self.top_n_tokens = top_n_tokens\n        self.top_n_tokens_tensor = top_n_tokens_tensor\n        self.max_input_length = max_input_length\n        self.padding_right_offset = new_padding_right_offset\n        self.max_tokens = max_tokens\n\n        # TODO\n        # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.\n        self.inference_params.conv_states = self.inference_params.conv_states[\n            :, indices\n        ]\n        self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]\n        return self\n\n    @classmethod\n    def concatenate(cls, batches: List[\"MambaBatch\"]) -> \"MambaBatch\":\n        # Used for padding\n        total_batch_size = 0\n        max_input_length = 0\n        padding_right_offset = 0\n        for batch in batches:\n            total_batch_size += len(batch)\n            max_input_length = max(max_input_length, batch.max_input_length)\n            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)\n\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n        input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        all_input_ids = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        max_tokens = 0\n        seqlen_offset = 0\n\n        (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape\n        (_, _, _, d_state) = batches[0].inference_params.ssm_states.shape\n        dtype = batches[0].inference_params.conv_states.dtype\n        device = batches[0].inference_params.conv_states.device\n        inference_params = new_inference_params(\n            n_blocks=n_blocks,\n            batch_size=total_batch_size,\n            d_state=d_state,\n            d_conv=d_conv,\n            d_inner=d_inner,\n            seqlen_offset=seqlen_offset,\n            device=device,\n            dtype=dtype,\n        )\n\n        # Batch tensors\n        input_ids = None\n        top_n_tokens_tensor = None\n\n        # Used for slicing correctly inside the tensors\n        # Equivalent to a cumsum on batch sizes\n        start_index = 0\n        for i, batch in enumerate(batches):\n            requests.extend(batch.requests)\n            input_lengths.extend(batch.input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n            all_input_ids.extend(batch.all_input_ids)\n            next_token_choosers.extend(batch.next_token_choosers)\n            stopping_criterias.extend(batch.stopping_criterias)\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + start_index\n\n            # Slicing end index for this batch\n            end_index = start_index + len(batch)\n\n            # Create empty tensor\n            # input_ids is always of shape [batch_size, 1]\n            # We do not need to pad it\n            if input_ids is None:\n                input_ids = batch.input_ids.new_empty((total_batch_size, 1))\n            # Copy to correct indices\n            input_ids[start_index:end_index] = batch.input_ids\n\n            if top_n_tokens_tensor is None:\n                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n                    total_batch_size,\n                )\n            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor\n\n            # Add eventual padding tokens that were added while concatenating\n            max_tokens += batch.max_tokens + (\n                max_input_length - batch.max_input_length\n            ) * len(batch)\n\n            inference_params.max_seqlen = max(\n                inference_params.max_seqlen, batch.inference_params.max_seqlen\n            )\n            assert batch.inference_params.seqlen_offset != 0, \"Invalid seqlen offset\"\n            inference_params.seqlen_offset = max(\n                inference_params.seqlen_offset, batch.inference_params.seqlen_offset\n            )\n\n            inference_params.conv_states[:, start_index:end_index] = (\n                batch.inference_params.conv_states\n            )\n            inference_params.ssm_states[:, start_index:end_index] = (\n                batch.inference_params.ssm_states\n            )\n\n            start_index = end_index\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=input_ids,\n            all_input_ids=all_input_ids,\n            input_lengths=input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length,\n            padding_right_offset=padding_right_offset,\n            keys_head_dim_last=batches[0].keys_head_dim_last,\n            max_tokens=max_tokens,\n            inference_params=inference_params,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nclass Mamba(Model):\n    def __init__(\n        self,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        self.quantize = quantize\n        self.process_group, _rank, world_size = initialize_torch_distributed()\n        if world_size > 1:\n            raise RuntimeError(\"Mamba does not support Tensor Parallelism (TP)\")\n        self.cuda_graphs = {}\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            # Bf16 is important. In f16 accumulations in the matmul are causing\n            # differences while the server is under load.\n            # This is detectable by the integration load test\n            dtype = torch.bfloat16 if dtype is None else dtype\n        else:\n            if quantize:\n                raise ValueError(\"quantization is not available on CPU\")\n\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        tokenizer = AutoTokenizer.from_pretrained(\n            \"EleutherAI/gpt-neox-20b\",\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        config = MambaConfig.from_pretrained(\n            model_id, revision=revision, trust_remote_code=trust_remote_code\n        )\n\n        tokenizer.bos_token_id = config.bos_token_id\n        tokenizer.eos_token_id = config.eos_token_id\n        tokenizer.pad_token = tokenizer.eos_token\n\n        config.quantize = quantize\n        config.speculator = speculator\n        torch.distributed.barrier(group=self.process_group)\n        weights_loader = get_loader(\n            quantize=quantize, model_id=model_id, revision=revision\n        )\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device,\n            dtype,\n            process_group=self.process_group,\n            weights_loader=weights_loader,\n        )\n        model = MambaModel(config, weights)\n        torch.distributed.barrier(group=self.process_group)\n        super(Mamba, self).__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n        )\n\n    @property\n    def batch_type(self) -> Type[MambaBatch]:\n        return MambaBatch\n\n    def warmup(\n        self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]\n    ) -> Union[Optional[int], Optional[int], Optional[int]]:\n        # TODO: implement warmup for Mamba if needed\n        if CUDA_GRAPHS:\n            if self.speculate is None or self.speculate == 0:\n                try:\n                    logger.info(f\"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}\")\n                    # Warmup cuda graphs\n                    for bs in CUDA_GRAPHS:\n                        self.cuda_graph_warmup(bs)\n                except Exception:\n                    logger.exception(\"Decode cuda graph warmup failed\")\n        else:\n            logger.info(f\"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).\")\n\n        if max_total_tokens is None:\n            max_total_tokens = min(self.tokenizer.model_max_length, 4096)\n\n        if max_input_tokens is None:\n            max_input_tokens = max_total_tokens - 1\n        return None, max_input_tokens, max_total_tokens\n\n    def cuda_graph_warmup(self, batch_size: int):\n        input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)\n        n_blocks = len(self.model.blocks)\n\n        d_state = self.model.config.d_state\n        d_conv = self.model.config.d_conv\n        # Inner takes the expand multiplication\n        d_inner = self.model.config.d_inner\n\n        # Important seqlen_offset to go through the update mecanism with the state\n        seqlen_offset = 1\n        inference_params = new_inference_params(\n            n_blocks=n_blocks,\n            batch_size=batch_size,\n            d_state=d_state,\n            d_conv=d_conv,\n            d_inner=d_inner,\n            seqlen_offset=seqlen_offset,\n            device=self.device,\n            dtype=self.dtype,\n        )\n\n        graph = torch.cuda.CUDAGraph()\n\n        torch.cuda.synchronize()\n        # Run once outside to warmup\n        self.model.forward(input_ids=input_ids, inference_params=inference_params)\n        torch.cuda.synchronize()\n\n        with torch.cuda.graph(graph, pool=MEM_POOL):\n            logits, speculative_logits = self.model.forward(\n                input_ids=input_ids, inference_params=inference_params\n            )\n        torch.cuda.synchronize()\n        graph_dict = {\n            \"input_ids\": input_ids,\n            \"inference_params\": inference_params,\n            \"graph\": graph,\n            \"logits\": logits,\n            \"speculative_logits\": speculative_logits,\n        }\n        self.cuda_graphs[batch_size] = graph_dict\n\n    def tunableop_warmup(self, batch_size: int, seqlen: int):\n        input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)\n        n_blocks = len(self.model.blocks)\n\n        d_state = self.model.config.d_state\n        d_conv = self.model.config.d_conv\n        # Inner takes the expand multiplication\n        d_inner = self.model.config.d_inner\n\n        # Important seqlen_offset to go through the update mecanism with the state\n        seqlen_offset = 1\n        inference_params = new_inference_params(\n            n_blocks=n_blocks,\n            batch_size=seqlen,\n            d_state=d_state,\n            d_conv=d_conv,\n            d_inner=d_inner,\n            seqlen_offset=seqlen_offset,\n            device=self.device,\n            dtype=self.dtype,\n        )\n\n        self.model.forward(input_ids=input_ids, inference_params=inference_params)\n\n    def forward(\n        self, input_ids: torch.Tensor, inference_params: Any\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        bs = input_ids.shape[0]\n        padded_bs = bs\n        if bs == 3:\n            padded_bs = 4\n        elif 3 < bs <= 8:\n            padded_bs = 8\n        elif bs > 8:\n            padded_bs = (bs + 7) // 8 * 8\n\n        # Try to find an associated cuda graph\n        cuda_graph = self.cuda_graphs.get(padded_bs, None)\n        is_prefill = inference_params is None or inference_params.seqlen_offset == 0\n\n        if is_prefill or cuda_graph is None:\n            return self.model(\n                input_ids,\n                inference_params=inference_params,\n            )\n\n        # Copy inputs to the static inputs of the cuda graph\n        # Static inputs are potentially padded\n        cuda_graph[\"input_ids\"][:bs] = input_ids\n        cuda_graph[\"inference_params\"].conv_states[\n            :, :bs\n        ] = inference_params.conv_states\n        cuda_graph[\"inference_params\"].ssm_states[:, :bs] = inference_params.ssm_states\n\n        # Replay the graph\n        cuda_graph[\"graph\"].replay()\n\n        inference_params.conv_states.copy_(\n            cuda_graph[\"inference_params\"].conv_states[:, :bs]\n        )\n        inference_params.ssm_states.copy_(\n            cuda_graph[\"inference_params\"].ssm_states[:, :bs]\n        )\n        # Slice output to the correct shape\n        speculative_logits = (\n            cuda_graph[\"speculative_logits\"][:bs]\n            if cuda_graph[\"speculative_logits\"] is not None\n            else None\n        )\n        logits = cuda_graph[\"logits\"][:bs]\n        return logits, speculative_logits\n\n    def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:\n        start = time.time_ns()\n        input_ids = (\n            batch.input_ids\n        )  # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids\n\n        batch_size, max_seqlen = input_ids.shape\n        # Inference params\n\n        if batch.inference_params is None:\n            # 0 is important here\n            seqlen_offset = 0\n            n_blocks = len(self.model.blocks)\n            d_state = self.model.config.d_state\n            d_conv = self.model.config.d_conv\n            d_inner = self.model.config.d_inner\n            inference_params = new_inference_params(\n                n_blocks=n_blocks,\n                batch_size=batch_size,\n                d_state=d_state,\n                d_conv=d_conv,\n                d_inner=d_inner,\n                seqlen_offset=seqlen_offset,\n                device=self.device,\n                dtype=self.dtype,\n            )\n            batch.inference_params = inference_params\n\n        # Forward pass\n        logits, speculative_logits = self.forward(\n            input_ids, inference_params=batch.inference_params\n        )\n\n        # batch.inference_params = new_inference_params\n        # Results\n        generations: List[Generation] = []\n        stopped = True\n\n        # Speculation is not active for causal\n        accepted_ids = torch.ones_like(batch.input_ids)[:, 0]\n        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n            batch.top_n_tokens,\n            batch.top_n_tokens_tensor,\n            torch.log_softmax(logits[:, -1], -1),\n            accepted_ids,\n        )\n\n        start_decode = time.time_ns()\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            logits,\n            batch.next_token_choosers,\n            batch.stopping_criterias,\n            batch.all_input_ids,\n            batch.top_n_tokens,\n            batch_top_token_ids,\n            batch_top_token_logprobs,\n        )\n\n        # For each member of the batch\n        for i, (\n            request,\n            input_length,\n            prefix_offset,\n            read_offset,\n            logits,\n            next_token_chooser,\n            stopping_criteria,\n            all_input_ids,\n            top_n_tokens,\n            top_token_ids,\n            top_token_logprobs,\n        ) in enumerate(iterator):\n            # Select next token\n            next_token_id, logprobs = next_token_chooser(\n                all_input_ids.view(1, -1), logits[-1:, :]\n            )\n\n            # Append next token to all tokens\n            all_input_ids = torch.cat([all_input_ids, next_token_id])\n            new_input_length = input_length + 1\n\n            # Generated token\n            next_token_logprob = logprobs[-1, next_token_id]\n            next_token_id_squeezed = next_token_id.squeeze()\n            next_token_text, prefix_offset, read_offset = self.decode_token(\n                all_input_ids[:, 0], prefix_offset, read_offset\n            )\n\n            # Evaluate stopping criteria\n            stop, reason = stopping_criteria(\n                next_token_id_squeezed,\n                next_token_text,\n            )\n\n            if not stop:\n                stopped = False\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if i % self.world_size == self.rank:\n                if stop:\n                    # Decode generated tokens\n                    output_text, _, _ = self.decode_token(\n                        all_input_ids[:, 0],\n                        prefix_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens\n                        - 1,\n                        read_offset=len(all_input_ids)\n                        - stopping_criteria.current_tokens,\n                        skip_special_tokens=True,\n                    )\n                    # Get seed\n                    if isinstance(next_token_chooser.choice, Sampling):\n                        seed = next_token_chooser.choice.seed\n                    else:\n                        seed = None\n\n                    generated_text = GeneratedText(\n                        output_text, stopping_criteria.current_tokens, reason, seed\n                    )\n                else:\n                    generated_text = None\n\n                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:\n                    # Remove generated token to only have prefill and add nan for first prompt token\n                    prefill_logprobs = [float(\"nan\")] + torch.log_softmax(\n                        logits, -1\n                    ).gather(1, all_input_ids[1:]).squeeze(1)[\n                        -new_input_length:-1\n                    ].tolist()\n                    prefill_token_ids = all_input_ids[-new_input_length:-1]\n                    prefill_texts = self.tokenizer.batch_decode(\n                        prefill_token_ids,\n                        clean_up_tokenization_spaces=False,\n                        skip_special_tokens=False,\n                    )\n                    prefill_tokens = Tokens(\n                        prefill_token_ids,\n                        prefill_logprobs,\n                        prefill_texts,\n                        is_special=[],\n                    )\n                else:\n                    prefill_tokens = None\n\n                if top_n_tokens > 0:\n                    toptoken_texts = self.tokenizer.batch_decode(\n                        top_token_ids,\n                        clean_up_tokenization_spaces=False,\n                        skip_special_tokens=False,\n                    )\n                    special_toptokens = [\n                        token_id in self.all_special_ids for token_id in top_token_ids\n                    ]\n                    top_tokens = Tokens(\n                        top_token_ids,\n                        top_token_logprobs,\n                        toptoken_texts,\n                        special_toptokens,\n                    )\n                else:\n                    top_tokens = None\n\n                generation = Generation(\n                    request.id,\n                    prefill_tokens,\n                    Tokens(\n                        [next_token_id_squeezed],\n                        [next_token_logprob],\n                        [next_token_text],\n                        [next_token_id_squeezed.item() in self.all_special_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n                # Update values\n                batch.next_token_choosers[i] = batch.next_token_choosers[\n                    i\n                ].advance_grammar(next_token_id_squeezed.item())\n                batch.input_ids[i, 0] = next_token_id\n                batch.all_input_ids[i] = all_input_ids\n                batch.input_lengths[i] = new_input_length\n                batch.prefix_offsets[i] = prefix_offset\n                batch.read_offsets[i] = read_offset\n                batch.max_input_length = max(batch.max_input_length, new_input_length)\n\n        # We finished all generations in the batch; there is no next batch\n        if stopped:\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        # Slice unused values from prefill\n        batch.input_ids = batch.input_ids[:, :1]\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "server/text_generation_server/models/metadata_kernels.py",
    "content": "import torch\nimport triton\n\nimport triton.language as tl\n\nfrom loguru import logger\nfrom typing import List, Optional\nfrom torch.utils._triton import has_triton as has_triton_torch\n\nfrom text_generation_server.utils.import_utils import (\n    SYSTEM,\n)\nfrom text_generation_server.utils.log import log_master\n\n_HAS_TRITON: Optional[bool] = None\n\n\ndef has_triton():\n    global _HAS_TRITON\n    if _HAS_TRITON is None:\n        # FIXME: it seems that has_triton_torch is bugged on RocM\n        #        For now, only accept cuda\n        _HAS_TRITON = has_triton_torch() if SYSTEM == \"cuda\" else False\n        if _HAS_TRITON:\n            log_master(logger.info, \"Using optimized Triton indexing kernels.\")\n\n    return _HAS_TRITON\n\n\ndef block_tables_to_padded(\n    max_blocks: int,\n    cu_seqlen: torch.Tensor,\n    block_tables: torch.Tensor,\n    block_tables_ragged: torch.Tensor,\n):\n    def grid(meta):\n        return (\n            triton.cdiv(max_blocks, meta[\"BLOCK_SIZE\"]),\n            len(block_tables),\n        )\n\n    triton_block_tables_to_padded[grid](\n        cu_seqlen,\n        block_tables,\n        block_tables_ragged,\n        block_tables.shape[1],\n        BLOCK_SIZE=256,\n    )\n\n\ndef block_tables_to_ragged(\n    *,\n    block_tables: torch.Tensor,\n    input_lengths: List[int],\n    cache_lengths: List[int],\n    input_lengths_tensor: torch.Tensor,\n    cache_lengths_tensor: torch.Tensor,\n    max_current_length: int,\n) -> torch.Tensor:\n    \"\"\"Convert block table to ragged format compatible with FlashInfer.\"\"\"\n    assert len(input_lengths) == len(cache_lengths)\n\n    total_len = sum(input_lengths) + sum(cache_lengths)\n    block_tables_ragged = torch.empty(\n        total_len, dtype=torch.int32, device=block_tables.device\n    )\n\n    if has_triton():\n        cu_seqlen = input_lengths_tensor.new_zeros(input_lengths_tensor.shape[0] + 1)\n        torch.cumsum(\n            input_lengths_tensor + cache_lengths_tensor, out=cu_seqlen[1:], dim=0\n        )\n\n        def grid(meta):\n            return (\n                triton.cdiv(max_current_length, meta[\"BLOCK_SIZE\"]),\n                len(cache_lengths),\n            )\n\n        triton_block_tables_to_ragged[grid](\n            cu_seqlen,\n            block_tables,\n            block_tables_ragged,\n            block_tables.shape[1],\n            BLOCK_SIZE=256,\n        )\n    else:\n        offset = 0\n        for i, (input_length, cache_length) in enumerate(\n            zip(input_lengths, cache_lengths)\n        ):\n            seq_len = cache_length + input_length\n            block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]\n            offset += seq_len\n\n    return block_tables_ragged\n\n\ndef copy_next_input_ids_inplace(\n    max_next_input_ids: int,\n    all_input_ids: torch.Tensor,\n    cache_lengths: torch.Tensor,\n    input_lengths: torch.Tensor,\n    prompt_lengths: torch.Tensor,\n    next_input_ids: torch.Tensor,\n    cu_accepted_ids: torch.Tensor,\n):\n    def grid(meta):\n        return (\n            triton.cdiv(max_next_input_ids, meta[\"BLOCK_SIZE\"]),\n            len(all_input_ids),\n        )\n\n    triton_copy_next_input_ids_inplace[grid](\n        all_input_ids,\n        cache_lengths,\n        input_lengths,\n        prompt_lengths,\n        next_input_ids,\n        cu_accepted_ids,\n        all_input_ids.shape[1],\n        BLOCK_SIZE=16,\n    )\n\n\ndef prepare_position_slot_ids(\n    max_input_length: int,\n    cache_lengths: torch.Tensor,\n    cu_seqlen: torch.Tensor,\n    cu_slots: torch.Tensor,\n    position_ids: torch.Tensor,\n    slot_indices: torch.Tensor,\n):\n    def grid(meta):\n        return (\n            triton.cdiv(max_input_length, meta[\"BLOCK_SIZE\"]),\n            len(cache_lengths),\n        )\n\n    triton_prepare_position_slot_ids[grid](\n        cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256\n    )\n\n\ndef slots_filtering(\n    max_slots: int,\n    slots: torch.Tensor,\n    filtered_slots: torch.Tensor,\n    cu_slots: torch.Tensor,\n    slots_start: torch.Tensor,\n):\n    def grid(meta):\n        return (\n            triton.cdiv(max_slots, meta[\"BLOCK_SIZE\"]),\n            len(slots_start),\n        )\n\n    triton_slots_filtering[grid](\n        slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256\n    )\n\n\n@triton.jit\ndef triton_slots_filtering(\n    # Inputs\n    slots_ptr,\n    filtered_slots_ptr,\n    slots_start_ptr,\n    cu_slots_ptr,\n    # Const values\n    BLOCK_SIZE: \"tl.constexpr\",\n):\n    # Position in block_tables_ragged.numel() / BLOCK_SIZE\n    pid = tl.program_id(axis=0)\n    # Position in batch\n    bid = tl.program_id(axis=1)\n\n    block_start = pid * BLOCK_SIZE\n    block_arange = block_start + tl.arange(0, BLOCK_SIZE)\n\n    filter_start = tl.load(slots_start_ptr + bid)\n\n    slot_start = tl.load(cu_slots_ptr + bid)\n    slot_end = tl.load(cu_slots_ptr + bid + 1)\n\n    mask = (slot_start + block_arange) < slot_end\n\n    slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)\n    tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)\n\n\n@triton.jit\ndef triton_block_tables_to_padded(\n    # Inputs\n    cu_seqlen_ptr,\n    # Outputs\n    block_tables_ptr,\n    block_tables_ragged_ptr,\n    # Stride\n    stride_block_tables,\n    # Const values\n    BLOCK_SIZE: \"tl.constexpr\",\n):\n    # Position in block_tables_ragged.numel() / BLOCK_SIZE\n    pid = tl.program_id(axis=0)\n    # Position in batch\n    bid = tl.program_id(axis=1)\n\n    block_start = pid * BLOCK_SIZE\n    block_arange = block_start + tl.arange(0, BLOCK_SIZE)\n\n    seq_start = tl.load(cu_seqlen_ptr + bid)\n    seq_end = tl.load(cu_seqlen_ptr + bid + 1)\n\n    mask = (seq_start + block_arange) < seq_end\n\n    blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)\n    tl.store(\n        block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask\n    )\n\n\n@triton.jit\ndef triton_block_tables_to_ragged(\n    # Inputs\n    cu_seqlen_ptr,\n    # Outputs\n    block_tables_ptr,\n    block_tables_ragged_ptr,\n    # Stride\n    stride_block_tables,\n    # Const values\n    BLOCK_SIZE: \"tl.constexpr\",\n):\n    # Position in block_tables_ragged.numel() / BLOCK_SIZE\n    pid = tl.program_id(axis=0)\n    # Position in batch\n    bid = tl.program_id(axis=1)\n\n    block_start = pid * BLOCK_SIZE\n    block_arange = block_start + tl.arange(0, BLOCK_SIZE)\n\n    seq_start = tl.load(cu_seqlen_ptr + bid)\n    seq_end = tl.load(cu_seqlen_ptr + bid + 1)\n\n    mask = (seq_start + block_arange) < seq_end\n\n    blocks = tl.load(\n        block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask\n    )\n    tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)\n\n\n@triton.jit\ndef triton_copy_next_input_ids_inplace(\n    # Inputs\n    all_input_ids_ptr,\n    cache_lengths_ptr,\n    input_lengths_ptr,\n    prompt_lengths_ptr,\n    next_input_ids_ptr,\n    cu_accepted_ids_ptr,\n    # Stride\n    stride_all_input_ids,\n    # Const values\n    BLOCK_SIZE: \"tl.constexpr\",\n):\n    # Position in max_accepted_ids / BLOCK_SIZE\n    pid = tl.program_id(axis=0)\n    # Position in batch\n    bid = tl.program_id(axis=1)\n\n    block_start = pid * BLOCK_SIZE\n    block_arange = block_start + tl.arange(0, BLOCK_SIZE)\n\n    # Used for correctly indexing in all_input_ids\n    cache_length = tl.load(cache_lengths_ptr + bid)\n    input_length = tl.load(input_lengths_ptr + bid)\n    prompt_length = tl.load(prompt_lengths_ptr + bid)\n\n    # Start/End of next_input_ids for this request\n    next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)\n    next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)\n\n    # Mask values out of range\n    mask = (next_input_ids_start + block_arange) < next_input_ids_end\n\n    # Mask values for request still prefilling\n    decode_mask = (cache_length + input_length + block_arange) >= prompt_length\n\n    mask = mask & decode_mask\n\n    # Load this request next input ids\n    next_input_ids = tl.load(\n        next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask\n    )\n\n    # Store in all_input_ids, since it is a 2D tensor, apply stride * bid\n    tl.store(\n        all_input_ids_ptr\n        + stride_all_input_ids * bid\n        + cache_length\n        + input_length\n        + block_arange,\n        next_input_ids,\n        mask=mask,\n    )\n\n\n@triton.jit\ndef triton_prepare_position_slot_ids(\n    # Inputs\n    cache_lengths_ptr,\n    cu_seqlen_ptr,\n    cu_slots_ptr,\n    # Outputs\n    position_ids_ptr,\n    slot_indices_ptr,\n    # Const values\n    BLOCK_SIZE: \"tl.constexpr\",\n):\n    # Position in max_input_length / BLOCK_SIZE\n    pid = tl.program_id(axis=0)\n    # Position in batch\n    bid = tl.program_id(axis=1)\n\n    block_start = pid * BLOCK_SIZE\n    block_arange = block_start + tl.arange(0, BLOCK_SIZE)\n\n    cache_length = tl.load(cache_lengths_ptr + bid)\n\n    seq_start = tl.load(cu_seqlen_ptr + bid)\n    seq_end = tl.load(cu_seqlen_ptr + bid + 1)\n\n    slot_start = tl.load(cu_slots_ptr + bid)\n\n    mask = (seq_start + block_arange) < seq_end\n\n    tl.store(\n        position_ids_ptr + seq_start + block_arange,\n        cache_length + block_arange,\n        mask=mask,\n    )\n    tl.store(\n        slot_indices_ptr + seq_start + block_arange,\n        slot_start + cache_length + block_arange,\n        mask=mask,\n    )\n"
  },
  {
    "path": "server/text_generation_server/models/mllama_causal_lm.py",
    "content": "import torch\n\nimport numpy as np\n\nfrom typing import Iterable, Optional, Tuple, List, Dict\nfrom text_generation_server.pb.generate_pb2 import Request\nfrom io import BytesIO\nfrom PIL import Image\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    PreTrainedTokenizerBase,\n)\n\nfrom text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.globals import PREFIX_CACHING, ATTENTION\nfrom text_generation_server.layers.attention import Seqlen\nfrom text_generation_server.models.metadata_kernels import block_tables_to_ragged\n\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass MllamaCausalLMBatch(VlmCausalLMBatch):\n    image_indices: List[int] = 42\n    aspect_ratio_ids: Optional[torch.Tensor] = None\n    aspect_ratio_mask: Optional[torch.Tensor] = None\n    cross_attention_states: Optional[torch.Tensor] = None\n\n    def prepare_for_prefill(self):\n        super(VlmCausalLMBatch, self).prepare_for_prefill()\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches):\n        batch = super(VlmCausalLMBatch, cls).concatenate(batches)\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n\n        offset = 0\n        image_indices = []\n        attention_states = []\n        for b in batches:\n            if b.cross_attention_states is not None:\n                attention_states.append(b.cross_attention_states)\n            image_indices.extend([i + offset for i in b.image_indices])\n            offset += len(b.image_indices)\n        if len(attention_states) > 0:\n            assert len(image_indices) > 0\n            batch.cross_attention_states = torch.cat(attention_states, dim=0)\n            batch.image_indices = image_indices\n        else:\n            batch.cross_attention_states = None\n            batch.image_indices = []\n        return batch\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]):\n        assert self.image_indices is not None\n        batch = super(VlmCausalLMBatch, self).filter(request_ids)\n        assert self.image_indices is not None\n        indices = []\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            indices.append(idx)\n\n        offset = 0\n        new_image_indices = []\n        prev_i = None\n        for i in self.image_indices:\n            if i in indices:\n                new_image_indices.append(offset)\n                if i != prev_i:\n                    offset += 1\n                prev_i = i\n\n        batch.image_indices = new_image_indices\n        if len(new_image_indices) > 0:\n            assert max(new_image_indices) < self.cross_attention_states.shape[0]\n            assert offset <= self.cross_attention_states.shape[0]\n            batch.cross_attention_states = self.cross_attention_states[\n                new_image_indices\n            ]\n        else:\n            batch.cross_attention_states = None\n        batch.pixel_values = None\n        return batch\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[Request], tokenizer, processor, config\n    ):\n        image_inputs = []\n        texts = []\n        image_indices = []\n        batch_tokenized_inputs = []\n\n        for i, r in enumerate(requests):\n            # Each input is encoded into a list, where each element of this input list is either a string or a URL\n            curr_text = \"\"\n            curr_image = None\n            curr_i = None\n            for chunk in r.input_chunks.chunks:\n                chunk_type = chunk.WhichOneof(\"chunk\")\n                if chunk_type == \"text\":\n                    curr_text += chunk.text\n                elif chunk_type == \"image\":\n                    image = Image.open(BytesIO(chunk.image.data))\n                    # TODO unsure about BOS\n                    curr_text += \"<|image|>\"\n                    image_input = processor.image_processor(image, return_tensors=\"pt\")\n                    curr_image = image_input\n                    curr_i = i\n                    # image_inputs.append(image_input)\n                    # image_indices.append(i)\n                else:\n                    raise RuntimeError(f\"Invalid chunk type {chunk_type}\")\n            texts.append(curr_text)\n            if curr_image is not None:\n                image_inputs.append(curr_image)\n                image_indices.append(curr_i)\n\n            input_ids = tokenizer(\n                curr_text,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=r.add_special_tokens,\n            )[\"input_ids\"]\n            batch_tokenized_inputs.append(input_ids)\n        if image_inputs:\n            image_input = image_inputs[0]\n            new_image_inputs = {\n                \"pixel_values\": torch.cat(\n                    [img[\"pixel_values\"] for img in image_inputs], dim=0\n                ),\n            }\n            if \"aspect_ratio_ids\" in image_input:\n                new_image_inputs[\"aspect_ratio_ids\"] = torch.cat(\n                    [img[\"aspect_ratio_ids\"] for img in image_inputs], dim=0\n                )\n            if \"aspect_ratio_mask\" in image_input:\n                new_image_inputs[\"aspect_ratio_mask\"] = torch.cat(\n                    [img[\"aspect_ratio_mask\"] for img in image_inputs], dim=0\n                )\n            image_inputs = new_image_inputs\n            image_inputs[\"image_indices\"] = image_indices\n        else:\n            image_inputs = None\n\n        if image_inputs is not None:\n            assert len(image_indices) == image_inputs[\"pixel_values\"].shape[0]\n\n        return batch_tokenized_inputs, image_inputs\n\n    @classmethod\n    def from_pb_processor(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        processor,\n        config,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"VlmCausalLMBatch\":\n        batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(\n            pb.requests, tokenizer, processor, config\n        )\n        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n        # XXX: <|image|> token is actually out of bounds and bugs out the logit processors.\n        batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(\n            max=config.text_config.vocab_size - 1\n        )\n        if isinstance(batch.input_ids, list):\n            if len(batch) > 1:\n                input_ids = np.concatenate(batch.input_ids, dtype=np.int64)\n            else:\n                input_ids = batch.input_ids[0]\n            batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)\n\n        batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)\n\n        if image_inputs is not None:\n            batch.pixel_values = image_inputs[\"pixel_values\"].to(\n                device=device, dtype=dtype\n            )\n            batch.aspect_ratio_ids = image_inputs[\"aspect_ratio_ids\"].to(device=device)\n            batch.aspect_ratio_mask = image_inputs[\"aspect_ratio_mask\"].to(\n                device=device\n            )\n            batch.image_indices = image_inputs[\"image_indices\"]\n        else:\n            batch.pixel_values = None\n            batch.aspect_ratio_ids = None\n            batch.aspect_ratio_mask = None\n            batch.image_indices = []\n        assert batch.image_indices is not None\n        return batch\n\n\nclass MllamaCausalLM(VlmCausalLM):\n    def set_inputs_embeds(self, batch):\n        # Set the input embeddings to None, as we are using the input_ids for the model\n        batch.inputs_embeds = None\n\n    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):\n        super(VlmCausalLM, self).cuda_graph_warmup(bs, max_s, max_bt)\n\n    def forward(\n        self,\n        batch: MllamaCausalLMBatch,\n        adapter_data: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n            cache_lengths_tensor = (\n                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)\n            ).reshape(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            cache_lengths_tensor = batch.cache_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        # Try to find an associated cuda graph\n        bs = input_ids.shape[0]\n        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])\n        if sorted_padded_bs:\n            # Get associated cuda graph\n            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]\n        else:\n            cuda_graph = None\n        if (\n            cu_seqlen_prefill is not None\n            or cuda_graph is None\n            # Only run cuda graphs when there's no images.\n            or batch.cross_attention_states is not None\n        ):\n            if PREFIX_CACHING:\n                block_tables = block_tables_to_ragged(\n                    block_tables=block_tables,\n                    input_lengths=batch.input_lengths,\n                    cache_lengths=batch.cache_lengths,\n                    input_lengths_tensor=batch.input_lengths_tensor,\n                    cache_lengths_tensor=batch.cache_lengths_tensor,\n                    max_current_length=batch.max_current_length,\n                )\n            with self._forward_context(\n                block_tables=block_tables,\n                cu_seqlen_prefill=cu_seqlen_prefill,\n                input_lengths_tensor=input_lengths,\n                cache_lengths_tensor=cache_lengths_tensor,\n            ):\n                seqlen = Seqlen(\n                    input_lengths=input_lengths,\n                    cache_lengths=cache_lengths_tensor,\n                    cu_seqlen_q=cu_seqlen_prefill,\n                    max_q=batch.max_input_length,\n                    max_k=batch.max_current_length,\n                )\n\n                if batch.pixel_values is not None:\n                    cross_attention_states = self.model.vision_forward(\n                        pixel_values=batch.pixel_values,\n                        aspect_ratio_ids=batch.aspect_ratio_ids,\n                        aspect_ratio_mask=batch.aspect_ratio_mask,\n                    )\n                    batch.cross_attention_states = cross_attention_states\n\n                cross_attention_states = batch.cross_attention_states\n\n                logits, speculative_logits = self.model.forward(\n                    input_ids=input_ids,\n                    position_ids=position_ids,\n                    cu_seqlen_prefill=cu_seqlen_prefill,\n                    kv_cache=kv_cache,\n                    block_tables=block_tables,\n                    slots=slots,\n                    seqlen=seqlen,\n                    max_s=max_s,\n                    prefill_cache_indices=batch.prefill_cache_indices,\n                    lm_head_indices=lm_head_indices,\n                    cross_attention_states=cross_attention_states,\n                    adapter_data=adapter_data,\n                    image_indices=batch.image_indices[:],\n                )\n                if batch.prefill_cache_indices is not None:\n                    batch.prefill_cache_indices = None\n                if batch.pixel_values is not None:\n                    batch.pixel_values = None\n                return logits, speculative_logits\n\n        # Copy inputs to the static inputs of the cuda graph\n        # Static inputs are potentially padded\n        cuda_graph[\"input_ids\"][: input_ids.shape[0]] = input_ids\n        cuda_graph[\"position_ids\"][: position_ids.shape[0]] = position_ids\n        if ATTENTION == \"flashinfer\":\n            block_tables = block_tables_to_ragged(\n                block_tables=block_tables,\n                input_lengths=batch.input_lengths,\n                cache_lengths=batch.cache_lengths,\n                input_lengths_tensor=batch.input_lengths_tensor,\n                cache_lengths_tensor=batch.cache_lengths_tensor,\n                max_current_length=batch.max_current_length,\n            )\n            cuda_graph[\"block_tables\"][: block_tables.shape[0]] = block_tables\n        else:\n            cuda_graph[\"block_tables\"][\n                : block_tables.shape[0], : block_tables.shape[1]\n            ] = block_tables\n\n        # XXX: This is working only because block 0 is reserved for the healthcheck\n        # so it doesn't matter if we override it with bogus values.\n        cuda_graph[\"slots\"].fill_(0)\n        cuda_graph[\"slots\"][: slots.shape[0]] = slots\n        cuda_graph[\"input_lengths\"].zero_()\n        cuda_graph[\"input_lengths\"][: input_lengths.shape[0]] = input_lengths\n        cuda_graph[\"cache_lengths\"].zero_()\n        cuda_graph[\"cache_lengths\"][\n            : cache_lengths_tensor.shape[0]\n        ] = cache_lengths_tensor\n\n        with self._forward_context(\n            block_tables=cuda_graph[\"block_tables\"],\n            cu_seqlen_prefill=None,\n            input_lengths_tensor=cuda_graph[\"input_lengths\"],\n            cache_lengths_tensor=cuda_graph[\"cache_lengths\"],\n            state=cuda_graph[\"state\"],\n        ):\n            # Replay the graph\n            cuda_graph[\"graph\"].replay()\n\n        # Slice output to the correct shape\n        speculative_logits = (\n            cuda_graph[\"speculative_logits\"][:bs]\n            if cuda_graph[\"speculative_logits\"] is not None\n            else None\n        )\n        logits = cuda_graph[\"logits\"][:bs]\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/models/model.py",
    "content": "import inspect\nimport torch\n\nfrom abc import ABC, abstractmethod\nfrom typing import List, Tuple, Optional, TypeVar, Type, Dict\nfrom collections import defaultdict\nfrom transformers import PreTrainedTokenizerBase\nfrom loguru import logger\n\nfrom text_generation_server.models.globals import (\n    ATTENTION,\n    PREFIX_CACHING,\n    BLOCK_SIZE,\n    PREFILL_CHUNKING,\n)\nfrom text_generation_server.models.types import Batch, Generation\nfrom text_generation_server.utils.log import log_master\nfrom text_generation_server.utils.prefill_chunking import set_support_chunking\nfrom text_generation_server.utils.speculate import get_speculate\nfrom text_generation_server.pb.generate_pb2 import InfoResponse\nfrom text_generation_server.adapters.weights import LayerAdapterWeights\n\nBASE_MODEL_ADAPTER_ID = \"__base_model__\"\n\n\nB = TypeVar(\"B\", bound=Batch)\n\n\nclass Model(ABC):\n    def __init__(\n        self,\n        model_id: str,\n        model: torch.nn.Module,\n        tokenizer: PreTrainedTokenizerBase,\n        requires_padding: bool,\n        dtype: torch.dtype,\n        device: torch.device,\n        rank: int = 0,\n        world_size: int = 1,\n        sliding_window: Optional[int] = None,\n        speculate: Optional[int] = None,\n        adapter_id: str = BASE_MODEL_ADAPTER_ID,\n        support_chunking: bool = False,\n    ):\n        self.model_id = model_id\n        self.model = model.eval()\n        self.tokenizer = tokenizer\n\n        # all_special_ids is not set correctly if the rust tokenizer is unpacked\n        # TODO report this to transformers.\n        other_special_ids = {\n            id for id, token in tokenizer.added_tokens_decoder.items() if token.special\n        }\n        self.all_special_ids = set(tokenizer.all_special_ids)\n        self.all_special_ids.update(other_special_ids)\n        self.requires_padding = requires_padding\n        self.dtype = dtype\n        self.device = device\n        self.rank = rank\n        self.world_size = world_size\n        self.sliding_window = sliding_window if sliding_window != -1 else None\n\n        self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(\n            LayerAdapterWeights\n        )\n        self.loaded_adapters = set()\n        self.static_adapter_id = adapter_id\n\n        if speculate is None:\n            speculate = get_speculate()\n        self.speculate = speculate\n\n        support_chunking = support_chunking and PREFILL_CHUNKING\n\n        if speculate != 0 and support_chunking:\n            log_master(\n                logger.warning,\n                \"Prefill chunking does not support speculation yet. \"\n                \"Prefill chunking will be turned off\",\n            )\n            support_chunking = False\n        if (\n            ATTENTION not in [\"flashinfer\", \"flashdecoding\", \"flashdecoding-ipex\"]\n            and support_chunking\n        ):\n            log_master(\n                logger.warning,\n                \"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.\",\n            )\n            support_chunking = False\n\n        log_master(logger.info, f\"Using prefill chunking = {support_chunking}\")\n\n        self.support_chunking = support_chunking\n        set_support_chunking(support_chunking)\n\n        self.has_position_ids = (\n            inspect.signature(model.forward).parameters.get(\"position_ids\", None)\n            is not None\n        )\n\n        self.check_initialized()\n\n    @property\n    def info(self) -> InfoResponse:\n        if self.requires_padding and self.sliding_window is not None:\n            raise NotImplementedError(\"sliding_window is not implemented with padding\")\n\n        return InfoResponse(\n            requires_padding=self.requires_padding,\n            dtype=str(self.dtype),\n            device_type=self.device.type,\n            window_size=None,  # Setting this parameter to None disabled the block logic with sliding window.\n            speculate=self.speculate,\n            support_chunking=self.support_chunking,\n            use_prefix_caching=PREFIX_CACHING,\n            attention_impl=ATTENTION,\n            block_size=BLOCK_SIZE,\n        )\n\n    @property\n    @abstractmethod\n    def batch_type(self) -> Type[B]:\n        raise NotImplementedError\n\n    @abstractmethod\n    def generate_token(\n        self, batch: B\n    ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:\n        raise NotImplementedError\n\n    def warmup(\n        self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]\n    ) -> Tuple[Optional[int], int, int]:\n        self.generate_token(batch)\n        total = sum(len(i) for i in batch.input_ids)\n        if max_total_tokens is None:\n            max_total_tokens = total\n\n        if max_input_tokens is None:\n            max_input_tokens = max_total_tokens - 1\n        return None, max_input_tokens, max_total_tokens\n\n    def decode_token(\n        self,\n        all_input_ids: List[int],\n        prefix_offset: int = 0,\n        read_offset: int = 0,\n        skip_special_tokens: bool = False,\n    ) -> Tuple[str, int, int]:\n        \"\"\"Hack to hopefully support generate_stream for the maximum number of tokenizers\"\"\"\n\n        # The prefix text is necessary only to defeat cleanup algorithms in the decode\n        # which decide to add a space or not depending on the surrounding ids.\n        prefix_text = self.tokenizer.decode(\n            all_input_ids[prefix_offset:read_offset],\n            skip_special_tokens=skip_special_tokens,\n        )\n        new_text = self.tokenizer.decode(\n            all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens\n        )\n\n        if len(new_text) > len(prefix_text) and not new_text.endswith(\"�\"):\n            # utf-8 char at the end means it's a potential unfinished byte sequence\n            # from byte fallback tokenization.\n            # If it's in the middle, it's probably a real invalid id generated\n            # by the model\n            new_text = new_text[len(prefix_text) :]\n            return new_text, read_offset, len(all_input_ids)\n        else:\n            return \"\", prefix_offset, read_offset\n\n    def check_initialized(self):\n        uninitialized_parameters = []\n        for n, p in self.model.named_parameters():\n            if p.data.device == torch.device(\"meta\"):\n                uninitialized_parameters.append(n)\n        if uninitialized_parameters:\n            raise RuntimeError(\n                f\"found uninitialized parameters in model {self.__class__.__name__}: {uninitialized_parameters}\"\n            )\n"
  },
  {
    "path": "server/text_generation_server/models/seq2seq_lm.py",
    "content": "import torch\nimport torch.distributed\nimport time\nfrom dataclasses import dataclass\nfrom opentelemetry import trace\nfrom transformers import (\n    AutoTokenizer,\n    AutoModelForSeq2SeqLM,\n    PreTrainedTokenizerBase,\n    AutoConfig,\n)\nfrom typing import Optional, Tuple, List, Type, Dict\nfrom text_generation_server.utils.import_utils import SYSTEM\nfrom text_generation_server.utils import (\n    initialize_torch_distributed,\n    weight_files,\n    Weights,\n)\nfrom text_generation_server.utils.chunks import concat_text_chunks\nfrom text_generation_server.utils.quantization import get_loader\nfrom text_generation_server.utils.tokens import batch_top_tokens\nfrom text_generation_server.models import Model\nfrom text_generation_server.models.types import (\n    GeneratedText,\n    Batch,\n    Generation,\n    Tokens,\n)\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling\n\ntracer = trace.get_tracer(__name__)\n\n\n@dataclass\nclass Seq2SeqLMBatch(Batch):\n    batch_id: int\n    requests: List[generate_pb2.Request]\n    requests_idx_mapping: Dict[int, int]\n\n    # Encoder values\n    input_ids: Optional[torch.Tensor]\n    attention_mask: torch.Tensor\n\n    # Decoder values\n    decoder_input_ids: torch.Tensor\n    decoder_attention_mask: Optional[torch.Tensor]\n    encoder_last_hidden_state: Optional[torch.Tensor]\n\n    # All tokens\n    all_decoder_input_ids: List[torch.Tensor]\n\n    # Seq2SeqLM keeps track of both encoder and decoder attention keys and values\n    past_key_values: Optional[List[Tuple]]\n\n    # Lengths of all generations present in the batch\n    input_lengths: List[int]\n    decoder_input_lengths: List[int]\n    prefix_offsets: List[int]\n    read_offsets: List[int]\n\n    # Generation helpers\n    next_token_choosers: List[NextTokenChooser]\n    stopping_criterias: List[StoppingCriteria]\n    top_n_tokens: List[int]\n    top_n_tokens_tensor: torch.Tensor\n\n    # Metadata used for padding\n    max_input_length: int\n    max_decoder_input_length: int\n    padding_right_offset: int\n\n    # Maximum number of tokens this batch will grow to\n    max_tokens: int\n\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        \"\"\"Convert a Seq2SeqLMBatch to a text_generation_server.v1.CachedBatch protobuf\"\"\"\n        return generate_pb2.CachedBatch(\n            id=self.batch_id,\n            request_ids=[r.id for r in self.requests],\n            size=len(self),\n            max_tokens=self.max_tokens,\n            current_tokens=len(self.decoder_input_ids),\n        )\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"Seq2SeqLMBatch\":\n        \"\"\"Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch\"\"\"\n        inputs = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        requests_idx_mapping = {}\n\n        # Parse batch\n        max_truncation = 0\n        padding_right_offset = 0\n        max_decode_tokens = 0\n        for i, r in enumerate(pb.requests):\n            inputs.append(concat_text_chunks(r.input_chunks.chunks))\n            requests_idx_mapping[r.id] = i\n            decoder_input_lengths.append(1)\n            next_token_choosers.append(\n                NextTokenChooser.from_pb(r.parameters, device, tokenizer)\n            )\n            stopping_criteria = StoppingCriteria.from_pb(\n                r.stopping_parameters, tokenizer\n            )\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(r.top_n_tokens)\n            max_truncation = max(max_truncation, r.truncate)\n            max_decode_tokens += stopping_criteria.max_new_tokens\n            padding_right_offset = max(\n                padding_right_offset, stopping_criteria.max_new_tokens\n            )\n\n        # Tokenize batch\n        tokenized_inputs = tokenizer(\n            inputs,\n            return_tensors=\"pt\",\n            padding=True,\n            return_token_type_ids=False,\n            truncation=True,\n            max_length=max_truncation,\n        ).to(device)\n\n        input_lengths = tokenized_inputs[\"attention_mask\"].sum(1)\n        max_input_length = input_lengths.max()\n\n        # Decoder sequence only contains the bos_token\n        decoder_input_ids = (\n            torch.tensor(tokenizer.bos_token_id, device=device)\n            .repeat(len(pb.requests))\n            .view(-1, 1)\n        )\n        for _ in pb.requests:\n            prefix_offsets.append(0)\n            read_offsets.append(1)\n        all_decoder_input_ids = decoder_input_ids.view(-1).split(1)\n        top_n_tokens_tensor = torch.tensor(\n            top_n_tokens, device=device, dtype=torch.int64\n        )\n\n        max_tokens = len(inputs) * (max_input_length + max_decode_tokens)\n\n        return cls(\n            batch_id=pb.id,\n            requests=pb.requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=tokenized_inputs[\"input_ids\"],\n            attention_mask=tokenized_inputs[\"attention_mask\"],\n            decoder_input_ids=decoder_input_ids,\n            all_decoder_input_ids=list(all_decoder_input_ids),\n            decoder_attention_mask=None,\n            encoder_last_hidden_state=None,\n            past_key_values=None,\n            input_lengths=input_lengths.tolist(),\n            decoder_input_lengths=decoder_input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length.item(),\n            max_decoder_input_length=1,\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]) -> Optional[\"Seq2SeqLMBatch\"]:\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n        if len(request_ids) == len(self):\n            return self\n\n        keep_indices = []\n\n        # New values after filtering\n        requests_idx_mapping = {}\n        requests = []\n        input_lengths = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n\n        all_decoder_input_ids = []\n\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n\n        max_input_length = 0\n        max_decoder_input_length = 0\n        padding_right_offset = 0\n\n        total_remaining_decode_tokens = 0\n\n        for i, request_id in enumerate(request_ids):\n            idx = self.requests_idx_mapping[request_id]\n            requests_idx_mapping[request_id] = i\n            keep_indices.append(idx)\n\n            requests.append(self.requests[idx])\n            prefix_offsets.append(self.prefix_offsets[idx])\n            read_offsets.append(self.read_offsets[idx])\n\n            all_decoder_input_ids.append(self.all_decoder_input_ids[idx])\n\n            request_input_length = self.input_lengths[idx]\n            input_lengths.append(request_input_length)\n            max_input_length = max(max_input_length, request_input_length)\n\n            request_decoder_input_length = self.decoder_input_lengths[idx]\n            decoder_input_lengths.append(request_decoder_input_length)\n            max_decoder_input_length = max(\n                max_decoder_input_length, request_decoder_input_length\n            )\n\n            next_token_choosers.append(self.next_token_choosers[idx])\n            stopping_criteria = self.stopping_criterias[idx]\n            stopping_criterias.append(stopping_criteria)\n            top_n_tokens.append(self.top_n_tokens[idx])\n            remaining_decode_tokens = (\n                stopping_criteria.max_new_tokens - stopping_criteria.current_tokens\n            )\n            total_remaining_decode_tokens += remaining_decode_tokens\n            padding_right_offset = max(padding_right_offset, remaining_decode_tokens)\n\n        # Apply indices to input_ids, attention mask, past key values and other items that need to be cached\n        self.decoder_input_ids = self.decoder_input_ids[keep_indices]\n        self.attention_mask = self.attention_mask[keep_indices, -max_input_length:]\n        if self.decoder_attention_mask is not None:\n            self.decoder_attention_mask = self.decoder_attention_mask[\n                keep_indices,\n                -(self.padding_right_offset + max_decoder_input_length) : (\n                    self.decoder_attention_mask.shape[1] - self.padding_right_offset\n                )\n                + padding_right_offset,\n            ]\n\n        self.encoder_last_hidden_state = self.encoder_last_hidden_state[\n            keep_indices, -max_input_length:\n        ]\n\n        # Ensure that past_key_values tensors can be updated in-place\n        if type(self.past_key_values[0]) is tuple:\n            self.past_key_values = [\n                [t for t in layer] for layer in self.past_key_values\n            ]\n\n        decoder_past_seq_len = max_decoder_input_length - 1\n        for layer in self.past_key_values:\n            layer[0] = layer[0][keep_indices, :, -decoder_past_seq_len:]\n            layer[1] = layer[1][keep_indices, :, -decoder_past_seq_len:]\n            layer[2] = layer[2][keep_indices, :, -max_input_length:]\n            layer[3] = layer[3][keep_indices, :, -max_input_length:]\n\n        top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]\n        max_tokens = (\n            len(request_ids) * (max_input_length + max_decoder_input_length)\n            + remaining_decode_tokens\n        )\n\n        self.requests = requests\n        self.requests_idx_mapping = requests_idx_mapping\n        self.input_ids = None\n        self.all_decoder_input_ids = all_decoder_input_ids\n        self.input_lengths = input_lengths\n        self.decoder_input_lengths = decoder_input_lengths\n        self.prefix_offsets = prefix_offsets\n        self.read_offsets = read_offsets\n        self.next_token_choosers = next_token_choosers\n        self.stopping_criterias = stopping_criterias\n        self.top_n_tokens = top_n_tokens\n        self.top_n_tokens_tensor = top_n_tokens_tensor\n        self.max_input_length = max_input_length\n        self.max_decoder_input_length = max_decoder_input_length\n        self.padding_right_offset = padding_right_offset\n        self.max_tokens = max_tokens\n\n        return self\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches: List[\"Seq2SeqLMBatch\"]) -> \"Seq2SeqLMBatch\":\n        \"\"\"Concatenate multiple batches together by padding internal torch tensors\"\"\"\n\n        # Used for padding\n        total_batch_size = 0\n        max_input_length = 0\n        max_decoder_input_length = 0\n        padding_right_offset = 0\n        for batch in batches:\n            total_batch_size += len(batch)\n            max_input_length = max(max_input_length, batch.max_input_length)\n            max_decoder_input_length = max(\n                max_decoder_input_length, batch.max_decoder_input_length\n            )\n            padding_right_offset = max(padding_right_offset, batch.padding_right_offset)\n\n        # Batch attributes\n        requests = []\n        requests_idx_mapping = {}\n        all_decoder_input_ids = []\n        input_lengths = []\n        decoder_input_lengths = []\n        prefix_offsets = []\n        read_offsets = []\n        next_token_choosers = []\n        stopping_criterias = []\n        top_n_tokens = []\n        max_tokens = 0\n\n        # Batch tensors\n        attention_mask = None\n        decoder_input_ids = None\n        decoder_attention_mask = None\n        encoder_last_hidden_state = None\n        top_n_tokens_tensor = None\n        past_key_values = []\n\n        # Used for slicing correctly inside the tensors\n        # Equivalent to a cumsum on batch sizes\n        start_index = 0\n\n        for i, batch in enumerate(batches):\n            # Extend all list attributes\n            requests.extend(batch.requests)\n            all_decoder_input_ids.extend(batch.all_decoder_input_ids)\n            input_lengths.extend(batch.input_lengths)\n            decoder_input_lengths.extend(batch.decoder_input_lengths)\n            prefix_offsets.extend(batch.prefix_offsets)\n            read_offsets.extend(batch.read_offsets)\n            next_token_choosers.extend(batch.next_token_choosers)\n            stopping_criterias.extend(batch.stopping_criterias)\n            top_n_tokens.extend(batch.top_n_tokens)\n\n            if i == 0:\n                requests_idx_mapping = batch.requests_idx_mapping\n            else:\n                # We need to offset the mapping for each batch by the cumulative batch size\n                for k, v in batch.requests_idx_mapping.items():\n                    requests_idx_mapping[k] = v + start_index\n\n            # Slicing end index for this batch\n            end_index = start_index + len(batch)\n\n            # We only concatenate batches that did at least one step\n            if batch.encoder_last_hidden_state is None:\n                raise ValueError(\"Batch encoder_last_hidden_state cannot be None\")\n\n            # Create padded tensor\n            if attention_mask is None:\n                attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_input_length),\n                )\n            # Copy to correct indices\n            attention_mask[start_index:end_index, -batch.max_input_length :] = (\n                batch.attention_mask[:, -batch.max_input_length :]\n            )\n\n            # Create padded tensor\n            if decoder_input_ids is None:\n                decoder_input_ids = batch.decoder_input_ids.new_zeros(\n                    (total_batch_size, 1),\n                )\n            # Copy to correct indices\n            decoder_input_ids[start_index:end_index] = batch.decoder_input_ids\n\n            # Create padded tensor\n            if decoder_attention_mask is None:\n                # As decoder_attention_mask might not exist, we use `batch.attention_mask` for device here\n                decoder_attention_mask = batch.attention_mask.new_zeros(\n                    (total_batch_size, max_decoder_input_length + padding_right_offset),\n                )\n            # If the decoder mask does not exist yet, all generations started at the same time and we never concatenated\n            # this batch. All generations are of length `batch.max_decoder_input_length`.\n            left_offset = max_decoder_input_length - batch.max_decoder_input_length\n            if batch.decoder_attention_mask is None:\n                decoder_attention_mask[\n                    start_index:end_index,\n                    left_offset:-padding_right_offset,\n                ] = 1\n            # If it exists, we need to index\n            else:\n                batch_left_offset = (\n                    batch.decoder_attention_mask.shape[1]\n                    - batch.max_decoder_input_length\n                    - batch.padding_right_offset\n                )\n                decoder_attention_mask[\n                    start_index:end_index,\n                    left_offset:-padding_right_offset,\n                ] = batch.decoder_attention_mask[\n                    :,\n                    batch_left_offset : -batch.padding_right_offset,\n                ]\n\n            # Create padded tensor\n            if encoder_last_hidden_state is None:\n                encoder_last_hidden_state = batch.encoder_last_hidden_state.new_zeros(\n                    (\n                        total_batch_size,\n                        max_input_length,\n                        batch.encoder_last_hidden_state.shape[-1],\n                    ),\n                )\n\n            if top_n_tokens_tensor is None:\n                top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros(\n                    total_batch_size,\n                )\n            top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor\n\n            # Copy to correct indices\n            encoder_last_hidden_state[\n                start_index:end_index, -batch.max_input_length :, :\n            ] = batch.encoder_last_hidden_state[:, -batch.max_input_length :, :]\n            batch.encoder_last_hidden_state = None\n\n            # Ensure that we can update tensors in-place\n            if isinstance(batch.past_key_values[0], tuple):\n                batch.past_key_values = [\n                    [t for t in layer] for layer in batch.past_key_values\n                ]\n\n            # Add eventual padding tokens that were added while concatenating\n            max_tokens += batch.max_tokens + (\n                max_input_length\n                - batch.max_input_length\n                + max_decoder_input_length\n                - batch.max_decoder_input_length\n            ) * len(batch)\n\n            start_index = end_index\n\n        # Determine shapes for new past kv tensors\n        first_past_kvs = batches[0].past_key_values\n        _, num_heads, _, head_dim = first_past_kvs[0][0].shape\n\n        padded_dec_t_shape = (\n            total_batch_size,\n            num_heads,\n            (max_decoder_input_length - 1),\n            head_dim,\n        )\n\n        padded_enc_t_shape = (\n            total_batch_size,\n            num_heads,\n            max_input_length,\n            head_dim,\n        )\n\n        # Iterate over attention layers\n        for j in range(len(first_past_kvs)):\n            past_key_values.append([])\n\n            # Decoder past\n            for k in range(0, 2):\n                # Initialize tensors\n                padded_past_values = first_past_kvs[j][k].new_zeros(padded_dec_t_shape)\n                past_key_values[j].append(padded_past_values)\n\n                start_index = 0\n                for batch in batches:\n                    t = batch.past_key_values[j][k]\n                    # Clear reference to the original tensor\n                    batch.past_key_values[j][k] = None\n                    # Slicing end index for this batch\n                    end_index = start_index + len(batch)\n                    # We slice the past keys and values to remove the padding from previous batches\n                    past_seq_len = batch.max_decoder_input_length - 1\n                    padded_past_values[start_index:end_index, :, -past_seq_len:, :] = t[\n                        :, :, -past_seq_len:, :\n                    ]\n                    del t\n\n                    start_index = end_index\n\n            # Encoder past\n            for k in range(2, 4):\n                # Initialize tensors\n                padded_past_values = first_past_kvs[j][k].new_zeros(padded_enc_t_shape)\n                past_key_values[j].append(padded_past_values)\n\n                start_index = 0\n                for batch in batches:\n                    t = batch.past_key_values[j][k]\n                    # Clear reference to the original tensor\n                    batch.past_key_values[j][k] = None\n                    # Slicing end index for this batch\n                    end_index = start_index + len(batch)\n                    # We slice the past keys and values to remove the padding from previous batches\n                    padded_past_values[\n                        start_index:end_index, :, -batch.max_input_length :, :\n                    ] = t[:, :, -batch.max_input_length :, :]\n                    del t\n\n                    start_index = end_index\n\n        return cls(\n            batch_id=batches[0].batch_id,\n            requests=requests,\n            requests_idx_mapping=requests_idx_mapping,\n            input_ids=None,\n            attention_mask=attention_mask,\n            decoder_input_ids=decoder_input_ids,\n            all_decoder_input_ids=all_decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_last_hidden_state=encoder_last_hidden_state,\n            past_key_values=past_key_values,\n            input_lengths=input_lengths,\n            decoder_input_lengths=decoder_input_lengths,\n            prefix_offsets=prefix_offsets,\n            read_offsets=read_offsets,\n            next_token_choosers=next_token_choosers,\n            stopping_criterias=stopping_criterias,\n            top_n_tokens=top_n_tokens,\n            top_n_tokens_tensor=top_n_tokens_tensor,\n            max_input_length=max_input_length,\n            max_decoder_input_length=max_decoder_input_length,\n            padding_right_offset=padding_right_offset,\n            max_tokens=max_tokens,\n        )\n\n    def __len__(self):\n        return len(self.requests)\n\n\nclass Seq2SeqLM(Model):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        default_dtype=torch.float16,\n        trust_remote_code: bool = False,\n        config_class=AutoConfig,\n        tokenizer_class=AutoTokenizer,\n        aliases=None,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            device = torch.device(f\"xpu:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            device = torch.device(\"cpu\")\n            # Float16 doesn't exist on target.\n            dtype = torch.bfloat16 if dtype is None else dtype\n        else:\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        config = config_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n        )\n        config.quantize = quantize\n        config.speculator = speculator\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        tokenizer.bos_token_id = config.decoder_start_token_id\n\n        weights_loader = get_loader(\n            quantize=quantize, model_id=model_id, revision=revision\n        )\n        torch.distributed.barrier(group=self.process_group)\n        filenames = weight_files(model_id, revision=revision, extension=\".safetensors\")\n        weights = Weights(\n            filenames,\n            device=device,\n            dtype=dtype,\n            process_group=self.process_group,\n            aliases=aliases,\n            weights_loader=weights_loader,\n        )\n        if config.quantize in [\"awq\", \"exl2\", \"gptq\", \"marlin\"]:\n            weights._set_gptq_params(model_id, revision)\n\n        model = model_class(config, weights)\n\n        torch.distributed.barrier(group=self.process_group)\n        super().__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n        )\n\n    @classmethod\n    def fallback(\n        cls,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        if speculator:\n            raise RuntimeError(\"Speculator decoding is not enabled for AutoModel\")\n\n        device_count = 0\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            device_count = torch.cuda.device_count()\n            dtype = torch.float16 if dtype is None else dtype\n        elif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n            device_count = torch.xpu.device_count()\n            dtype = torch.float16 if dtype is None else dtype\n        else:\n            if quantize:\n                raise ValueError(\"quantization is not available on CPU\")\n\n            device = torch.device(\"cpu\")\n            dtype = torch.float32 if dtype is None else dtype\n\n        model = AutoModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=dtype,\n            device_map=(\"auto\" if device_count > 1 else None),\n            load_in_8bit=quantize == \"bitsandbytes\",\n            trust_remote_code=trust_remote_code,\n        )\n        if device_count == 1:\n            model = model.to(device)\n\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n        tokenizer.bos_token_id = model.config.decoder_start_token_id\n\n        self = cls.__new__(\n            cls,\n        )\n        super().__init__(\n            self,\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=True,\n            dtype=dtype,\n            device=device,\n        )\n        self.quantize = quantize\n        return self\n\n    @property\n    def batch_type(self) -> Type[Seq2SeqLMBatch]:\n        return Seq2SeqLMBatch\n\n    def forward(\n        self,\n        input_ids,\n        attention_mask,\n        decoder_input_ids,\n        decoder_attention_mask: Optional,\n        encoder_last_hidden_state: Optional,\n        past_key_values: Optional = None,\n    ) -> Tuple[\n        torch.Tensor,\n        Optional[torch.Tensor],\n        torch.Tensor,\n        List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],\n    ]:\n        # Model Forward\n        outputs = self.model.forward(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            decoder_input_ids=decoder_input_ids,\n            decoder_attention_mask=decoder_attention_mask,\n            encoder_outputs=encoder_last_hidden_state,\n            past_key_values=past_key_values,\n            use_cache=True,\n        )\n        if isinstance(outputs, tuple):\n            # Our custom models\n            outputs, speculative_logits = outputs\n        else:\n            # Generic transformers models\n            speculative_logits = None\n        return (\n            outputs.logits,\n            speculative_logits,\n            outputs.encoder_last_hidden_state,\n            outputs.past_key_values,\n        )\n\n    @tracer.start_as_current_span(\"generate_token\")\n    def generate_token(\n        self, batch: Seq2SeqLMBatch\n    ) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch], Tuple[int, int]]:\n        start = time.time_ns()\n        if batch.decoder_attention_mask is not None:\n            # slice to the correct shape\n            decoder_attention_mask = batch.decoder_attention_mask[\n                :, : -batch.padding_right_offset\n            ]\n        else:\n            decoder_attention_mask = None\n\n        # Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`\n        # internally...\n        if batch.encoder_last_hidden_state is not None:\n            encoder_last_hidden_state = [batch.encoder_last_hidden_state]\n        else:\n            encoder_last_hidden_state = None\n\n        logits, speculative_logits, encoder_last_hidden_state, past = self.forward(\n            batch.input_ids,\n            batch.attention_mask,\n            batch.decoder_input_ids,\n            decoder_attention_mask,\n            encoder_last_hidden_state,\n            batch.past_key_values,\n        )\n\n        # Speculation is not active for seq2seq\n        accepted_ids = torch.ones_like(batch.decoder_input_ids)[:, 0]\n        batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(\n            batch.top_n_tokens,\n            batch.top_n_tokens_tensor,\n            torch.log_softmax(logits[:, -1], -1),\n            accepted_ids,\n        )\n\n        start_decode = time.time_ns()\n\n        # Finished requests\n        generations: List[Generation] = []\n        stopped = True\n\n        # Zipped iterator\n        iterator = zip(\n            batch.requests,\n            batch.input_lengths,\n            batch.prefix_offsets,\n            batch.read_offsets,\n            batch.decoder_input_lengths,\n            logits,\n            batch.next_token_choosers,\n            batch.stopping_criterias,\n            batch.all_decoder_input_ids,\n            batch.top_n_tokens,\n            batch_top_token_ids,\n            batch_top_token_logprobs,\n        )\n\n        # For each member of the batch\n        for i, (\n            request,\n            input_length,\n            prefix_offset,\n            read_offset,\n            decoder_input_length,\n            logits,\n            next_token_chooser,\n            stopping_criteria,\n            all_decoder_input_ids,\n            top_n_tokens,\n            top_token_ids,\n            top_token_logprobs,\n        ) in enumerate(iterator):\n            # Select next token\n            next_token_id, logprobs = next_token_chooser(\n                all_decoder_input_ids.view(1, -1), logits[-1:, :]\n            )\n\n            # Append next token to decoder tokens\n            all_decoder_input_ids = torch.cat(\n                [all_decoder_input_ids, next_token_id.squeeze(1)]\n            )\n            new_decoder_input_length = decoder_input_length + 1\n\n            # Generated token\n            next_token_logprob = logprobs[-1, next_token_id]\n            next_token_id_squeezed = next_token_id.squeeze()\n            next_token_text, prefix_offset, read_offset = self.decode_token(\n                all_decoder_input_ids, prefix_offset, read_offset\n            )\n\n            # Evaluate stopping criteria\n            stop, reason = stopping_criteria(next_token_id, next_token_text)\n\n            if not stop:\n                stopped = False\n\n            # Shard generations\n            # All generations will be appended in the rust sharded client\n            if i % self.world_size == self.rank:\n                if stop:\n                    # Slice with decoder_input_length to remove padding\n                    # Decode all tokens\n                    output_text, _, _ = self.decode_token(\n                        all_decoder_input_ids,\n                        prefix_offset=len(all_decoder_input_ids)\n                        - decoder_input_length\n                        - 1,\n                        read_offset=len(all_decoder_input_ids) - decoder_input_length,\n                        skip_special_tokens=True,\n                    )\n\n                    # Get seed\n                    if isinstance(next_token_chooser.choice, Sampling):\n                        seed = next_token_chooser.choice.seed\n                    else:\n                        seed = None\n\n                    generated_text = GeneratedText(\n                        output_text, stopping_criteria.current_tokens, reason, seed\n                    )\n                else:\n                    generated_text = None\n\n                # Prefill\n                if stopping_criteria.current_tokens == 1 and request.prefill_logprobs:\n                    prefill_tokens = Tokens(\n                        [self.tokenizer.bos_token_id],\n                        [float(\"nan\")],\n                        [self.tokenizer.bos_token],\n                        [False],\n                    )\n                else:\n                    prefill_tokens = None\n\n                if top_n_tokens > 0:\n                    all_top_tokens = []\n                    for top_token_ids, top_token_logprobs in zip(\n                        top_token_ids, top_token_logprobs\n                    ):\n                        toptoken_texts = self.tokenizer.batch_decode(\n                            top_token_ids,\n                            clean_up_tokenization_spaces=False,\n                            skip_special_tokens=False,\n                        )\n                        special_toptokens = [\n                            token_id in self.all_special_ids\n                            for token_id in top_token_ids\n                        ]\n                        top_tokens = Tokens(\n                            top_token_ids,\n                            top_token_logprobs,\n                            toptoken_texts,\n                            special_toptokens,\n                        )\n                        all_top_tokens.append(top_tokens)\n                    top_tokens = all_top_tokens\n                else:\n                    top_tokens = None\n\n                generation = Generation(\n                    request.id,\n                    prefill_tokens,\n                    Tokens(\n                        [next_token_id_squeezed],\n                        [next_token_logprob],\n                        [next_token_text],\n                        [next_token_id_squeezed.item() in self.all_special_ids],\n                    ),\n                    generated_text,\n                    top_tokens,\n                )\n\n                generations.append(generation)\n\n            # Update values\n            batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar(\n                next_token_id_squeezed.item()\n            )\n            batch.decoder_input_ids[i] = next_token_id\n            batch.all_decoder_input_ids[i] = all_decoder_input_ids\n            batch.input_lengths[i] = input_length\n            batch.decoder_input_lengths[i] = new_decoder_input_length\n            batch.prefix_offsets[i] = prefix_offset\n            batch.read_offsets[i] = read_offset\n            batch.max_input_length = max(batch.max_input_length, input_length)\n            batch.max_decoder_input_length = max(\n                batch.max_decoder_input_length, new_decoder_input_length\n            )\n\n        # We finished all generations in the batch; there is no next batch\n        if stopped:\n            forward_ns = start_decode - start\n            decode_ns = time.time_ns() - start_decode\n            return generations, None, (forward_ns, decode_ns)\n\n        # We don't need input_ids after the prefill forward\n        batch.input_ids = None\n        batch.encoder_last_hidden_state = encoder_last_hidden_state\n        batch.past_key_values = past\n        # Update decoder_attention_mask as we added a new token to input_ids\n        if batch.decoder_attention_mask is not None:\n            batch.decoder_attention_mask[:, -batch.padding_right_offset] = 1\n        batch.padding_right_offset -= 1\n\n        forward_ns = start_decode - start\n        decode_ns = time.time_ns() - start_decode\n        return generations, batch, (forward_ns, decode_ns)\n"
  },
  {
    "path": "server/text_generation_server/models/transformers_flash_causal_lm.py",
    "content": "import math\nfrom typing import List, Optional\n\nimport torch\nfrom opentelemetry import trace\nfrom transformers import AutoTokenizer, AutoModelForCausalLM\nimport transformers.modeling_utils\n\nfrom text_generation_server.models.flash_causal_lm import FlashCausalLM\nfrom text_generation_server.utils import initialize_torch_distributed\n\nfrom text_generation_server.layers.attention import paged_attention, attention, Seqlen\nfrom text_generation_server.layers.attention.kv_cache import KVScales, KVCache\nfrom text_generation_server.models.globals import ATTENTION\nfrom text_generation_server.utils.import_utils import SYSTEM\n\ntracer = trace.get_tracer(__name__)\n\n\ndef tgi_flash_attention_forward(\n    module,\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],  # This is a positional arg in Transformers\n    kv_cache: List[KVCache],\n    kv_head_mapping: torch.Tensor,\n    slots: torch.Tensor,\n    cu_seqlen_prefill: Optional[torch.Tensor],\n    seqlen: Seqlen,\n    block_tables: torch.Tensor,\n    max_s: int,\n    kv_scales: KVScales,\n    softmax_scale: Optional[float] = None,\n    sliding_window: Optional[int] = None,\n    softcap: Optional[float] = None,\n    **kwargs,  # This is needed to \"absorb\" other args passed by Transformers modeling\n):\n\n    kv_cache = kv_cache[module.layer_idx]\n    query_states = query_states.transpose(1, 2).squeeze(dim=0)\n    key_states = key_states.transpose(1, 2).squeeze(dim=0)\n    value_states = value_states.transpose(1, 2).squeeze(dim=0)\n\n    # Take care of updating the cache in-place\n    kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)\n\n    _, num_heads, head_dim = query_states.shape\n    softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale\n    sliding_window = -1 if sliding_window is None else sliding_window\n\n    if cu_seqlen_prefill is not None:\n        attn_output = attention(\n            query=query_states,\n            key=key_states,\n            value=value_states,\n            kv_cache=kv_cache,\n            kv_scales=kv_scales,\n            seqlen=seqlen,\n            block_tables=block_tables,\n            softmax_scale=softmax_scale,\n            window_size_left=sliding_window,\n            softcap=softcap,\n        )\n    else:\n        attn_output = paged_attention(\n            query_states,\n            kv_cache,\n            kv_head_mapping,\n            softmax_scale,\n            block_tables,\n            seqlen,\n            max_s,\n            kv_scales=kv_scales,\n            softcap=softcap,\n            window_size_left=sliding_window,\n        )\n\n    attn_output = attn_output.view(-1, num_heads * head_dim)\n\n    return attn_output, None\n\n\ntransformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[\"tgi\"] = tgi_flash_attention_forward\n\n# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,\n# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache\n# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due\n# to internal constraints it was not (yet?) possible to circumvent\nREPLICATED_ATTENTION_MODELS = [\n    \"olmo2\",\n    \"phi3\",\n]\n\n\nclass TransformersFlashCausalLM(FlashCausalLM):\n    def __init__(\n        self,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        default_dtype=torch.float16,\n        trust_remote_code: bool = False,\n        tokenizer_class=AutoTokenizer,\n        kv_cache_dtype: Optional[torch.dtype] = None,\n    ):\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n\n        if speculator:\n            raise RuntimeError(\"Speculator decoding is not enabled for AutoModel\")\n\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            if hasattr(torch, \"xpu\") and torch.xpu.is_available():\n                device = torch.device(f\"xpu:{rank}\")\n            else:\n                device = torch.device(\"cpu\")\n            dtype = default_dtype if dtype is None else dtype\n        else:\n            raise ValueError(\n                \"Flash `Transformers` modeling backend is not available on cpu.\"\n            )\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n\n        model = AutoModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=dtype,\n            load_in_8bit=quantize == \"bitsandbytes\",\n            trust_remote_code=trust_remote_code,\n            attn_implementation=\"tgi\",\n            device_map=device if world_size == 1 else None,\n            tp_plan=\"auto\" if world_size > 1 else None,\n        )\n\n        torch.distributed.barrier(group=self.process_group)\n\n        if tokenizer.pad_token_id is None:\n            if model.config.pad_token_id is not None:\n                tokenizer.pad_token_id = model.config.pad_token_id\n            elif model.config.eos_token_id is not None and isinstance(\n                model.config.eos_token_id, int\n            ):\n                tokenizer.pad_token_id = model.config.eos_token_id\n            elif tokenizer.eos_token_id is not None:\n                tokenizer.pad_token_id = tokenizer.eos_token_id\n            else:\n                tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n\n        self.num_layers = model.config.num_hidden_layers\n        self.num_heads = model.config.num_attention_heads\n        self.num_kv_heads = model.config.num_key_value_heads\n        # Some models use GQA and different sizes for o_proj\n        # and q_proj, that allows for that.\n        if hasattr(model.config, \"head_dim\"):\n            self.head_size = model.config.head_dim\n        else:\n            self.head_size = (\n                model.config.hidden_size // model.config.num_attention_heads\n            )\n\n        # Skip it for models in the exception list\n        if model.config.model_type not in REPLICATED_ATTENTION_MODELS:\n            self.num_heads = self.num_heads // self.process_group.size()\n            self.num_kv_heads = (\n                self.num_kv_heads // self.process_group.size()\n                if self.num_kv_heads > 1\n                else self.num_kv_heads\n            )\n\n        self.cuda_graphs = {}\n        self.kv_cache = []\n        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype\n\n        if ATTENTION == \"flashinfer\":\n            from text_generation_server.layers.attention.flashinfer import (\n                create_prefill_state,\n                create_decode_state,\n                create_prefill_with_paged_kv_state,\n            )\n\n            self.prefill_state = create_prefill_state(device=device)\n            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(\n                device=device\n            )\n\n            self.decode_state = create_decode_state(\n                device=device,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n            )\n\n        self.num_groups = self.num_heads // self.num_kv_heads\n\n        # Those will never change and will be used in the forwards\n        self.kv_head_mapping = torch.arange(\n            0, self.num_kv_heads, dtype=torch.int32, device=device\n        ).repeat_interleave(self.num_groups)\n        # This means no scale\n        self.kv_scales = KVScales(\n            torch.tensor(1.0, device=device),\n            torch.tensor(1.0, device=device),\n        )\n\n        # Skip FlashCausalLM init.\n        super(FlashCausalLM, self).__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=False,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n        )\n\n        # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code\n        # We first copy the original model.forward because we still need it in the monkey patch\n        self.model.original_forward = self.model.forward\n        self.model.forward = self._model_forward\n\n        torch.distributed.barrier(group=self.process_group)\n\n    @classmethod\n    def fallback(\n        cls,\n        model_id: str,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n    ):\n        return cls(\n            model_id=model_id,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n        )\n\n    def _model_forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[KVCache],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        lm_head_indices: Optional[torch.Tensor],\n        prefill_cache_indices=None,  # not used, but passed to match original signature\n        adapter_data=None,  # not supported, but passed to match original signature\n    ):\n        # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers\n        logits_to_keep = lm_head_indices if lm_head_indices is not None else 0\n\n        # This is equivalent to `self.model.forward`, see the monkey patch in __init__\n        logits = self.model.original_forward(\n            input_ids=input_ids.unsqueeze(0),  # expand dim to fit Transformers\n            position_ids=position_ids.unsqueeze(0),  # expand dim to fit Transformers\n            past_key_values=None,  # we use self.kv_cache instead of transformers cache object\n            use_cache=False,  # we use self.kv_cache instead of transformers cache object\n            logits_to_keep=logits_to_keep,\n            return_dict=True,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            kv_head_mapping=self.kv_head_mapping,\n            kv_scales=self.kv_scales,\n        ).logits.squeeze(dim=0)\n\n        return logits, None\n"
  },
  {
    "path": "server/text_generation_server/models/transformers_flash_vlm.py",
    "content": "import math\nfrom typing import List, Optional\n\nimport torch\nfrom opentelemetry import trace\nfrom transformers import AutoTokenizer, AutoProcessor\nimport transformers.modeling_utils\n\nfrom text_generation_server.models.flash_causal_lm import FlashCausalLM\nfrom text_generation_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch\nfrom text_generation_server.utils import initialize_torch_distributed\n\nfrom text_generation_server.layers.attention import paged_attention, attention, Seqlen\nfrom text_generation_server.layers.attention.kv_cache import KVScales, KVCache\nfrom text_generation_server.models.globals import ATTENTION\nimport torch.nn.functional as F\nfrom text_generation_server.utils.import_utils import SYSTEM\n\ntracer = trace.get_tracer(__name__)\n\n# The base TP plan of these models has replicated q/k/v. This means that each process will see the full states,\n# hence we should not divide the number of heads by the world size. This is a known waste of VRAM (the cache\n# will be fully replicated on each process) and GPU communication (additional all-gather operations), however due\n# to internal constraints it was not (yet?) possible to circumvent\nREPLICATED_ATTENTION_MODELS = [\n    \"olmo2\",\n    \"phi3\",\n]\n\n\n# # Qwen2VL\n# transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[\n#     \"tgi\"\n# ] = transformers.models.qwen2_vl.modeling_qwen2_vl.QWEN2_VL_VISION_ATTENTION_CLASSES[\n#     \"eager\"\n# ]\ndef tgi_flash_attention_forward(\n    module,\n    query_states: torch.Tensor,\n    key_states: torch.Tensor,\n    value_states: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],  # This is a positional arg in Transformers\n    kv_cache: List[KVCache],\n    kv_head_mapping: torch.Tensor,\n    slots: torch.Tensor,\n    cu_seqlen_prefill: Optional[torch.Tensor],\n    seqlen: Seqlen,\n    block_tables: torch.Tensor,\n    max_s: int,\n    kv_scales: KVScales,\n    softmax_scale: Optional[float] = None,\n    sliding_window: Optional[int] = None,\n    softcap: Optional[float] = None,\n    use_sdpa: Optional[bool] = False,\n    **kwargs,  # This is needed to \"absorb\" other args passed by Transformers modeling\n):\n    kv_cache = kv_cache[module.layer_idx]\n    query_states = query_states.transpose(1, 2).squeeze(dim=0)\n    key_states = key_states.transpose(1, 2).squeeze(dim=0)\n    value_states = value_states.transpose(1, 2).squeeze(dim=0)\n\n    # Take care of updating the cache in-place\n    kv_cache.store(key=key_states, value=value_states, slots=slots, kv_scales=kv_scales)\n\n    _, num_heads, head_dim = query_states.shape\n    softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale\n    sliding_window = -1 if sliding_window is None else sliding_window\n\n    if cu_seqlen_prefill is not None:\n        if not use_sdpa:\n            attn_output = attention(\n                query=query_states,\n                key=key_states,\n                value=value_states,\n                kv_cache=kv_cache,\n                kv_scales=kv_scales,\n                seqlen=seqlen,\n                block_tables=block_tables,\n                softmax_scale=softmax_scale,\n                window_size_left=sliding_window,\n                softcap=softcap,\n            )\n        else:\n            lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]\n            max_length = max(lengths)\n            attention_mask = attention_mask[:, :, :, :max_length]\n            enable_gqa = query_states.shape[1] != key_states.shape[1]\n            # Split tensors using vectorized split\n            query_list = torch.split(query_states, lengths.tolist(), dim=0)\n            key_list = torch.split(key_states, lengths.tolist(), dim=0)\n            value_list = torch.split(value_states, lengths.tolist(), dim=0)\n\n            padded_query = torch.nn.utils.rnn.pad_sequence(query_list, batch_first=True)\n            padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True)\n            padded_value = torch.nn.utils.rnn.pad_sequence(value_list, batch_first=True)\n\n            padded_query = padded_query.transpose(1, 2).contiguous()\n            padded_key = padded_key.transpose(1, 2).contiguous()\n            padded_value = padded_value.transpose(1, 2).contiguous()\n\n            # Compute attention\n            attn_output = F.scaled_dot_product_attention(\n                padded_query,\n                padded_key,\n                padded_value,\n                attn_mask=attention_mask,\n                scale=softmax_scale,\n                enable_gqa=enable_gqa,\n            )\n\n            attn_output = attn_output.transpose(\n                1, 2\n            )  # [batch_size, seq_len, num_heads, head_dim]\n            max_seq_len = padded_query.size(2)\n            seq_range = torch.arange(max_seq_len, device=padded_query.device).unsqueeze(\n                0\n            )\n            lengths_tensor = torch.tensor(\n                lengths, device=padded_query.device\n            ).unsqueeze(1)\n            mask = seq_range < lengths_tensor  # [batch, max_seq_len]\n            attn_output = attn_output[mask]  # [total_seq_len, num_heads, head_dim]\n\n    else:\n        attn_output = paged_attention(\n            query_states,\n            kv_cache,\n            kv_head_mapping,\n            softmax_scale,\n            block_tables,\n            seqlen,\n            max_s,\n            kv_scales=kv_scales,\n            softcap=softcap,\n            window_size_left=sliding_window,\n        )\n\n    attn_output = attn_output.view(-1, num_heads * head_dim)\n\n    return attn_output, None\n\n\ntransformers.modeling_utils.ALL_ATTENTION_FUNCTIONS[\"tgi\"] = tgi_flash_attention_forward\n\n\n# TODO: implement\n# tgi_cross_attention_forward\n\n\nclass TransformersFlashVlmCausalLM(VlmCausalLM):\n    def __init__(\n        self,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        default_dtype=torch.float16,\n        trust_remote_code: bool = False,\n        tokenizer_class=AutoTokenizer,\n        processor_class=AutoProcessor,\n        processor_kwargs=None,\n        kv_cache_dtype: Optional[torch.dtype] = None,\n        batch_class=VlmCausalLMBatch,\n        support_chunking: bool = True,\n    ):\n        self.batch_class = batch_class\n        self.quantize = quantize\n        self.process_group, rank, world_size = initialize_torch_distributed()\n        self.dtype = dtype\n\n        if speculator:\n            raise RuntimeError(\"Speculator decoding is not enabled for AutoModel\")\n\n        if torch.cuda.is_available():\n            device = torch.device(f\"cuda:{rank}\")\n            dtype = default_dtype if dtype is None else dtype\n        elif SYSTEM == \"ipex\":\n            if hasattr(torch, \"xpu\") and torch.xpu.is_available():\n                device = torch.device(f\"xpu:{rank}\")\n            else:\n                device = torch.device(\"cpu\")\n            dtype = default_dtype if dtype is None else dtype\n        else:\n            raise ValueError(\n                \"Flash `Transformers` modeling backend is not available on cpu.\"\n            )\n\n        tokenizer = tokenizer_class.from_pretrained(\n            model_id,\n            revision=revision,\n            padding_side=\"left\",\n            truncation_side=\"left\",\n            trust_remote_code=trust_remote_code,\n        )\n\n        if processor_kwargs is None:\n            processor_kwargs = {}\n\n        self.processor = processor_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n            **processor_kwargs,\n        )\n\n        attn_implementation = {\n            \"text_config\": \"tgi\",\n            \"vision_config\": \"sdpa\",\n        }\n\n        model = model_class.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=dtype,\n            load_in_8bit=quantize == \"bitsandbytes\",\n            trust_remote_code=trust_remote_code,\n            attn_implementation=attn_implementation,\n            device_map=device if world_size == 1 else None,\n            tp_plan=\"auto\" if world_size > 1 else None,\n        )\n\n        torch.distributed.barrier(group=self.process_group)\n        self.config = model.config\n        config = model.config\n\n        # VLM models define the config we care about in their text_config\n        text_config = getattr(model.config, \"text_config\", None)\n        if text_config is not None:\n            config = text_config\n\n        if tokenizer.pad_token_id is None:\n            if model.config.pad_token_id is not None:\n                tokenizer.pad_token_id = model.config.pad_token_id\n            elif model.config.eos_token_id is not None and isinstance(\n                model.config.eos_token_id, int\n            ):\n                tokenizer.pad_token_id = model.config.eos_token_id\n            elif tokenizer.eos_token_id is not None:\n                tokenizer.pad_token_id = tokenizer.eos_token_id\n            else:\n                tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n\n        self.num_layers = config.num_hidden_layers\n        self.num_heads = config.num_attention_heads\n        self.num_kv_heads = config.num_key_value_heads\n        # Some models use GQA and different sizes for o_proj\n        # and q_proj, that allows for that.\n        if hasattr(config, \"head_dim\"):\n            self.head_size = config.head_dim\n        else:\n            self.head_size = config.hidden_size // config.num_attention_heads\n\n        # Skip it for models in the exception list\n        if config.model_type not in REPLICATED_ATTENTION_MODELS:\n            self.num_heads = self.num_heads // self.process_group.size()\n            self.num_kv_heads = (\n                self.num_kv_heads // self.process_group.size()\n                if self.num_kv_heads > 1\n                else self.num_kv_heads\n            )\n\n        self.cuda_graphs = {}\n        self.kv_cache = []\n        self.kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype\n\n        if ATTENTION == \"flashinfer\":\n            from text_generation_server.layers.attention.flashinfer import (\n                create_prefill_state,\n                create_decode_state,\n                create_prefill_with_paged_kv_state,\n            )\n\n            self.prefill_state = create_prefill_state(device=device)\n            self.prefill_with_paged_kv_state = create_prefill_with_paged_kv_state(\n                device=device\n            )\n\n            self.decode_state = create_decode_state(\n                device=device,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n            )\n\n        self.num_groups = self.num_heads // self.num_kv_heads\n\n        # Those will never change and will be used in the forwards\n        self.kv_head_mapping = torch.arange(\n            0, self.num_kv_heads, dtype=torch.int32, device=device\n        ).repeat_interleave(self.num_groups)\n        # This means no scale\n        self.kv_scales = KVScales(\n            torch.tensor(1.0, device=device),\n            torch.tensor(1.0, device=device),\n        )\n\n        # Skip FlashCausalLM init.\n        super(FlashCausalLM, self).__init__(\n            model_id=model_id,\n            model=model,\n            tokenizer=tokenizer,\n            requires_padding=False,\n            dtype=dtype,\n            device=device,\n            rank=rank,\n            world_size=world_size,\n            support_chunking=support_chunking,\n        )\n\n        # Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code\n        # We first copy the original model.forward because we still need it in the monkey patch\n        self.model.original_forward = self.model.forward\n        self.model.forward = self._model_forward\n        self.model.get_position_ids = self.get_position_ids\n\n        torch.distributed.barrier(group=self.process_group)\n\n    def get_position_ids(self, input_ids, image_grid_thw, position_ids):\n        return position_ids\n\n    def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):\n        return {\n            \"input_ids\": input_ids.unsqueeze(0),\n            \"position_ids\": position_ids.unsqueeze(0),\n        }\n\n    def post_process_outputs(self, logits, lm_head_indices):\n        return logits.squeeze(dim=0)\n\n    @classmethod\n    def fallback(\n        cls,\n        model_id: str,\n        model_class,\n        revision: Optional[str] = None,\n        quantize: Optional[str] = None,\n        speculator: Optional[str] = None,\n        dtype: Optional[torch.dtype] = None,\n        trust_remote_code: bool = False,\n        batch_class: Optional[type] = VlmCausalLMBatch,\n        processor_kwargs: Optional[dict] = None,\n        support_chunking: bool = True,\n    ):\n        return cls(\n            model_id=model_id,\n            model_class=model_class,\n            revision=revision,\n            quantize=quantize,\n            speculator=speculator,\n            dtype=dtype,\n            trust_remote_code=trust_remote_code,\n            batch_class=batch_class,\n            processor_kwargs=processor_kwargs,\n            support_chunking=support_chunking,\n        )\n\n    def _model_forward(\n        self,\n        input_ids: torch.Tensor,\n        position_ids: torch.Tensor,\n        cu_seqlen_prefill: Optional[torch.Tensor],\n        kv_cache: List[KVCache],\n        block_tables: torch.Tensor,\n        slots: torch.Tensor,\n        seqlen: Seqlen,\n        max_s: int,\n        lm_head_indices: Optional[torch.Tensor],\n        prefill_cache_indices=None,  # not used, but passed to match original signature\n        adapter_data=None,  # not supported, but passed to match original signature\n        pixel_values: torch.FloatTensor = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n        pixel_attention_mask=None,\n        image_sizes: Optional[torch.LongTensor] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        # A value of `None` (i.e. no logit slicing) translates to `0` in Transformers\n        logits_to_keep = lm_head_indices if lm_head_indices is not None else 0\n\n        inputs = self.pre_process_inputs(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n        )\n        inputs[\"input_ids\"] = None\n\n        # This is equivalent to `self.model.forward`, see the monkey patch in __init__\n        logits = self.model.original_forward(\n            input_ids=inputs[\"input_ids\"],\n            inputs_embeds=inputs_embeds.unsqueeze(0),\n            position_ids=inputs[\"position_ids\"],\n            past_key_values=None,  # we use self.kv_cache instead of transformers cache object\n            use_cache=False,  # we use self.kv_cache instead of transformers cache object\n            logits_to_keep=logits_to_keep,\n            return_dict=True,\n            cu_seqlen_prefill=cu_seqlen_prefill,\n            kv_cache=kv_cache,\n            block_tables=block_tables,\n            slots=slots,\n            seqlen=seqlen,\n            max_s=max_s,\n            kv_head_mapping=self.kv_head_mapping,\n            kv_scales=self.kv_scales,\n            pixel_values=pixel_values,\n            pixel_attention_mask=pixel_attention_mask,\n            image_sizes=image_sizes,\n            image_grid_thw=image_grid_thw,\n            attention_mask=inputs.get(\"attention_mask\", None),\n            use_sdpa=inputs.get(\"use_sdpa\", False),\n            cache_position=inputs.get(\"cache_position\", None),\n        ).logits\n\n        logits = self.post_process_outputs(logits, lm_head_indices)\n\n        return logits, None\n\n\nclass TransformersQwen2VlmCausalLM(TransformersFlashVlmCausalLM):\n    def get_position_ids(self, input_ids: torch.Tensor, image_grid_thw: torch.Tensor):\n        if image_grid_thw is None:\n            return (\n                torch.arange(input_ids.shape[0], device=input_ids.device)\n                .unsqueeze(1)\n                .repeat(1, 3)\n            )\n\n        spatial_merge_size = self.config.vision_config.spatial_merge_size\n        vision_start_token_id = self.config.vision_start_token_id\n        vision_end_token_id = self.config.vision_end_token_id\n        device = input_ids.device\n        dtype = input_ids.dtype\n        input_ids_len = input_ids.shape[0]\n\n        vision_starts = torch.where(input_ids == vision_start_token_id)[0]\n        vision_ends = torch.where(input_ids == vision_end_token_id)[0]\n        vision_segments = torch.stack((vision_starts, vision_ends), dim=1)\n        prev_vision_end = torch.cat(\n            [torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]\n        )\n        text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1\n        vision_widths_max = torch.cat(\n            [\n                torch.zeros(1, device=image_grid_thw.device, dtype=dtype),\n                image_grid_thw[:-1, 2] // spatial_merge_size,\n            ]\n        )\n        vision_segment_lengths = vision_widths_max + text_lengths_between_vision\n        vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)\n        text_segment_lengths = vision_segment_lengths - text_lengths_between_vision\n\n        # create position ids for each vision segment based on the image grid\n        llm_pos_ids_list = []\n        for i, _ in enumerate(vision_segments):\n            t, h, w = (\n                image_grid_thw[i][0],\n                image_grid_thw[i][1] // spatial_merge_size,\n                image_grid_thw[i][2] // spatial_merge_size,\n            )\n            t_indices = torch.arange(t, device=device).repeat_interleave(h * w)\n            h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)\n            w_indices = torch.arange(w, device=device).repeat(t * h)\n            image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)\n\n            # offset by the position of the last vision segment\n            im = image_position_ids + vision_segment_lengths[i]\n            llm_pos_ids_list.append(im)\n\n        # create position ids for each text segment\n        text_ranges = [\n            torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)\n            + text_segment_lengths[i]\n            for i, seq_len in enumerate(text_lengths_between_vision)\n        ]\n\n        full_llm_pos_ids_list = [\n            item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist\n        ]\n        # import ipdb\n\n        # ipdb.set_trace()\n        max_s = full_llm_pos_ids_list[-1].max() + 1\n        final_text_len = input_ids_len - vision_ends[-1]\n        if final_text_len > 0:\n            m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)\n            full_llm_pos_ids_list.append(m + max_s)\n\n        position_ids = (\n            torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)\n        )\n        return position_ids\n\n    def post_process_outputs(self, logits, lm_head_indices):\n        return logits.squeeze(dim=0)[lm_head_indices].unsqueeze(0)\n\n    def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):\n        input_ids = input_ids.unsqueeze(0)\n        position_ids = position_ids.transpose(0, 1).unsqueeze(1)\n        return {\"input_ids\": input_ids, \"position_ids\": position_ids}\n\n\nclass TransformersGemma3VlmCausalLM(TransformersFlashVlmCausalLM):\n    def get_attention_mask(self, input_ids, cu_seqlen_prefill):\n        device = input_ids.device\n        dtype = self.dtype\n        min_dtype = torch.finfo(dtype).min\n\n        lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist()\n        batch_size = len(lengths)\n\n        sequence_length = max(lengths)\n        target_length = sequence_length\n        # Create the padding mask from the computed lengths.\n        # pad_mask: [batch, sequence_length] where True indicates valid tokens.\n        seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)\n        lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1)\n        pad_mask = seq_range < lengths_tensor  # shape: [batch, sequence_length]\n\n        # Build the base causal mask (for non-image tokens):\n        causal_mask = torch.tril(\n            torch.ones(\n                (sequence_length, sequence_length), dtype=torch.bool, device=device\n            )\n        )\n        base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze(\n            1\n        )  # [batch, sequence_length, sequence_length]\n        base_mask = base_mask & causal_mask.unsqueeze(0)  # apply causal constraint\n\n        image_token_mask = (input_ids == self.config.image_token_index).to(\n            input_ids.device\n        )\n\n        image_token_mask = torch.nn.utils.rnn.pad_sequence(\n            torch.split(image_token_mask, lengths), batch_first=True, padding_value=0\n        )\n        bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze(\n            1\n        )\n\n        # Combine the causal base mask and the bidirectional mask.\n        combined_mask = torch.logical_or(\n            base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1)\n        ).to(device)\n        # combined_mask now has shape [batch, 1, sequence_length, sequence_length]\n\n        full_attention_mask = torch.zeros(\n            (batch_size, 1, sequence_length, target_length),\n            device=device,\n            dtype=torch.bool,\n        )\n        full_attention_mask[:, :, :, :sequence_length] = combined_mask\n\n        final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)\n\n        return final_attention_mask\n\n    def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):\n        inputs = {\n            \"input_ids\": input_ids.unsqueeze(0),\n            \"position_ids\": position_ids.unsqueeze(0),\n        }\n\n        if cu_seqlen_prefill is not None:\n            attention_mask = self.get_attention_mask(\n                input_ids.squeeze(0), cu_seqlen_prefill\n            )\n            inputs[\"attention_mask\"] = attention_mask\n            inputs[\"use_sdpa\"] = True\n\n        return inputs\n\n\nclass TransformersLlama4VlmCausalLM(TransformersFlashVlmCausalLM):\n    def pre_process_inputs(self, input_ids, position_ids, cu_seqlen_prefill):\n        inputs = super().pre_process_inputs(input_ids, position_ids, cu_seqlen_prefill)\n        inputs[\"cache_position\"] = position_ids\n        inputs[\"attention_mask\"] = torch.zeros((1, 1, 1, 1), device=input_ids.device)\n        return inputs\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.FloatTensor,\n        pixel_attention_mask: Optional[torch.FloatTensor] = None,\n        image_sizes: Optional[torch.Tensor] = None,\n        image_grid_thw: Optional[torch.LongTensor] = None,\n    ):\n        image_features = self.model.get_image_features(\n            pixel_values=pixel_values,\n            vision_feature_layer=self.model.config.vision_config.vision_feature_layer,\n            vision_feature_select_strategy=self.model.config.vision_config.vision_feature_select_strategy,\n            image_sizes=image_sizes,\n        )\n\n        vision_flat = image_features.view(-1, image_features.size(-1))\n        projected_vision_flat = self.model.multi_modal_projector(vision_flat)\n        return projected_vision_flat\n\n    def get_inputs_embeds(self, input_ids, vision_embeds=None):\n        inputs_embeds = self.model.get_input_embeddings()(input_ids)\n\n        if vision_embeds is not None:\n            original_inputs_embeds_shape = inputs_embeds.shape\n            special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(\n                -1\n            )\n            final_mask = special_image_mask.to(inputs_embeds.device)\n            inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))\n\n            final_mask_1d = final_mask[..., 0].reshape(-1)\n            num_tokens_to_fill = final_mask_1d.sum()\n\n            if num_tokens_to_fill != vision_embeds.size(0):\n                raise ValueError(\n                    f\"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, \"\n                    f\"but multi_modal_projector returned {vision_embeds.size(0)}\"\n                )\n\n            expanded_mask = final_mask_1d.unsqueeze(-1).expand(\n                -1, inputs_embeds.size(-1)\n            )\n            inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, vision_embeds)\n            inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)\n        return inputs_embeds\n"
  },
  {
    "path": "server/text_generation_server/models/types.py",
    "content": "import torch\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nfrom transformers import PreTrainedTokenizerBase\n\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.pb.generate_pb2 import FinishReason\n\n\nclass Batch(ABC):\n    @abstractmethod\n    def to_pb(self) -> generate_pb2.CachedBatch:\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"Batch\":\n        raise NotImplementedError\n\n    @abstractmethod\n    def filter(self, request_ids: List[int]) -> \"Batch\":\n        raise NotImplementedError\n\n    @classmethod\n    @abstractmethod\n    def concatenate(cls, batches: List[\"Batch\"]) -> \"Batch\":\n        raise NotImplementedError\n\n    @abstractmethod\n    def __len__(self):\n        raise NotImplementedError\n\n\n@dataclass\nclass GeneratedText:\n    text: str\n    generated_tokens: int\n    finish_reason: FinishReason\n    seed: Optional[int]\n\n    def to_pb(self) -> generate_pb2.GeneratedText:\n        return generate_pb2.GeneratedText(\n            text=self.text,\n            generated_tokens=self.generated_tokens,\n            finish_reason=self.finish_reason,\n            seed=self.seed,\n        )\n\n\n@dataclass\nclass Tokens:\n    token_ids: List[int]\n    logprobs: List[float]\n    texts: List[str]\n    is_special: List[bool]\n\n    def to_pb(self) -> generate_pb2.Tokens:\n        return generate_pb2.Tokens(\n            ids=self.token_ids,\n            logprobs=self.logprobs,\n            texts=self.texts,\n            is_special=self.is_special,\n        )\n\n    def __len__(self):\n        return len(self.token_ids)\n\n    def __add__(self, other: \"Tokens\") -> \"Tokens\":\n        return Tokens(\n            self.token_ids + other.token_ids,\n            self.logprobs + other.logprobs,\n            self.texts + other.texts,\n            self.is_special + other.is_special,\n        )\n\n\n@dataclass\nclass Generation:\n    request_id: int\n    prefill_tokens: Optional[Tokens]\n    tokens: Tokens\n    generated_text: Optional[GeneratedText]\n    # Optional for now, since it's not yet supported for every model.\n    top_tokens: Optional[List[Tokens]]\n\n    def to_pb(self) -> generate_pb2.Generation:\n        return generate_pb2.Generation(\n            request_id=self.request_id,\n            prefill_tokens=(\n                self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None\n            ),\n            tokens=self.tokens.to_pb(),\n            generated_text=(\n                self.generated_text.to_pb() if self.generated_text is not None else None\n            ),\n            top_tokens=(\n                [top_tokens.to_pb() for top_tokens in self.top_tokens]\n                if self.top_tokens is not None\n                else None\n            ),\n        )\n"
  },
  {
    "path": "server/text_generation_server/models/vlm_causal_lm.py",
    "content": "from dataclasses import dataclass\nimport torch\nfrom PIL import Image\nfrom io import BytesIO\n\nfrom opentelemetry import trace\nfrom typing import Iterable, Optional, Tuple, List, Type, Dict\n\nfrom transformers import PreTrainedTokenizerBase\nfrom transformers.image_processing_utils import select_best_resolution\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.models.flash_causal_lm import (\n    FlashCausalLMBatch,\n    FlashCausalLM,\n)\nfrom text_generation_server.models.globals import PREFIX_CACHING, ATTENTION, MEM_POOL\nfrom loguru import logger\nfrom text_generation_server.utils.log import log_master\nfrom transformers import AutoProcessor\nfrom text_generation_server.layers.attention import Seqlen\nfrom text_generation_server.models.metadata_kernels import block_tables_to_ragged\n\ntracer = trace.get_tracer(__name__)\n\nIDEFICS2_FAKE_TOKEN = \"<fake_token_around_image>\"\nIDEFICS2_IMAGE_TOKEN = \"<image>\"\n\nIDEFICS3_IMAGE_TOKEN = \"<image>\"\nIDEFICS3_FAKE_IMAGE_TOKEN = \"<fake_token_around_image>\"\nIDEFICS3_GLOBAL_IMG_TOKEN = \"<global-img>\"\n\n\ndef prompt_split_image_llama4(aspect_ratio, num_patches_per_chunk):\n    \"\"\"\n    Create a structured string representation of image tokens\n\n    Args:\n       num_patches: Number of patches in the image\n\n    Returns:\n        String with appropriate image tokens\n    \"\"\"\n    img_string = \"<|image_start|>\"\n    ratio_h, ratio_w = aspect_ratio\n    if ratio_h * ratio_w > 1:\n        for yy in range(ratio_h):\n            for xx in range(ratio_w):\n                img_string += \"<|patch|>\" * num_patches_per_chunk\n                if xx < ratio_w - 1:\n                    img_string += \"<|tile_x_separator|>\"\n\n            img_string += \"<|tile_y_separator|>\"\n    img_string += \"<|image|>\"\n    img_string += \"<|patch|>\" * num_patches_per_chunk\n    img_string += \"<|image_end|>\"\n\n    return img_string\n\n\n# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60\ndef _prompt_split_image(\n    *,\n    image_seq_len: int,\n    image_rows: int,\n    image_cols: int,\n    fake_token_around_image: str,\n    image_token: str,\n    global_img_token: str,\n):\n    \"\"\"Prompt with expanded image tokens for when the image is split into patches.\"\"\"\n    text_split_images = \"\"\n    for n_h in range(image_rows):\n        for n_w in range(image_cols):\n            text_split_images += (\n                f\"{fake_token_around_image}\"\n                + f\"<row_{n_h + 1}_col_{n_w + 1}>\"\n                + f\"{image_token}\" * image_seq_len\n            )\n        text_split_images += \"\\n\"\n\n    text_split_images += (\n        f\"\\n{fake_token_around_image}\"\n        + f\"{global_img_token}\"\n        + f\"{image_token}\" * image_seq_len\n        + f\"{fake_token_around_image}\"\n    )\n    return text_split_images\n\n\ndef get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):\n    \"\"\"\n    Calculate the shape of the image patch grid after the preprocessing for images of any resolution.\n\n    Args:\n        image_size (`tuple`):\n            The size of the input image in the format (height, width).\n        grid_pinpoints (`List`):\n            A list containing possible resolutions. Each item in the list should be a tuple or list\n            of the form `(height, width)`.\n        patch_size (`int`):\n            The size of each image patch.\n\n    Returns:\n        tuple: The shape of the image patch grid in the format (width, height).\n    \"\"\"\n    if not isinstance(grid_pinpoints, list):\n        raise ValueError(\"grid_pinpoints should be a list of tuples or lists\")\n\n    height, width = select_best_resolution(image_size, grid_pinpoints)\n    return height // patch_size, width // patch_size\n\n\ndef image_text_replacement(processor, image_input, config) -> str:\n    if config.model_type == \"idefics2\":\n        image_seq_len = 64\n        image_str = f\"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}\"\n        if processor.image_processor.do_image_splitting:\n            image_str *= 5\n        return image_str, IDEFICS2_FAKE_TOKEN\n    if config.model_type == \"idefics3\":\n        # TODO: implement this in a more general way\n        n_rows = image_input[\"rows\"][0][0]\n        n_cols = image_input[\"cols\"][0][0]\n        image_seq_len = int(\n            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)\n            / (config.scale_factor**2)\n        )\n        image_str = _prompt_split_image(\n            image_seq_len=image_seq_len,\n            image_rows=n_rows,\n            image_cols=n_cols,\n            fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,\n            image_token=IDEFICS3_IMAGE_TOKEN,\n            global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,\n        )\n        return image_str, IDEFICS3_FAKE_IMAGE_TOKEN\n    elif config.model_type == \"llava_next\":\n        height, width = image_input[\"image_sizes\"][0]\n        num_features = get_number_of_features(height, width, config)\n\n        log_master(\n            logger.info,\n            f\"Found {num_features} features in image of resolution {height}x{width}\",\n        )\n        return \"<image>\" * num_features, \"<image>\"\n\n    elif config.model_type == \"paligemma\":\n        return \"<image>\" * config.text_config.num_image_tokens, \"<image>\"\n    elif config.model_type == \"qwen2_vl\":\n        grid_t, grid_h, grid_w = image_input[\"image_grid_thw\"][0]\n        num_pads = grid_t * grid_h * grid_w // 4\n        padding = \"<|image_pad|>\" * num_pads\n        return f\"<|vision_start|>{padding}<|vision_end|>\", \"<|vision_start|>\"\n    elif config.model_type == \"qwen2_5_vl\":\n        grid_t, grid_h, grid_w = image_input[\"image_grid_thw\"][0]\n        num_pads = grid_t * grid_h * grid_w // 4\n        padding = \"<|image_pad|>\" * num_pads\n        return f\"<|vision_start|>{padding}<|vision_end|>\", \"<|vision_start|>\"\n    elif config.model_type == \"gemma3\":\n        # TODO: get correct number of features via reviewing the Gemma3 architecture\n        # and calculating the number of image tokens\n        num_pads = 256\n        padding = \"<image_soft_token>\" * num_pads\n        return f\"\\n\\n<start_of_image>{padding}<end_of_image>\\n\\n\", \"<start_of_image>\"\n    elif config.model_type == \"llama4\":\n        patch_size = config.vision_config.patch_size\n        pixel_shuffle_ratio = config.vision_config.pixel_shuffle_ratio\n        downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))\n        aspect_ratios = image_input[\"aspect_ratios\"][0]\n        image_height, image_width = image_input[\"pixel_values\"][0].shape[-2:]\n\n        num_patches_per_chunk = int(\n            (image_height // patch_size)\n            * (image_width // patch_size)\n            // downsample_ratio\n        )\n        tokens_for_this_image = prompt_split_image_llama4(\n            aspect_ratios, num_patches_per_chunk\n        )\n\n        return tokens_for_this_image, \"<|image_start|>\"\n    else:\n        raise RuntimeError(f\"Unknown config {config.model_type} for multimodal\")\n\n\ndef image_text_replacement_fixup(config, text: str) -> str:\n    if config.model_type == \"idefics2\":\n        return text.replace(\n            f\"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}\", IDEFICS2_FAKE_TOKEN\n        )\n    return text\n\n\ndef preprocess_text(config, text: str) -> str:\n    if config.model_type == \"paligemma\":\n        return \"<bos>\" + text + \"\\n\"\n    return text\n\n\ndef preprocess_image(config, img):\n    model_type = config.model_type\n\n    if model_type in {\"qwen2_vl\", \"qwen2_5_vl\"} and img.width <= 20:\n        img = img.resize((img.width * 2, img.height * 2))\n    if model_type == \"paligemma\":\n        img = img.convert(\"RGB\")\n\n    if model_type not in {\"llava_next\", \"gemma3\", \"llama4\"}:\n        # TODO: check if this is needed\n        img = [img]\n\n    return img\n\n\ndef get_unpadded_features(\n    original_height: int,\n    original_width: int,\n    npatches: int,\n    num_patch_height: int,\n    num_patch_width: int,\n) -> Tuple[int, int]:\n    current_height = npatches * num_patch_height\n    current_width = npatches * num_patch_width\n\n    aspect_ratio: float = original_width / original_height\n    current_aspect_ratio: float = current_width / current_height\n\n    if aspect_ratio > current_aspect_ratio:\n        new_height = (original_height * current_width) // original_width\n        padding = (current_height - new_height) // 2\n        current_height = current_height - (2 * padding)\n    else:\n        new_width = (original_width * current_height) // original_height\n        padding = (current_width - new_width) // 2\n        current_width = current_width - (2 * padding)\n\n    unpadded_features = current_height * current_width\n    newline_features = current_height\n    return (unpadded_features, newline_features)\n\n\ndef get_number_of_features(height: int, width: int, config) -> int:\n    # From config\n    # Hardcoded for CLIP for now\n    # image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]\n    image_grid_pinpoints = config.image_grid_pinpoints\n    image_size = config.vision_config.image_size\n    patch_size = config.vision_config.patch_size\n\n    assert image_size % patch_size == 0\n\n    npatches = image_size // patch_size\n\n    # Dimensions are intentionally swapped to be bug-compatible with\n    # upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59\n    num_patch_width, num_patch_height = get_anyres_image_grid_shape(\n        [height, width],\n        image_grid_pinpoints,\n        image_size,\n    )\n    unpadded_features, newline_features = get_unpadded_features(\n        height, width, npatches, num_patch_height, num_patch_width\n    )\n    # The base patch covers the entire image\n    base_features = npatches**2\n    return unpadded_features + newline_features + base_features\n\n\ndef scatter_image_embeds(\n    embeds: torch.Tensor, is_embed: Optional[torch.Tensor]\n) -> torch.Tensor:\n    if is_embed is None:\n        return embeds\n\n    placeholders = embeds.new_full(\n        (is_embed.shape[0], embeds.shape[-1]),\n        fill_value=torch.nan,\n    )\n    placeholders[is_embed] = embeds\n    return placeholders\n\n\ndef gather_image_embeds(\n    embeds: torch.Tensor, is_embed: Optional[torch.Tensor]\n) -> Optional[torch.Tensor]:\n    if is_embed is None:\n        return embeds\n    sel = embeds[is_embed]\n    return sel if sel.numel() else None\n\n\n@dataclass\nclass ImagePositions:\n    offset: int\n    length: int\n    id: int\n    num_placeholder_tokens: int\n    is_embed: Optional[torch.Tensor] = None\n\n\nclass VlmCausalLMBatch(FlashCausalLMBatch):\n    image_inputs: Optional[List[List[Dict[str, torch.Tensor]]]]\n    image_positions: Optional[List[List[ImagePositions]]]\n    encoder_cache: Optional[List[Dict[int, torch.Tensor]]]\n    pixel_values: Optional[List[torch.Tensor]]\n    pixel_attention_mask: Optional[List[torch.Tensor]]\n    image_sizes: Optional[List[Tuple[int, int]]]\n    image_grid_thw: Optional[torch.Tensor]\n    cache_entries_to_free: List[Tuple[int, int]]\n    has_image_inputs: bool = False\n    inputs_embeds: Optional[torch.Tensor] = None\n\n    @classmethod\n    @tracer.start_as_current_span(\"concatenate\")\n    def concatenate(cls, batches):\n        batch = super(VlmCausalLMBatch, cls).concatenate(batches)\n\n        batch.image_inputs = []\n        batch.image_positions = []\n        batch.encoder_cache = []\n        for b in batches:\n            if b.image_inputs is not None:\n                batch.image_inputs.extend(b.image_inputs)\n            else:\n                batch.image_inputs.append(None)\n            if b.image_positions is not None:\n                batch.image_positions.extend(b.image_positions)\n            else:\n                batch.image_positions.append(None)\n            if b.encoder_cache is not None:\n                batch.encoder_cache.extend(b.encoder_cache)\n            else:\n                batch.encoder_cache.append(None)\n\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n        batch.image_grid_thw = None\n        batch.inputs_embeds = None\n\n        # To be filled in prepare_for_prefill\n        batch.has_image_inputs = False\n        batch.cache_entries_to_free = []\n\n        return batch\n\n    @tracer.start_as_current_span(\"filter\")\n    def filter(self, request_ids: List[int]):\n        if len(request_ids) == 0:\n            raise ValueError(\"Batch must have at least one request\")\n\n        image_inputs = []\n        image_positions = []\n        encoder_cache = []\n\n        for request_id in request_ids:\n            idx = self.requests_idx_mapping[request_id]\n            image_inputs.append(self.image_inputs[idx])\n            image_positions.append(self.image_positions[idx])\n            encoder_cache.append(self.encoder_cache[idx])\n\n        batch = super().filter(request_ids)\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n        batch.image_grid_thw = None\n        batch.inputs_embeds = None\n\n        batch.image_inputs = image_inputs\n        batch.image_positions = image_positions\n        batch.encoder_cache = encoder_cache\n\n        # To be filled in prepare_for_prefill\n        batch.has_image_inputs = False\n        batch.cache_entries_to_free = []\n        return batch\n\n    @classmethod\n    def batch_tokenized_inputs(\n        cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config\n    ):\n        kwargs = {}\n        if (\n            hasattr(processor, \"image_processor_class\")\n            and processor.image_processor_class == \"Idefics3ImageProcessor\"\n        ):\n            kwargs[\"return_row_col_info\"] = True\n\n        max_length = 0\n        vocab = tokenizer.get_vocab()\n\n        if not hasattr(config, \"image_token_index\"):\n            config.image_token_index = config.image_token_id\n\n        batch_tokenized_inputs: List[List[int]] = []\n        batch_image_inputs: List[Optional[List[dict]]] = []\n        batch_image_positions: List[Optional[List[ImagePositions]]] = []\n\n        for r in requests:\n            text_parts = []\n            image_inputs = []\n            image_texts = []\n\n            image_id = 0\n\n            for chunk in r.input_chunks.chunks:\n                chunk_type = chunk.WhichOneof(\"chunk\")\n                if chunk_type == \"text\":\n                    text = preprocess_text(config, chunk.text)\n                    text_parts.append(text)\n                elif chunk_type == \"image\":\n                    img = Image.open(BytesIO(chunk.image.data))\n                    img = preprocess_image(config, img)\n\n                    image_input = processor.image_processor(\n                        [img], return_tensors=\"pt\", **kwargs\n                    )\n                    image_inputs.append(image_input)\n\n                    img_text, img_start_token_str = image_text_replacement(\n                        processor, image_input, config\n                    )\n                    text_parts.append(img_text)\n\n                    image_texts.append([image_id, img_start_token_str, img_text])\n                    image_id += 1\n                else:\n                    raise RuntimeError(f\"Invalid chunk type {chunk_type}\")\n\n            full_text = image_text_replacement_fixup(config, \"\".join(text_parts))\n            input_ids = tokenizer(\n                full_text,\n                truncation=True,\n                max_length=r.truncate,\n                add_special_tokens=(\n                    r.add_special_tokens if config.model_type != \"paligemma\" else False\n                ),\n            )[\"input_ids\"]\n            max_length = max(max_length, len(input_ids))\n\n            if len(image_inputs) > 0:\n                img_start_token = vocab[image_texts[0][1]]\n                image_positions = cls.get_image_positions(\n                    input_ids, image_texts, img_start_token, config, tokenizer\n                )\n            else:\n                image_inputs = None\n                image_positions = None\n\n            batch_tokenized_inputs.append(input_ids)\n            batch_image_inputs.append(image_inputs)\n            batch_image_positions.append(image_positions)\n\n        return batch_tokenized_inputs, batch_image_inputs, batch_image_positions\n\n    @classmethod\n    def get_image_positions(\n        cls,\n        input_ids: List[int],\n        image_texts: List[Tuple[int, str, str]],\n        img_start_token: int,\n        config,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> List[ImagePositions]:\n        image_positions = []\n        num_images = len(image_texts)\n\n        input_ids_t = torch.as_tensor(input_ids)\n        img_start_token_pos = torch.where(input_ids_t.eq(img_start_token))[0]\n        num_tokens = input_ids_t.numel()\n\n        last_pos = 0\n        for i in range(num_images):\n            image_id, img_start_token_str, img_text = image_texts[i]\n            img_text = image_text_replacement_fixup(config, img_text)\n\n            if config.model_type == \"gemma3\":\n                img_text = img_text.replace(\"\\n\\n\", \"\")\n\n            tokens = tokenizer(img_text, add_special_tokens=False, return_tensors=\"pt\")[\n                \"input_ids\"\n            ][0]\n            length = tokens.numel()\n\n            assert (\n                length <= num_tokens\n            ), f\"{length} > {num_tokens} Image is truncated, try increasing --max-batch-prefill-tokens\"\n\n            pos = torch.searchsorted(img_start_token_pos, last_pos, right=False)\n            index = img_start_token_pos[pos]\n            assert torch.equal(\n                input_ids_t[index : index + length], tokens\n            ), \"Image tokens not found in input_ids\"\n\n            is_embed = tokens == config.image_token_index\n            num_placeholder_tokens = int(is_embed.sum())\n            if num_placeholder_tokens == length:\n                is_embed = None\n\n            pos = ImagePositions(\n                offset=index,\n                length=length,\n                id=image_id,\n                num_placeholder_tokens=num_placeholder_tokens,\n                is_embed=is_embed,\n            )\n\n            image_positions.append(pos)\n            last_pos = index + length\n\n            if (\n                config.model_type == \"idefics2\"\n                and i + 1 != num_images\n                and input_ids[last_pos] == config.image_token_index\n            ):\n                fake_token = last_pos - 1\n                fake_token_index = torch.searchsorted(\n                    img_start_token_pos, fake_token, right=False\n                )\n                img_start_token_pos[fake_token_index] = last_pos\n                image_texts[i + 1][2] = image_texts[i + 1][2][\n                    len(img_start_token_str) :\n                ]\n\n        return image_positions\n\n    @classmethod\n    def from_pb_processor(\n        cls,\n        pb: generate_pb2.Batch,\n        tokenizer: PreTrainedTokenizerBase,\n        processor,\n        config,\n        dtype: torch.dtype,\n        device: torch.device,\n    ) -> \"VlmCausalLMBatch\":\n        batch_tokenized_inputs, image_inputs, image_positions = (\n            cls.batch_tokenized_inputs(pb.requests, tokenizer, processor, config)\n        )\n        batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)\n        batch.image_inputs = image_inputs\n        batch.image_positions = image_positions\n        batch.encoder_cache = [{} for _ in range(len(pb.requests))]\n        if len(image_inputs):\n            batch.pixel_values = None\n            batch.pixel_attention_mask = None\n            batch.image_sizes = None\n            batch.image_grid_thw = None\n        return batch\n\n    def prepare_for_prefill(self):\n        super().prepare_for_prefill()\n\n        self.has_image_inputs = False\n        self.cache_entries_to_free = []\n\n        self.pixel_values = []\n\n        assert (\n            len(self.cache_lengths)\n            == len(self.input_lengths)\n            == len(self.prefilling_mask)\n        ), \"Mismatch in lengths of cache_lengths, input_lengths, and prefilling_mask\"\n\n        for i, (\n            cache_length,\n            input_length,\n            request_prefilling,\n        ) in enumerate(\n            zip(\n                self.cache_lengths,\n                self.input_lengths,\n                self.prefilling_mask,\n            )\n        ):\n            if not request_prefilling or self.image_positions[i] is None:\n                continue\n\n            for image_position in self.image_positions[i]:\n                if image_position is None:\n                    continue\n                start_pos = image_position.offset\n                length = image_position.length\n\n                if start_pos >= cache_length + input_length:\n                    # No encoder input required at this step\n                    break\n                if start_pos + length <= cache_length:\n                    # The encode input is already processed\n                    continue\n\n                self.has_image_inputs = True\n\n                if image_position.id not in self.encoder_cache[i]:\n                    image_inputs = self.image_inputs[i][image_position.id]\n                    self.pixel_values.append((i, image_position.id, image_inputs))\n\n                    # Remove the image from the image_inputs\n                    self.image_inputs[i][image_position.id] = None\n\n        if not self.has_image_inputs:\n            self.pixel_values = None\n            self.pixel_attention_mask = None\n            self.image_sizes = None\n            self.image_grid_thw = None\n        else:\n            image_grid_thw_list = [\n                x[2][\"image_grid_thw\"]\n                for x in self.pixel_values\n                if \"image_grid_thw\" in x[2]\n            ]\n            if image_grid_thw_list:\n                self.image_grid_thw = torch.cat(image_grid_thw_list, dim=0).to(\n                    self.input_ids.device\n                )\n            else:\n                self.image_grid_thw = None\n\n    def update_encoder_cache(self, encoder_outputs, request_id, img_pos):\n        self.encoder_cache[request_id][img_pos.id] = scatter_image_embeds(\n            encoder_outputs, img_pos.is_embed\n        )\n\n    def gather_vision_embeds(self):\n        device = self.input_ids.device\n        chunks = []\n        for (\n            i,\n            cache_length,\n            input_length,\n            request_prefilling,\n        ) in zip(\n            range(len(self.requests)),\n            self.cache_lengths,\n            self.input_lengths,\n            self.prefilling_mask,\n        ):\n            if not request_prefilling or self.image_positions[i] is None:\n                continue\n\n            for image_position in self.image_positions[i]:\n                if image_position is None:\n                    continue\n                start_pos = image_position.offset\n                length = image_position.length\n\n                if start_pos >= cache_length + input_length:\n                    # No encoder input required at this step\n                    break\n                if start_pos + length <= cache_length:\n                    # The encode input is already processed\n                    continue\n\n                start_idx = max(cache_length - start_pos, 0)\n                end_idx = min(cache_length - start_pos + input_length, length)\n\n                assert (\n                    image_position.id in self.encoder_cache[i]\n                ), f\"image_id {image_position.id} not in encoder_cache {self.encoder_cache[i]}\"\n                encoder_output = self.encoder_cache[i][image_position.id]\n\n                is_embed = image_position.is_embed\n                if is_embed is not None:\n                    is_embed = is_embed[start_idx:end_idx]\n\n                from loguru import logger\n\n                logger.info(\n                    f\"image_id {image_position.id} start_idx {start_idx} end_idx {end_idx}, length {length}\"\n                )\n\n                embeds = gather_image_embeds(\n                    encoder_output[start_idx:end_idx],\n                    is_embed=is_embed,\n                )\n                if embeds is not None:\n                    chunks.append(embeds)\n\n                if end_idx == length:\n                    self.cache_entries_to_free.append((i, image_position.id))\n                    self.image_positions[i][image_position.id] = None\n\n        if len(chunks) == 0:\n            return None\n        return torch.cat(chunks, dim=0).to(device)\n\n    def free_encoder_cache(self):\n        for i, image_id in self.cache_entries_to_free:\n            self.encoder_cache[i].pop(image_id, None)\n\n        self.cache_entries_to_free = []\n\n        # release any freed GPU memory immediately?\n\n\nclass VlmCausalLM(FlashCausalLM):\n    def __init__(\n        self,\n        model_id: str,\n        *,\n        processor_class=AutoProcessor,\n        processor_kwargs=None,\n        batch_class=VlmCausalLMBatch,\n        revision,\n        trust_remote_code: bool,\n        support_chunking: bool = True,\n        **kwargs,\n    ):\n        if PREFIX_CACHING:\n            raise NotImplementedError(\"Vlm do not work with prefix caching yet\")\n        if processor_kwargs is None:\n            processor_kwargs = {}\n        self.processor = processor_class.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n            **processor_kwargs,\n        )\n        self.batch_class = batch_class\n        super().__init__(\n            model_id=model_id,\n            revision=revision,\n            trust_remote_code=trust_remote_code,\n            support_chunking=support_chunking,\n            **kwargs,\n        )\n\n    @property\n    def batch_type(self) -> Type[VlmCausalLMBatch]:\n        return self.batch_class\n\n    def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):\n        max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None\n        input_lengths = [max_s] * bs\n        cache_lengths = [0] * bs\n        config = getattr(self.model.config, \"text_config\", self.model.config)\n        if max_bs is None:\n            inputs_embeds = torch.zeros(\n                (bs, config.hidden_size),\n                device=self.device,\n                dtype=self.dtype,\n            )\n            position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)\n            config = getattr(self.model, \"config\", None)\n            rope_scaling = getattr(config, \"rope_scaling\", None) if config else None\n            if (  # mrope have position_ids per section, if so repeat n times\n                isinstance(rope_scaling, dict) and rope_scaling[\"rope_type\"] == \"mrope\"\n            ):\n                n_sections = len(self.model.config.rope_scaling[\"mrope_section\"])\n                position_ids = position_ids.unsqueeze(1).repeat(1, n_sections)\n            slots = torch.arange(bs, dtype=torch.int64, device=self.device)\n            input_lengths_tensor = (\n                torch.ones(bs, dtype=torch.int32, device=self.device) * max_s\n            )\n            cache_lengths_tensor = torch.zeros(\n                bs, dtype=torch.int32, device=self.device\n            )\n            block_tables = torch.arange(\n                max_bt, dtype=torch.int32, device=self.device\n            ).repeat(bs)\n            block_tables = block_tables.reshape((bs, max_bt))\n            if ATTENTION == \"flashinfer\":\n                block_tables = block_tables_to_ragged(\n                    block_tables=block_tables,\n                    input_lengths=input_lengths,\n                    cache_lengths=cache_lengths,\n                    input_lengths_tensor=input_lengths_tensor,\n                    cache_lengths_tensor=cache_lengths_tensor,\n                    max_current_length=max_s,\n                )\n        else:\n            if bs > max_bs:\n                raise RuntimeError(\n                    \"Cuda graphs should be generated in decreasing order size to reduce VRAM usage\"\n                )\n            inputs_embeds = self.cuda_graphs[max_bs][\"inputs_embeds\"][:bs]\n            position_ids = self.cuda_graphs[max_bs][\"position_ids\"][:bs]\n            if ATTENTION == \"flashinfer\":\n                block_tables = self.cuda_graphs[max_bs][\"block_tables\"][: bs * max_bt]\n            else:\n                block_tables = self.cuda_graphs[max_bs][\"block_tables\"][:bs]\n            slots = self.cuda_graphs[max_bs][\"slots\"][:bs]\n            input_lengths_tensor = self.cuda_graphs[max_bs][\"input_lengths\"][:bs]\n            cache_lengths_tensor = self.cuda_graphs[max_bs][\"cache_lengths\"][:bs]\n\n        if ATTENTION == \"flashinfer\":\n            from text_generation_server.layers.attention.flashinfer import (\n                create_decode_state_cuda_graphs,\n            )\n\n            block_tables_ptr = torch.zeros(\n                bs + 1, dtype=torch.int32, device=self.device\n            )\n            last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)\n            state = create_decode_state_cuda_graphs(\n                device=inputs_embeds.device,\n                block_tables=block_tables,\n                block_tables_ptr=block_tables_ptr,\n                last_page_len=last_page_len,\n                num_heads=self.num_heads,\n                num_kv_heads=self.num_kv_heads,\n            )\n        else:\n            state = None\n\n        graph = torch.cuda.CUDAGraph()\n        self.cuda_graphs[bs] = {\n            \"inputs_embeds\": inputs_embeds,\n            \"position_ids\": position_ids,\n            \"kv_cache\": self.kv_cache,\n            \"block_tables\": block_tables,\n            \"slots\": slots,\n            \"input_lengths\": input_lengths_tensor,\n            \"cache_lengths\": cache_lengths_tensor,\n            \"state\": state,\n            \"graph\": graph,\n        }\n\n        torch.cuda.synchronize()\n        # Run once outside to warmup\n        with self._forward_context(\n            block_tables=block_tables,\n            cu_seqlen_prefill=None,\n            input_lengths_tensor=input_lengths_tensor,\n            state=state,\n            cache_lengths_tensor=cache_lengths_tensor,\n        ):\n            seqlen = Seqlen(\n                input_lengths=input_lengths_tensor,\n                cache_lengths=cache_lengths_tensor,\n                cu_seqlen_q=None,\n                max_q=1,\n                max_k=max_s,\n            )\n            self.model.forward(\n                inputs_embeds=inputs_embeds,\n                position_ids=position_ids,\n                cu_seqlen_prefill=None,\n                kv_cache=self.kv_cache,\n                block_tables=block_tables,\n                slots=slots,\n                seqlen=seqlen,\n                max_s=max_s,\n                prefill_cache_indices=None,\n                lm_head_indices=None,\n            )\n            del seqlen\n\n            torch.cuda.synchronize()\n\n            with torch.cuda.graph(graph, pool=MEM_POOL):\n                seqlen = Seqlen(\n                    input_lengths=input_lengths_tensor,\n                    cache_lengths=cache_lengths_tensor,\n                    cu_seqlen_q=None,\n                    max_q=1,\n                    max_k=max_s,\n                )\n                logits, speculative_logits = self.model.forward(\n                    inputs_embeds=inputs_embeds,\n                    position_ids=position_ids,\n                    cu_seqlen_prefill=None,\n                    kv_cache=self.kv_cache,\n                    block_tables=block_tables,\n                    slots=slots,\n                    seqlen=seqlen,\n                    max_s=max_s,\n                    prefill_cache_indices=None,\n                    lm_head_indices=None,\n                )\n                self.cuda_graphs[bs][\"logits\"] = logits\n                self.cuda_graphs[bs][\"speculative_logits\"] = speculative_logits\n        torch.cuda.synchronize()\n\n    def get_vision_embeds(\n        self,\n        pixel_values: torch.Tensor,\n        pixel_attention_mask: torch.Tensor,\n        image_sizes: torch.Tensor,\n        image_grid_thw: torch.Tensor,\n    ):\n        embeds = self.model.get_vision_embeds(\n            pixel_values=pixel_values,\n            pixel_attention_mask=pixel_attention_mask,\n            image_sizes=image_sizes,\n            image_grid_thw=image_grid_thw,\n        )\n        return embeds\n\n    def get_inputs_embeds(\n        self,\n        input_ids: torch.Tensor,\n        vision_embeds: Optional[torch.Tensor] = None,\n    ):\n        return self.model.get_inputs_embeds(\n            input_ids=input_ids,\n            vision_embeds=vision_embeds,\n        )\n\n    def encode_images(self, batch):\n        if batch.pixel_values is not None:\n            device = batch.input_ids.device\n            for request_id, image_id, image_input in batch.pixel_values:\n                pixel_values = image_input[\"pixel_values\"].to(device)\n\n                if \"pixel_attention_mask\" in image_input:\n                    pixel_attention_mask = image_input[\"pixel_attention_mask\"].to(\n                        device\n                    )\n                else:\n                    pixel_attention_mask = None\n\n                if \"image_sizes\" in image_input:\n                    image_sizes = image_input[\"image_sizes\"].to(device)\n                else:\n                    image_sizes = None\n\n                if \"image_grid_thw\" in image_input:\n                    image_grid_thw = image_input[\"image_grid_thw\"].to(device)\n                else:\n                    image_grid_thw = None\n\n                encoder_outputs = self.get_vision_embeds(\n                    pixel_values=pixel_values,\n                    pixel_attention_mask=pixel_attention_mask,\n                    image_sizes=image_sizes,\n                    image_grid_thw=image_grid_thw,\n                )\n                batch.update_encoder_cache(\n                    encoder_outputs,\n                    request_id,\n                    batch.image_positions[request_id][image_id],\n                )\n\n        batch.pixel_values = None\n        batch.pixel_attention_mask = None\n        batch.image_sizes = None\n\n    def set_inputs_embeds(self, batch):\n        if batch.has_image_inputs:\n            self.encode_images(batch)\n            vision_embeds = batch.gather_vision_embeds()\n            batch.has_image_inputs = False\n        else:\n            vision_embeds = None\n\n        inputs_embeds = self.get_inputs_embeds(\n            batch.input_ids, vision_embeds=vision_embeds\n        )\n\n        batch.inputs_embeds = inputs_embeds\n\n    def forward(\n        self,\n        batch: VlmCausalLMBatch,\n        adapter_data: Optional[Dict[str, torch.Tensor]] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        # Model Forward\n        if batch.speculative_ids is not None:\n            input_ids = batch.input_ids\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n            speculative_ids = batch.speculative_ids\n\n            B, speculative_length = speculative_ids.shape\n            new_length = speculative_length + 1\n            new_input_ids = torch.cat(\n                [input_ids.unsqueeze(-1), speculative_ids], dim=1\n            ).reshape(-1)\n            arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)\n            arange_int = arange.to(dtype=torch.int32)\n            new_position_ids = (\n                position_ids.unsqueeze(-1).expand(B, new_length) + arange\n            ).view(-1)\n            slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)\n            input_lengths = (\n                input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int\n            ).view(-1)\n            cache_lengths_tensor = (\n                batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)\n            ).reshape(-1)\n\n            # Add Copy the block tables for all members\n            block_tables = (\n                block_tables.unsqueeze(1)\n                .expand(B, new_length, -1)\n                .reshape(B * new_length, -1)\n                .contiguous()\n            )\n            max_s = max_s + speculative_length\n\n            input_ids = new_input_ids\n            position_ids = new_position_ids\n        else:\n            input_ids = batch.input_ids\n            inputs_embeds = batch.inputs_embeds\n            position_ids = batch.position_ids\n            cu_seqlen_prefill = batch.cu_seqlen_prefill\n            kv_cache = self.kv_cache\n            block_tables = batch.block_tables_tensor\n            slots = batch.slots[batch.slot_indices]\n            input_lengths = batch.input_lengths_tensor\n            cache_lengths_tensor = batch.cache_lengths_tensor\n            max_s = batch.max_current_length\n            lm_head_indices = batch.prefill_head_indices\n\n        if self.model.config.model_type in {\"qwen2_vl\", \"qwen2_5_vl\"}:\n            if position_ids.dim() == 1 and batch.prefilling:\n                position_ids = self.model.get_position_ids(\n                    input_ids, batch.image_grid_thw\n                )\n                batch.position_ids = position_ids\n\n        attention_mask = None\n        attention_mask_forward = None\n        if self.model.config.model_type == \"gemma3\" and cu_seqlen_prefill is not None:\n            attention_mask = self.model.get_attention_mask(\n                input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True\n            )\n            min_dtype = torch.finfo(self.dtype).min\n            attention_mask_forward = torch.where(attention_mask, 0, min_dtype).to(\n                input_ids.device\n            )\n            attention_mask = attention_mask.reshape(-1)\n\n        # Try to find an associated cuda graph\n        bs = input_ids.shape[0]\n        sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])\n        if sorted_padded_bs:\n            # Get associated cuda graph\n            cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]\n        else:\n            cuda_graph = None\n        if cu_seqlen_prefill is not None or cuda_graph is None:\n            if ATTENTION == \"flashinfer\":\n                block_tables = block_tables_to_ragged(\n                    block_tables=block_tables,\n                    input_lengths=batch.input_lengths,\n                    cache_lengths=batch.cache_lengths,\n                    input_lengths_tensor=batch.input_lengths_tensor,\n                    cache_lengths_tensor=batch.cache_lengths_tensor,\n                    max_current_length=batch.max_current_length,\n                )\n            with self._forward_context(\n                block_tables=block_tables,\n                cu_seqlen_prefill=cu_seqlen_prefill,\n                input_lengths_tensor=input_lengths,\n                cache_lengths_tensor=cache_lengths_tensor,\n                attention_mask=attention_mask,\n            ):\n                seqlen = Seqlen(\n                    input_lengths=input_lengths,\n                    cache_lengths=cache_lengths_tensor,\n                    cu_seqlen_q=cu_seqlen_prefill,\n                    max_q=batch.max_input_length,\n                    max_k=batch.max_current_length,\n                )\n                logits, speculative_logits = self.model.forward(\n                    inputs_embeds=inputs_embeds,\n                    position_ids=position_ids,\n                    cu_seqlen_prefill=cu_seqlen_prefill,\n                    kv_cache=kv_cache,\n                    block_tables=block_tables,\n                    slots=slots,\n                    seqlen=seqlen,\n                    max_s=max_s,\n                    prefill_cache_indices=batch.prefill_cache_indices,\n                    lm_head_indices=lm_head_indices,\n                    attention_mask=attention_mask_forward,\n                )\n                if batch.prefill_cache_indices is not None:\n                    batch.prefill_cache_indices = None\n                batch.image_grid_thw = None\n                batch.free_encoder_cache()\n                return logits, speculative_logits\n\n        # Copy inputs to the static inputs of the cuda graph\n        # Static inputs are potentially padded\n        cuda_graph[\"inputs_embeds\"][: inputs_embeds.shape[0]] = inputs_embeds\n        cuda_graph[\"position_ids\"][: position_ids.shape[0]] = position_ids\n        if ATTENTION == \"flashinfer\":\n            block_tables = block_tables_to_ragged(\n                block_tables=block_tables,\n                input_lengths=batch.input_lengths,\n                cache_lengths=batch.cache_lengths,\n                input_lengths_tensor=batch.input_lengths_tensor,\n                cache_lengths_tensor=batch.cache_lengths_tensor,\n                max_current_length=batch.max_current_length,\n            )\n            cuda_graph[\"block_tables\"][: block_tables.shape[0]] = block_tables\n        else:\n            cuda_graph[\"block_tables\"][\n                : block_tables.shape[0], : block_tables.shape[1]\n            ] = block_tables\n\n        # XXX: This is working only because block 0 is reserved for the healthcheck\n        # so it doesn't matter if we override it with bogus values.\n        cuda_graph[\"slots\"].fill_(0)\n        cuda_graph[\"slots\"][: slots.shape[0]] = slots\n        cuda_graph[\"input_lengths\"].zero_()\n        cuda_graph[\"input_lengths\"][: input_lengths.shape[0]] = input_lengths\n        cuda_graph[\"cache_lengths\"].zero_()\n        cuda_graph[\"cache_lengths\"][\n            : cache_lengths_tensor.shape[0]\n        ] = cache_lengths_tensor\n\n        with self._forward_context(\n            block_tables=cuda_graph[\"block_tables\"],\n            cu_seqlen_prefill=None,\n            input_lengths_tensor=cuda_graph[\"input_lengths\"],\n            cache_lengths_tensor=cuda_graph[\"cache_lengths\"],\n            state=cuda_graph[\"state\"],\n        ):\n            # Replay the graph\n            cuda_graph[\"graph\"].replay()\n\n        # Slice output to the correct shape\n        speculative_logits = (\n            cuda_graph[\"speculative_logits\"][:bs]\n            if cuda_graph[\"speculative_logits\"] is not None\n            else None\n        )\n        logits = cuda_graph[\"logits\"][:bs]\n\n        batch.free_encoder_cache()\n        return logits, speculative_logits\n"
  },
  {
    "path": "server/text_generation_server/pb/.gitignore",
    "content": "*.py\n*.pyi\n*.py-e\n"
  },
  {
    "path": "server/text_generation_server/server.py",
    "content": "import asyncio\nimport os\nimport torch\nimport time\nimport signal\n\nfrom grpc import aio\nfrom loguru import logger\n\nfrom grpc_reflection.v1alpha import reflection\nfrom pathlib import Path\nfrom typing import List, Optional\n\nfrom text_generation_server.cache import Cache\nfrom text_generation_server.interceptor import ExceptionInterceptor\nfrom text_generation_server.models import Model, get_model_with_lora_adapters\nfrom text_generation_server.utils.adapter import AdapterInfo\nfrom text_generation_server.utils.prefill_chunking import set_max_prefill_tokens\n\ntry:\n    from text_generation_server.models.vlm_causal_lm import (\n        VlmCausalLMBatch,\n    )\n    from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch\n    from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch\n\n    VLM_BATCH_TYPES = {\n        VlmCausalLMBatch,\n        IdeficsCausalLMBatch,\n        MllamaCausalLMBatch,\n    }\nexcept (ImportError, NotImplementedError):\n    # These imports can fail on CPU/Non flash.\n    VLM_BATCH_TYPES = set()\n\nfrom text_generation_server.pb import generate_pb2_grpc, generate_pb2\nfrom text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor\nfrom text_generation_server.models.globals import set_adapter_to_index\n\n\nclass SignalHandler:\n    KEEP_PROCESSING = True\n\n    def __init__(self):\n        signal.signal(signal.SIGINT, self.exit_gracefully)\n        signal.signal(signal.SIGTERM, self.exit_gracefully)\n\n    def set_keep_processing(self, value: bool):\n        self.KEEP_PROCESSING = value\n\n    def exit_gracefully(self, signum, frame):\n        print(f\"Exiting gracefully: Signal {signum}\")\n        self.set_keep_processing(False)\n\n\nclass TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):\n    def __init__(\n        self,\n        model: Model,\n        cache: Cache,\n        server_urls: List[str],\n    ):\n        self.cache = cache\n        self.model = model\n        # Quantize is resolved during model loading\n        self.quantize = model.quantize\n        self.server_urls = server_urls\n        # For some reason, inference_mode does not work well with GLOO which we use on CPU\n        # if model.device.type == \"cuda\":\n        #     # Force inference mode for the lifetime of TextGenerationService\n        #     self._inference_mode_raii_guard = torch._C._InferenceMode(True)\n\n    async def Info(self, request, context):\n        return self.model.info\n\n    async def Health(self, request, context):\n        if self.model.device.type == \"cuda\":\n            torch.zeros((2, 2)).cuda()\n        return generate_pb2.HealthResponse()\n\n    async def ServiceDiscovery(self, request, context):\n        return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)\n\n    async def ClearCache(self, request, context):\n        if request.HasField(\"id\"):\n            self.cache.delete(request.id)\n        else:\n            self.cache.clear()\n        return generate_pb2.ClearCacheResponse()\n\n    async def FilterBatch(self, request, context):\n        batch = self.cache.pop(request.batch_id)\n        if batch is None:\n            raise ValueError(f\"Batch ID {request.batch_id} not found in cache.\")\n        filtered_batch = batch.filter(request.request_ids)\n        self.cache.set(filtered_batch)\n\n        return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())\n\n    async def Warmup(self, request, context):\n        set_max_prefill_tokens(request.max_prefill_tokens)\n\n        if self.quantize in {\"exl2\", \"gptq\"}:\n            try:\n                # When using GPTQ, Exllama kernels need some global kernels\n                # For which we have the finale shapes only after the model has loaded\n                # This will allocate those buffers.\n                from text_generation_server.layers.gptq import (\n                    create_exllama_buffers,\n                    set_device,\n                )\n\n                set_device(self.model.device)\n                create_exllama_buffers(request.max_prefill_tokens)\n            except ImportError:\n                pass\n\n        if (\n            self.model.batch_type in VLM_BATCH_TYPES\n        ):  # Hack, i would rather use kwargs in the `from_pb` call\n            batch = self.model.batch_type.from_pb_processor(\n                request.batch,\n                self.model.tokenizer,\n                self.model.processor,\n                self.model.model.config,\n                self.model.dtype,\n                self.model.device,\n            )\n        else:\n            batch = self.model.batch_type.from_pb(\n                request.batch, self.model.tokenizer, self.model.dtype, self.model.device\n            )\n\n        # Override default values with None for clearer semantics.\n        max_input_tokens = (\n            request.max_input_tokens if request.HasField(\"max_input_tokens\") else None\n        )\n        max_total_tokens = (\n            request.max_total_tokens if request.HasField(\"max_total_tokens\") else None\n        )\n        max_supported_total_tokens, max_input_tokens, max_total_tokens = (\n            self.model.warmup(batch, max_input_tokens, max_total_tokens)\n        )\n\n        return generate_pb2.WarmupResponse(\n            max_supported_total_tokens=max_supported_total_tokens,\n            max_input_tokens=max_input_tokens,\n            max_total_tokens=max_total_tokens,\n        )\n\n    async def Prefill(self, request, context):\n        start = time.time_ns()\n        if (\n            self.model.batch_type in VLM_BATCH_TYPES\n        ):  # Hack, i would rather use kwargs in the `from_pb` call\n            batch = self.model.batch_type.from_pb_processor(\n                request.batch,\n                self.model.tokenizer,\n                self.model.processor,\n                self.model.model.config,\n                self.model.dtype,\n                self.model.device,\n            )\n        else:\n            batch = self.model.batch_type.from_pb(\n                request.batch, self.model.tokenizer, self.model.dtype, self.model.device\n            )\n\n        concat_ns = None\n        if self.model.support_chunking:\n            if request.HasField(\"cached_batch\"):\n                cached_batch = self.cache.pop(request.cached_batch.id)\n                if cached_batch is None:\n                    raise ValueError(\n                        f\"Batch ID {request.cached_batch.id} not found in cache.\"\n                    )\n                start_concat = time.time_ns()\n                batch = self.model.batch_type.concatenate([cached_batch, batch])\n                concat_ns = time.time_ns() - start_concat\n\n        generations, next_batch, timings = self.model.generate_token(batch)\n        self.cache.set(next_batch)\n\n        return generate_pb2.PrefillResponse(\n            generations=[generation.to_pb() for generation in generations],\n            batch=next_batch.to_pb() if next_batch else None,\n            forward_ns=timings[0],\n            decode_ns=timings[1],\n            total_ns=time.time_ns() - start,\n            concat_ns=concat_ns,\n        )\n\n    async def Decode(self, request, context):\n        start = time.time_ns()\n        if len(request.batches) == 0:\n            raise ValueError(\"Must provide at least one batch\")\n\n        batches = []\n        for batch_pb in request.batches:\n            batch = self.cache.pop(batch_pb.id)\n            if batch is None:\n                raise ValueError(f\"Batch ID {batch_pb.id} not found in cache.\")\n            batches.append(batch)\n\n        if len(batches) == 0:\n            raise ValueError(\"All batches are empty\")\n\n        if len(batches) > 1:\n            start_concat = time.time_ns()\n            batch = self.model.batch_type.concatenate(batches)\n            concat_ns = time.time_ns() - start_concat\n        else:\n            batch = batches[0]\n            concat_ns = None\n\n        generations, next_batch, timings = self.model.generate_token(batch)\n        self.cache.set(next_batch)\n\n        return generate_pb2.DecodeResponse(\n            generations=[generation.to_pb() for generation in generations],\n            batch=next_batch.to_pb() if next_batch else None,\n            concat_ns=concat_ns,\n            forward_ns=timings[0],\n            decode_ns=timings[1],\n            total_ns=time.time_ns() - start,\n        )\n\n\ndef serve(\n    model_id: str,\n    lora_adapters: Optional[List[AdapterInfo]],\n    revision: Optional[str],\n    sharded: bool,\n    quantize: Optional[str],\n    speculate: Optional[int],\n    dtype: Optional[str],\n    kv_cache_dtype: Optional[str],\n    trust_remote_code: bool,\n    uds_path: Path,\n    max_input_tokens: int,\n):\n    async def serve_inner(\n        model_id: str,\n        lora_adapters: Optional[List[AdapterInfo]],\n        revision: Optional[str],\n        sharded: bool = False,\n        quantize: Optional[str] = None,\n        speculate: Optional[int] = None,\n        dtype: Optional[str] = None,\n        kv_cache_dtype: Optional[str] = None,\n        trust_remote_code: bool = False,\n    ):\n        unix_socket_template = \"unix://{}-{}\"\n        adapter_to_index = {}\n        if sharded:\n            server_urls = [\n                unix_socket_template.format(uds_path, rank)\n                for rank in range(int(os.environ[\"WORLD_SIZE\"]))\n            ]\n            local_url = server_urls[int(os.environ[\"RANK\"])]\n        else:\n            local_url = unix_socket_template.format(uds_path, 0)\n            server_urls = [local_url]\n\n        try:\n            model = get_model_with_lora_adapters(\n                model_id,\n                lora_adapters,\n                revision,\n                sharded,\n                quantize,\n                speculate,\n                dtype,\n                kv_cache_dtype,\n                trust_remote_code,\n                max_input_tokens,\n                adapter_to_index,\n            )\n\n        except Exception:\n            logger.exception(\"Error when initializing model\")\n            raise\n\n        signal_handler = SignalHandler()\n\n        set_adapter_to_index(adapter_to_index)\n        server = aio.server(\n            interceptors=[\n                ExceptionInterceptor(lambda: signal_handler.set_keep_processing(False)),\n                UDSOpenTelemetryAioServerInterceptor(),\n            ],\n            options=[\n                # Set the maximum possible message length: i32::MAX\n                (\"grpc.max_receive_message_length\", (1 << 31) - 1)\n            ],\n        )\n        generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(\n            TextGenerationService(model, Cache(), server_urls), server\n        )\n        SERVICE_NAMES = (\n            generate_pb2.DESCRIPTOR.services_by_name[\"TextGenerationService\"].full_name,\n            reflection.SERVICE_NAME,\n        )\n        reflection.enable_server_reflection(SERVICE_NAMES, server)\n        server.add_insecure_port(local_url)\n\n        await server.start()\n\n        logger.info(\"Server started at {}\".format(local_url))\n        while signal_handler.KEEP_PROCESSING:\n            await asyncio.sleep(0.5)\n\n    asyncio.run(\n        serve_inner(\n            model_id,\n            lora_adapters,\n            revision,\n            sharded,\n            quantize,\n            speculate,\n            dtype,\n            kv_cache_dtype,\n            trust_remote_code,\n        )\n    )\n"
  },
  {
    "path": "server/text_generation_server/tracing.py",
    "content": "import grpc\n\nfrom opentelemetry import trace\nfrom opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\nfrom opentelemetry.instrumentation.grpc._aio_server import (\n    OpenTelemetryAioServerInterceptor,\n)\nfrom opentelemetry.semconv.trace import SpanAttributes\nfrom opentelemetry.sdk.resources import Resource\nfrom opentelemetry.sdk.trace import TracerProvider\nfrom opentelemetry.sdk.trace.export import (\n    BatchSpanProcessor,\n)\n\n\nclass UDSOpenTelemetryAioServerInterceptor(OpenTelemetryAioServerInterceptor):\n    def __init__(self):\n        super().__init__(trace.get_tracer(__name__))\n\n    def _start_span(self, handler_call_details, context, set_status_on_exception=False):\n        \"\"\"\n        Rewrite _start_span method to support Unix Domain Socket gRPC contexts\n        \"\"\"\n\n        # standard attributes\n        attributes = {\n            SpanAttributes.RPC_SYSTEM: \"grpc\",\n            SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],\n        }\n\n        # if we have details about the call, split into service and method\n        if handler_call_details.method:\n            service, method = handler_call_details.method.lstrip(\"/\").split(\"/\", 1)\n            attributes.update(\n                {\n                    SpanAttributes.RPC_METHOD: method,\n                    SpanAttributes.RPC_SERVICE: service,\n                }\n            )\n\n        # add some attributes from the metadata\n        metadata = dict(context.invocation_metadata())\n        if \"user-agent\" in metadata:\n            attributes[\"rpc.user_agent\"] = metadata[\"user-agent\"]\n\n        # We use gRPC over a UNIX socket\n        attributes.update({SpanAttributes.NET_TRANSPORT: \"unix\"})\n\n        return self._tracer.start_as_current_span(\n            name=handler_call_details.method,\n            kind=trace.SpanKind.SERVER,\n            attributes=attributes,\n            set_status_on_exception=set_status_on_exception,\n        )\n\n\ndef setup_tracing(otlp_service_name: str, otlp_endpoint: str):\n    resource = Resource.create(attributes={\"service.name\": otlp_service_name})\n    span_exporter = OTLPSpanExporter(endpoint=otlp_endpoint, insecure=True)\n    span_processor = BatchSpanProcessor(span_exporter)\n\n    trace.set_tracer_provider(TracerProvider(resource=resource))\n    trace.get_tracer_provider().add_span_processor(span_processor)\n"
  },
  {
    "path": "server/text_generation_server/utils/__init__.py",
    "content": "from text_generation_server.utils.convert import convert_file, convert_files\nfrom text_generation_server.utils.dist import initialize_torch_distributed\nfrom text_generation_server.utils.weights import Weights\nfrom text_generation_server.utils.peft import download_and_unload_peft\nfrom text_generation_server.utils.hub import (\n    weight_files,\n    weight_hub_files,\n    download_weights,\n    EntryNotFoundError,\n    LocalEntryNotFoundError,\n    RevisionNotFoundError,\n)\nfrom text_generation_server.utils.tokens import (\n    NextTokenChooser,\n    HeterogeneousNextTokenChooser,\n    StoppingCriteria,\n    StopSequenceCriteria,\n    FinishReason,\n    Sampling,\n    Greedy,\n)\n\n__all__ = [\n    \"convert_file\",\n    \"convert_files\",\n    \"initialize_torch_distributed\",\n    \"weight_files\",\n    \"weight_hub_files\",\n    \"download_weights\",\n    \"download_and_unload_peft\",\n    \"EntryNotFoundError\",\n    \"HeterogeneousNextTokenChooser\",\n    \"LocalEntryNotFoundError\",\n    \"RevisionNotFoundError\",\n    \"Greedy\",\n    \"NextTokenChooser\",\n    \"Sampling\",\n    \"StoppingCriteria\",\n    \"StopSequenceCriteria\",\n    \"FinishReason\",\n    \"Weights\",\n]\n"
  },
  {
    "path": "server/text_generation_server/utils/adapter.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/utils/adapter.py\n# License:  Apache License Version 2.0, January 2004\n\nimport warnings\nimport re\nfrom dataclasses import dataclass\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING, Set, Tuple, Optional, List\n\nfrom safetensors.torch import load_file\nfrom transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer\n\nfrom text_generation_server.utils.merges.strategies import merge_adapters\n\nfrom text_generation_server.utils import hub\nfrom text_generation_server.adapters.lora import LoraConfig\n\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters.config import AdapterConfig, ModuleMap\n\n\nBASE_MODEL_ADAPTER_ID = \"__base_model__\"\n\n\n@dataclass\nclass AdapterInfo:\n    id: str\n    path: Optional[str]\n    revision: Optional[str] = None\n\n\n@dataclass\nclass AdapterParameters:\n    adapter_info: Tuple[AdapterInfo]\n    weights: Tuple[float]\n    merge_strategy: NotImplemented\n    density: float\n    majority_sign_method: NotImplemented\n\n\n@dataclass\nclass AdapterSource:\n    adapter_id: str\n    model_id: str\n    revision: str\n\n\ndef parse_lora_adapters(lora_adapters: Optional[str]) -> List[AdapterInfo]:\n    if not lora_adapters:\n        return []\n\n    adapter_list = []\n    for adapter in lora_adapters.split(\",\"):\n        adapter = adapter.strip()\n        if adapter.count(\"=\") > 1 or adapter.count(\"@\") > 1:\n            raise ValueError(f\"Invalid LoRA adapter format: {adapter}\")\n        match = re.match(r\"^([^=@]+)(?:=([^@]+))?(?:@(.+))?$\", adapter)\n\n        if match:\n            adapter_id, path, revision = match.groups()\n            adapter_list.append(\n                AdapterInfo(id=adapter_id, path=path, revision=revision)\n            )\n        else:\n            raise ValueError(f\"Invalid LoRA adapter format: {adapter}\")\n    return adapter_list\n\n\ndef load_and_merge_adapters(\n    model_id: str,\n    adapter_parameters: AdapterParameters,\n    adapter_index: int,\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    if len(adapter_parameters.adapter_info) == 1:\n        adapter = next(iter(adapter_parameters.adapter_info))\n        return load_module_map(\n            model_id,\n            adapter.revision,\n            adapter.id,\n            adapter.path,\n            weight_names,\n            trust_remote_code,\n        )\n\n    adapter_params = AdapterParametersContainer(adapter_parameters, adapter_index)\n    return _load_and_merge(\n        model_id,\n        adapter_params,\n        weight_names,\n        trust_remote_code,\n    )\n\n\n@dataclass\nclass AdapterParametersContainer:\n    adapter_parameters: AdapterParameters\n    adapter_index: int\n\n    def __hash__(self) -> int:\n        return self.adapter_index\n\n\n@lru_cache(maxsize=32)\ndef _load_and_merge(\n    model_id: str,\n    adapter_params: AdapterParametersContainer,\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    params = adapter_params.adapter_parameters\n\n    adapters_to_merge = []\n    merged_weight_names = set()\n    tokenizer = None\n    for adapter in params.adapter_info:\n        if adapter.id == BASE_MODEL_ADAPTER_ID:\n            raise ValueError(\"Base model adapter cannot be merged.\")\n\n        (\n            module_map,\n            adapter_config,\n            adapter_weight_names,\n            adapter_tokenizer,\n        ) = load_module_map(\n            model_id,\n            adapter.revision,\n            adapter.id,\n            adapter.path,\n            weight_names,\n            trust_remote_code,\n        )\n\n        adapters_to_merge.append((module_map, adapter_config))\n        merged_weight_names = merged_weight_names.union(adapter_weight_names)\n        if tokenizer is None:\n            tokenizer = adapter_tokenizer\n\n    if len(adapters_to_merge) == 0:\n        raise ValueError(\"No adapters to merge.\")\n\n    module_map, adapter_config = merge_adapters(adapters_to_merge, params)\n    return module_map, adapter_config, merged_weight_names, tokenizer\n\n\ndef check_architectures(\n    model_id: str,\n    adapter_id: str,\n    adapter_config: \"AdapterConfig\",\n    trust_remote_code: bool = False,\n):\n    try:\n        if not adapter_config.base_model_name_or_path:\n            # Avoid execution latency caused by the network connection retrying for AutoConfig.from_pretrained(None)\n            return\n\n        expected_config = AutoConfig.from_pretrained(\n            model_id, trust_remote_code=trust_remote_code\n        )\n        model_config = AutoConfig.from_pretrained(\n            adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code\n        )\n    except Exception as e:\n        warnings.warn(\n            f\"Unable to check architecture compatibility for adapter '{adapter_id}' \"\n            f\"against model '{model_id}'. Assuming they are compatible. Error: {e}\"\n        )\n        return\n\n    if model_config.architectures == expected_config.architectures:\n        warnings.warn(\n            f\"Adapter '{adapter_id}' was not trained on base model '{model_id}'. \"\n            f\"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead.\"\n        )\n    else:\n        # TODO(travis): revisit this when we support clasification heads which will not use CausalLM\n        raise ValueError(\n            f\"Adapter '{adapter_id}' is not compatible with model '{model_id}'. \"\n            f\"Architectures differ: {model_config.architectures} != {expected_config.architectures}. \"\n            f\"Use --model-id '{adapter_config.base_model_name_or_path}' instead.\"\n        )\n\n\n@lru_cache(maxsize=128)\ndef load_module_map(\n    model_id: str,\n    revision: str,\n    adapter_id: str,\n    adapter_path: Optional[str],\n    weight_names: Tuple[str],\n    trust_remote_code: bool = False,\n) -> Tuple[\"ModuleMap\", \"AdapterConfig\", Set[str], PreTrainedTokenizer]:\n    adapter_config = LoraConfig.load(adapter_path or adapter_id, None)\n\n    if not adapter_path and adapter_config.base_model_name_or_path != model_id:\n        check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)\n\n    adapter_filenames = (\n        hub._weight_files_from_dir(adapter_path, extension=\".safetensors\")\n        if adapter_path\n        else hub._cached_weight_files(\n            adapter_id, revision=revision, extension=\".safetensors\"\n        )\n    )\n\n    # throw an error if no adapter weights are found\n    if not adapter_filenames:\n        raise FileNotFoundError(\n            f\"No adapter weights found for adapter '{adapter_id}' and revision '{revision}'.\"\n        )\n\n    try:\n        adapter_tokenizer = AutoTokenizer.from_pretrained(\n            adapter_config.config_path,\n            trust_remote_code=trust_remote_code,\n        )\n    except Exception:\n        # Adapter does not have a tokenizer, so fallback to base model tokenizer\n        adapter_tokenizer = None\n\n    # load adapter weights from all shards (should have relatively small memory footprint)\n    adapter_weights = {}\n    for filename in adapter_filenames:\n        adapter_weights.update(load_file(filename))\n\n    # map the model weights to the relevant adapter weights (LoRA A and B matrices)\n    module_map, adapter_weight_names = adapter_config.map_weights_for_model(\n        adapter_weights, weight_names\n    )\n    return module_map, adapter_config, adapter_weight_names, adapter_tokenizer\n\n\ndef get_attn_weights(i, layer):\n    qkv = layer.self_attn.query_key_value\n    weights = {}\n\n    for k in [\"q\", \"k\", \"v\"]:\n        key = (i, f\"{k}_proj\")\n        value = (f\"model.layers.{i}.self_attn.{k}_proj\", qkv)\n        weights[key] = value\n\n    # also add the qkv_proj weight for the adapter\n    weights[(i, \"qkv_proj\")] = (\n        f\"model.layers.{i}.self_attn.qkv_proj\",\n        qkv,\n    )\n\n    weights[(i, \"o_proj\")] = (\n        f\"model.layers.{i}.self_attn.o_proj\",\n        layer.self_attn.o_proj,\n    )\n\n    return weights\n\n\ndef get_mlp_weights(i, layer):\n    weights = {}\n    if hasattr(layer, \"mlp\"):\n        mlp = layer.mlp\n        if hasattr(mlp, \"gate_up_proj\"):\n            # handle combined gate_up_proj (e.g., for some LLaMA variants)\n            weights.update(\n                {\n                    (i, \"gate_proj\"): (\n                        f\"model.layers.{i}.mlp.gate_proj\",\n                        mlp.gate_up_proj,\n                    ),\n                    (i, \"up_proj\"): (f\"model.layers.{i}.mlp.up_proj\", mlp.gate_up_proj),\n                }\n            )\n        else:\n            # handle separate gate_proj, up_proj, and down_proj (e.g., for Gemma)\n            if hasattr(mlp, \"gate_proj\"):\n                weights[(i, \"gate_proj\")] = (\n                    f\"model.layers.{i}.mlp.gate_proj\",\n                    mlp.gate_proj,\n                )\n            if hasattr(mlp, \"up_proj\"):\n                weights[(i, \"up_proj\")] = (f\"model.layers.{i}.mlp.up_proj\", mlp.up_proj)\n\n        if hasattr(mlp, \"c_fc\"):\n            weights[(i, \"c_fc\")] = (f\"model.layers.{i}.mlp.c_fc\", mlp.c_fc)\n\n        if hasattr(mlp, \"c_proj\"):\n            weights[(i, \"c_proj\")] = (f\"model.layers.{i}.mlp.c_proj\", mlp.c_proj)\n\n        if hasattr(mlp, \"down_proj\"):\n            weights[(i, \"down_proj\")] = (\n                f\"model.layers.{i}.mlp.down_proj\",\n                mlp.down_proj,\n            )\n\n    return weights\n\n\n# build_layer_weight_lookup creates a mapping of model layers to their corresponding\n# weight tensors and paths. It builds a dictionary that maps layer identifiers to tuples\n# containing the weight tensor path and the actual layer object. This mapping is needed\n# for the lora adapter to know which weights to update when applying the adapter.\ndef build_layer_weight_lookup(model):\n    if hasattr(model, \"language_model\"):\n        m = model.language_model.model\n    elif hasattr(model, \"text_model\"):\n        m = model.text_model.model\n    else:\n        m = model.model\n\n    layer_weights = {}\n\n    for i, layer in enumerate(m.layers):\n        attn_weights = get_attn_weights(i, layer)\n        mlp_weights = get_mlp_weights(i, layer)\n\n        layer_weights.update(attn_weights)\n        layer_weights.update(mlp_weights)\n\n    lm_head = None\n    if hasattr(m, \"lm_head\"):\n        lm_head = m.lm_head\n    elif hasattr(model, \"lm_head\"):\n        lm_head = model.lm_head\n\n    if lm_head:\n        layer_weights[(0, \"lm_head\")] = (\"lm_head\", lm_head)\n\n    return layer_weights\n"
  },
  {
    "path": "server/text_generation_server/utils/chunks.py",
    "content": "from typing import Iterable\n\nfrom loguru import logger\n\nfrom text_generation_server.pb import generate_pb2\n\n\ndef concat_text_chunks(chunks: Iterable[generate_pb2.InputChunk]) -> str:\n    \"\"\"\n    Concatenate text in text chunks. Non-text chunks are dropped.\n    \"\"\"\n    text = None\n    for chunk in chunks:\n        chunk_type = chunk.WhichOneof(\"chunk\")\n        if chunk_type == \"text\":\n            if text is None:\n                text = chunk.text\n            else:\n                raise NotImplementedError(\"Request contained more than one text chunk\")\n        else:\n            # We cannot reject this, e.g. warmup sends an image chunk.\n            logger.debug(f\"Encountered non-text chunk type {chunk_type}\")\n\n    if text is None:\n        raise NotImplementedError(\"Request without a text chunk\")\n\n    return text\n"
  },
  {
    "path": "server/text_generation_server/utils/convert.py",
    "content": "import datetime\nimport torch\nimport os\n\nfrom loguru import logger\nfrom pathlib import Path\nfrom safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete\nfrom typing import List, Dict\nfrom collections import defaultdict\n\n\ndef _remove_duplicate_names(\n    state_dict: Dict[str, torch.Tensor],\n    *,\n    preferred_names: List[str] = None,\n    discard_names: List[str] = None,\n) -> Dict[str, List[str]]:\n    if preferred_names is None:\n        preferred_names = []\n    preferred_names = set(preferred_names)\n    if discard_names is None:\n        discard_names = []\n    discard_names = set(discard_names)\n\n    shareds = _find_shared_tensors(state_dict)\n    to_remove = defaultdict(list)\n    for shared in shareds:\n        complete_names = set(\n            [name for name in shared if _is_complete(state_dict[name])]\n        )\n        if not complete_names:\n            if len(shared) == 1:\n                # Force contiguous\n                name = list(shared)[0]\n                state_dict[name] = state_dict[name].clone()\n                complete_names = {name}\n            else:\n                raise RuntimeError(\n                    f\"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue.\"\n                )\n\n        keep_name = sorted(list(complete_names))[0]\n\n        # Mecanism to preferentially select keys to keep\n        # coming from the on-disk file to allow\n        # loading models saved with a different choice\n        # of keep_name\n        preferred = complete_names.difference(discard_names)\n        if preferred:\n            keep_name = sorted(list(preferred))[0]\n\n        if preferred_names:\n            preferred = preferred_names.intersection(complete_names)\n            if preferred:\n                keep_name = sorted(list(preferred))[0]\n        for name in sorted(shared):\n            if name != keep_name:\n                to_remove[keep_name].append(name)\n    return to_remove\n\n\ndef convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):\n    \"\"\"\n    Convert a pytorch file to a safetensors file\n    This will remove duplicate tensors from the file.\n\n    Unfortunately, this might not respect *transformers* convention.\n    Forcing us to check for potentially different keys during load when looking\n    for specific tensors (making tensor sharing explicit).\n    \"\"\"\n    loaded = torch.load(pt_file, map_location=\"cpu\", weights_only=True)\n    if \"state_dict\" in loaded:\n        loaded = loaded[\"state_dict\"]\n    to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)\n\n    metadata = {\"format\": \"pt\"}\n    for kept_name, to_remove_group in to_removes.items():\n        for to_remove in to_remove_group:\n            if to_remove not in metadata:\n                metadata[to_remove] = kept_name\n            del loaded[to_remove]\n    # Force tensors to be contiguous\n    loaded = {k: v.contiguous() for k, v in loaded.items()}\n\n    dirname = os.path.dirname(sf_file)\n    os.makedirs(dirname, exist_ok=True)\n    save_file(loaded, sf_file, metadata=metadata)\n    reloaded = load_file(sf_file)\n    for k in loaded:\n        pt_tensor = loaded[k]\n        sf_tensor = reloaded[k]\n        if not torch.equal(pt_tensor, sf_tensor):\n            raise RuntimeError(f\"The output tensors do not match for key {k}\")\n\n\ndef convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):\n    assert len(pt_files) == len(sf_files)\n\n    N = len(pt_files)\n    # We do this instead of using tqdm because we want to parse the logs with the launcher\n\n    for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):\n        # Skip blacklisted files\n        if (\n            \"arguments\" in pt_file.name\n            or \"args\" in pt_file.name\n            or \"training\" in pt_file.name\n        ):\n            continue\n\n        start = datetime.datetime.now()\n        convert_file(pt_file, sf_file, discard_names)\n        elapsed = datetime.datetime.now() - start\n        logger.info(f\"Convert: [{i + 1}/{N}] -- Took: {elapsed}\")\n"
  },
  {
    "path": "server/text_generation_server/utils/dist.py",
    "content": "import os\nimport torch\nfrom torch.distributed import ProcessGroup\nfrom datetime import timedelta\nfrom loguru import logger\nfrom text_generation_server.utils.import_utils import SYSTEM\n\n# Tensor Parallelism settings\nRANK = int(os.getenv(\"RANK\", \"0\"))\nWORLD_SIZE = int(os.getenv(\"WORLD_SIZE\", \"1\"))\n\n# CUDA memory fraction\nMEMORY_FRACTION = float(os.getenv(\"CUDA_MEMORY_FRACTION\", \"1.0\"))\n\n\nclass FakeBarrier:\n    def wait(self):\n        pass\n\n\nclass FakeGroup(ProcessGroup):\n    def __init__(self, rank, size):\n        self._rank = rank\n        self._size = size\n        super().__init__(rank, size)\n\n    def allreduce(self, *args, **kwargs):\n        return FakeBarrier()\n\n    def allgather(self, inputs, local_tensor, **kwargs):\n        assert (\n            len(inputs[0]) == len(local_tensor) == 1\n        ), f\"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors\"\n        for input_ in inputs:\n            input_[0].data = local_tensor[0].data\n        return FakeBarrier()\n\n    def barrier(self, *args, **kwargs):\n        return FakeBarrier()\n\n    def size(self):\n        return self._size\n\n    def rank(self):\n        return self._rank\n\n\ndef initialize_torch_distributed():\n    if torch.cuda.is_available():\n        from torch.distributed import ProcessGroupNCCL\n\n        # Set the device id.\n        assert WORLD_SIZE <= torch.cuda.device_count(), \"Each process is one gpu\"\n        device = RANK % torch.cuda.device_count()\n        torch.cuda.set_device(device)\n        torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)\n        backend = \"nccl\"\n        options = ProcessGroupNCCL.Options()\n        options.is_high_priority_stream = True\n        options._timeout = timedelta(seconds=120)\n    else:\n        backend = \"gloo\"\n        options = None\n\n    if WORLD_SIZE == 1:\n        return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE\n    else:\n        if os.getenv(\"DEBUG\", None) == \"1\":\n            return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE\n\n        if not torch.distributed.is_initialized():\n            # Call the init process.\n            if SYSTEM == \"ipex\":\n                import intel_extension_for_pytorch as ipex\n\n                if torch.xpu.is_available():\n                    assert (\n                        WORLD_SIZE <= torch.xpu.device_count()\n                    ), \"Each process is one xpu\"\n                    device = RANK % torch.xpu.device_count()\n                    torch.xpu.set_device(device)\n                    device_id = torch.device(f\"xpu:{RANK}\")\n                    torch.distributed.init_process_group(\n                        backend=\"xccl\",\n                        world_size=WORLD_SIZE,\n                        rank=RANK,\n                        timeout=timedelta(seconds=120),\n                        pg_options=options,\n                        device_id=device_id,\n                    )\n                else:\n                    ipex.distributed.init_process_group(\n                        backend=\"ccl\",\n                        world_size=WORLD_SIZE,\n                        rank=RANK,\n                        timeout=timedelta(seconds=120),\n                        pg_options=options,\n                    )\n            else:\n                device = torch.device(f\"cuda:{RANK}\")\n                torch.distributed.init_process_group(\n                    backend=backend,\n                    world_size=WORLD_SIZE,\n                    rank=RANK,\n                    timeout=timedelta(seconds=120),\n                    pg_options=options,\n                    device_id=device,\n                )\n        else:\n            logger.warning(\"torch.distributed is already initialized.\")\n\n        return torch.distributed.group.WORLD, RANK, WORLD_SIZE\n"
  },
  {
    "path": "server/text_generation_server/utils/hub.py",
    "content": "import time\nimport os\n\nfrom datetime import timedelta\nfrom loguru import logger\nfrom pathlib import Path\nfrom typing import Optional, List\n\nfrom huggingface_hub import file_download, hf_api, HfApi, hf_hub_download\nfrom huggingface_hub.constants import HUGGINGFACE_HUB_CACHE\nfrom huggingface_hub.utils import (\n    LocalEntryNotFoundError,\n    EntryNotFoundError,\n    RevisionNotFoundError,  # noqa # Import here to ease try/except in other part of the lib\n)\n\nWEIGHTS_CACHE_OVERRIDE = os.getenv(\"WEIGHTS_CACHE_OVERRIDE\", None)\nHF_HUB_OFFLINE = os.environ.get(\"HF_HUB_OFFLINE\", \"0\").lower() in [\"true\", \"1\", \"yes\"]\n\n\ndef _cached_weight_files(\n    model_id: str, revision: Optional[str], extension: str\n) -> List[str]:\n    \"\"\"Guess weight files from the cached revision snapshot directory\"\"\"\n    d = _get_cached_revision_directory(model_id, revision)\n    if not d:\n        return []\n    filenames = _weight_files_from_dir(d, extension)\n    return filenames\n\n\ndef _weight_hub_files_from_model_info(\n    info: hf_api.ModelInfo, extension: str\n) -> List[str]:\n    return [\n        s.rfilename\n        for s in info.siblings\n        if s.rfilename.endswith(extension)\n        and len(s.rfilename.split(\"/\")) == 1\n        and \"arguments\" not in s.rfilename\n        and \"args\" not in s.rfilename\n        and \"training\" not in s.rfilename\n    ]\n\n\ndef _weight_files_from_dir(d: Path, extension: str) -> List[str]:\n    # os.walk: do not iterate, just scan for depth 1, not recursively\n    # see _weight_hub_files_from_model_info, that's also what is\n    # done there with the len(s.rfilename.split(\"/\")) == 1 condition\n    root, _, files = next(os.walk(str(d)))\n    filenames = [\n        os.path.join(root, f)\n        for f in files\n        if f.endswith(extension)\n        and \"arguments\" not in f\n        and \"args\" not in f\n        and \"training\" not in f\n    ]\n    return filenames\n\n\ndef _get_cached_revision_directory(\n    model_id: str, revision: Optional[str]\n) -> Optional[Path]:\n    if revision is None:\n        revision = \"main\"\n\n    repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(\n        file_download.repo_folder_name(repo_id=model_id, repo_type=\"model\")\n    )\n\n    if not repo_cache.is_dir():\n        # No cache for this model\n        return None\n\n    refs_dir = repo_cache / \"refs\"\n    snapshots_dir = repo_cache / \"snapshots\"\n\n    # Resolve refs (for instance to convert main to the associated commit sha)\n    if refs_dir.is_dir():\n        revision_file = refs_dir / revision\n        if revision_file.exists():\n            with revision_file.open() as f:\n                revision = f.read()\n\n    # Check if revision folder exists\n    if not snapshots_dir.exists():\n        return None\n    cached_shas = os.listdir(snapshots_dir)\n    if revision not in cached_shas:\n        # No cache for this revision and we won't try to return a random revision\n        return None\n\n    return snapshots_dir / revision\n\n\ndef weight_hub_files(\n    model_id: str, revision: Optional[str] = None, extension: str = \".safetensors\"\n) -> List[str]:\n    \"\"\"Get the weights filenames on the hub\"\"\"\n    api = HfApi()\n\n    if HF_HUB_OFFLINE:\n        filenames = _cached_weight_files(model_id, revision, extension)\n    else:\n        # Online case, fetch model info from the Hub\n        info = api.model_info(model_id, revision=revision)\n        filenames = _weight_hub_files_from_model_info(info, extension)\n\n    if not filenames:\n        raise EntryNotFoundError(\n            f\"No {extension} weights found for model {model_id} and revision {revision}.\",\n            None,\n        )\n\n    return filenames\n\n\ndef try_to_load_from_cache(\n    model_id: str, revision: Optional[str], filename: str\n) -> Optional[Path]:\n    \"\"\"Try to load a file from the Hugging Face cache\"\"\"\n\n    d = _get_cached_revision_directory(model_id, revision)\n    if not d:\n        return None\n\n    # Check if file exists in cache\n    cached_file = d / filename\n    return cached_file if cached_file.is_file() else None\n\n\ndef weight_files(\n    model_id: str, revision: Optional[str] = None, extension: str = \".safetensors\"\n) -> List[Path]:\n    \"\"\"Get the local files\"\"\"\n    # Local model\n    d = Path(model_id)\n    if d.exists() and d.is_dir():\n        local_files = _weight_files_from_dir(d, extension)\n        if not local_files:\n            raise FileNotFoundError(\n                f\"No local weights found in {model_id} with extension {extension}\"\n            )\n        return [Path(f) for f in local_files]\n\n    try:\n        filenames = weight_hub_files(model_id, revision, extension)\n    except EntryNotFoundError as e:\n        if extension != \".safetensors\":\n            raise e\n        # Try to see if there are pytorch weights\n        pt_filenames = weight_hub_files(model_id, revision, extension=\".bin\")\n        # Change pytorch extension to safetensors extension\n        # It is possible that we have safetensors weights locally even though they are not on the\n        # hub if we converted weights locally without pushing them\n        filenames = [\n            f\"{Path(f).stem.lstrip('pytorch_')}.safetensors\" for f in pt_filenames\n        ]\n\n    if WEIGHTS_CACHE_OVERRIDE is not None:\n        files = []\n        for filename in filenames:\n            p = Path(WEIGHTS_CACHE_OVERRIDE) / filename\n            if not p.exists():\n                raise FileNotFoundError(\n                    f\"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}.\"\n                )\n            files.append(p)\n        return files\n\n    files = []\n    for filename in filenames:\n        cache_file = try_to_load_from_cache(\n            model_id, revision=revision, filename=filename\n        )\n        if cache_file is None:\n            raise LocalEntryNotFoundError(\n                f\"File {filename} of model {model_id} not found in \"\n                f\"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. \"\n                f\"Please run `text-generation-server download-weights {model_id}` first.\"\n            )\n        files.append(cache_file)\n\n    return files\n\n\ndef download_weights(\n    filenames: List[str], model_id: str, revision: Optional[str] = None\n) -> List[Path]:\n    \"\"\"Download the safetensors files from the hub\"\"\"\n\n    def download_file(fname, tries=5, backoff: int = 5):\n        local_file = try_to_load_from_cache(model_id, revision, fname)\n        if local_file is not None:\n            logger.info(f\"File {fname} already present in cache.\")\n            return Path(local_file)\n\n        for idx in range(tries):\n            try:\n                logger.info(f\"Download file: {fname}\")\n                stime = time.time()\n                local_file = hf_hub_download(\n                    filename=fname,\n                    repo_id=model_id,\n                    revision=revision,\n                    local_files_only=HF_HUB_OFFLINE,\n                )\n                logger.info(\n                    f\"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}.\"\n                )\n                return Path(local_file)\n            except Exception as e:\n                if idx + 1 == tries:\n                    raise e\n                logger.error(e)\n                logger.info(f\"Retrying in {backoff} seconds\")\n                time.sleep(backoff)\n                logger.info(f\"Retry {idx + 1}/{tries - 1}\")\n\n    # We do this instead of using tqdm because we want to parse the logs with the launcher\n    start_time = time.time()\n    files = []\n    for i, filename in enumerate(filenames):\n        file = download_file(filename)\n\n        elapsed = timedelta(seconds=int(time.time() - start_time))\n        remaining = len(filenames) - (i + 1)\n        eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0\n\n        logger.info(f\"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}\")\n        files.append(file)\n\n    return files\n"
  },
  {
    "path": "server/text_generation_server/utils/import_utils.py",
    "content": "import torch\nfrom loguru import logger\nimport os\n\n\nimport importlib.util\n\n\ndef is_ipex_available():\n    return importlib.util.find_spec(\"intel_extension_for_pytorch\") is not None\n\n\ndef get_cuda_free_memory(device, memory_fraction):\n    total_free_memory, _ = torch.cuda.mem_get_info(device)\n    total_gpu_memory = torch.cuda.get_device_properties(device).total_memory\n    free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)\n    return free_memory\n\n\ndef get_xpu_free_memory(device, memory_fraction):\n    total_free_memory, total_xpu_memory = torch.xpu.mem_get_info(device)\n    memory_fraction = float(os.getenv(\"XPU_MEMORY_FRACTION\", \"0.9\"))\n    free_memory = max(\n        0, int(total_free_memory - (1 - memory_fraction) * total_xpu_memory)\n    )\n    return free_memory\n\n\ndef get_cpu_free_memory(device, memory_fraction):\n    import psutil\n    from text_generation_server.utils.dist import WORLD_SIZE\n\n    mem = psutil.virtual_memory()\n    free_memory = int(mem.available * 0.95 / WORLD_SIZE)\n    return free_memory\n\n\ndef noop(*args, **kwargs):\n    pass\n\n\nSYSTEM = None\nif torch.version.hip is not None:\n    SYSTEM = \"rocm\"\n    empty_cache = torch.cuda.empty_cache\n    synchronize = torch.cuda.synchronize\n    get_free_memory = get_cuda_free_memory\nelif torch.version.cuda is not None and torch.cuda.is_available():\n    SYSTEM = \"cuda\"\n    empty_cache = torch.cuda.empty_cache\n    synchronize = torch.cuda.synchronize\n    get_free_memory = get_cuda_free_memory\nelif is_ipex_available():\n    SYSTEM = \"ipex\"\n    import intel_extension_for_pytorch  # noqa: F401\n\n    if hasattr(torch, \"xpu\") and torch.xpu.is_available():\n        empty_cache = torch.xpu.empty_cache\n        synchronize = torch.xpu.synchronize\n        get_free_memory = get_xpu_free_memory\n    else:\n        empty_cache = noop\n        synchronize = noop\n        get_free_memory = get_cpu_free_memory\nelif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n    SYSTEM = \"xpu\"\n    empty_cache = torch.xpu.empty_cache\n    synchronize = torch.xpu.synchronize\n    get_free_memory = get_xpu_free_memory\nelse:\n    SYSTEM = \"cpu\"\n\n    empty_cache = noop\n    synchronize = noop\n    get_free_memory = get_cpu_free_memory\nlogger.info(f\"Detected system {SYSTEM}\")\n"
  },
  {
    "path": "server/text_generation_server/utils/kernels.py",
    "content": "import importlib\n\nfrom loguru import logger\nfrom kernels import load_kernel as hf_load_kernel\n\nfrom text_generation_server.utils.log import log_once\n\n\ndef load_kernel(*, module: str, repo_id: str):\n    \"\"\"\n    Load a kernel. First try to load it as the given module (e.g. for\n    local development), falling back to a locked Hub kernel.\n    \"\"\"\n    try:\n        m = importlib.import_module(module)\n        log_once(logger.info, f\"Using local module for `{module}`\")\n        return m\n    except ModuleNotFoundError:\n        return hf_load_kernel(repo_id=repo_id)\n\n\n__all__ = [\"load_kernel\"]\n"
  },
  {
    "path": "server/text_generation_server/utils/log.py",
    "content": "from functools import lru_cache\nfrom text_generation_server.utils.dist import RANK\n\n\n@lru_cache(10)\ndef log_once(log, msg: str, master=True):\n    if master:\n        log_master(log, msg)\n    else:\n        log(msg)\n\n\ndef log_master(log, msg: str):\n    if RANK == 0:\n        log(msg)\n"
  },
  {
    "path": "server/text_generation_server/utils/logits_process.py",
    "content": "from functools import lru_cache\nimport math\nimport time\nimport torch\nfrom typing import List, Optional, DefaultDict\n\nfrom loguru import logger\nfrom typing import Dict\nfrom text_generation_server.pb.generate_pb2 import GrammarType\n\nfrom outlines.fsm.guide import RegexGuide\n\nfrom transformers import (\n    LogitsProcessor,\n    PreTrainedTokenizerBase,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    TypicalLogitsWarper,\n)\n\nmempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None\n\n\nclass StaticWarper:\n    def __init__(\n        self,\n        temperature=1.0,\n        top_k=None,\n        top_p=None,\n        typical_p=None,\n    ):\n        self.warpers = []\n\n        if temperature is not None and temperature != 1.0:\n            temperature = float(temperature)\n            self.warpers.append(TemperatureLogitsWarper(temperature))\n        if top_k is not None and top_k != 0:\n            self.warpers.append(TopKLogitsWarper(top_k=top_k))\n        if top_p is not None and top_p < 1.0:\n            self.warpers.append(TopPLogitsWarper(top_p=top_p))\n        if typical_p is not None and typical_p < 1.0:\n            self.warpers.append(TypicalLogitsWarper(mass=typical_p))\n\n        self.cuda_graph = None\n        self.static_scores = None\n        self.static_warped_scores = None\n        self.static_next_logprob = None\n\n    def __call__(self, scores):\n        if torch.cuda.is_available():\n            if self.cuda_graph is None:\n                self.static_scores = scores\n                self.cuda_graph = torch.cuda.CUDAGraph()\n\n                with torch.cuda.graph(self.cuda_graph, pool=mempool):\n                    local_scores = self.static_scores\n                    for warper in self.warpers:\n                        local_scores = warper(None, local_scores)\n\n                    self.static_warped_scores = local_scores\n                    # Compute logprobs\n                    self.static_next_logprob = torch.log_softmax(\n                        self.static_warped_scores, -1\n                    )\n\n            self.static_scores.copy_(scores)\n            self.cuda_graph.replay()\n\n            return self.static_warped_scores, self.static_next_logprob\n\n        # CPU branch\n        for warper in self.warpers:\n            scores = warper(None, scores)\n        return scores, torch.log_softmax(scores, -1)\n\n\n@lru_cache(10)\ndef static_warper(\n    temperature: Optional[float],\n    top_k: Optional[int],\n    top_p: Optional[float],\n    typical_p: Optional[float],\n) -> StaticWarper:\n    return StaticWarper(\n        temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p\n    )\n\n\nclass HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        repetition_penalty (`List[float]`):\n            The parameter for repetition penalty. 1.0 means no penalty. See [this\n            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.\n    \"\"\"\n\n    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):\n        self.penalty = penalty\n        self.penalty_tensor = torch.tensor(\n            penalty, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        score = torch.gather(scores, 1, input_ids)\n\n        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability\n        score = torch.where(\n            score < 0, score * self.penalty_tensor, score / self.penalty_tensor\n        )\n\n        scores.scatter_(1, input_ids, score)\n        return scores\n\n    def filter(self, indices):\n        self.penalty = [self.penalty[i] for i in indices]\n        if any([x != 1.0 for x in self.penalty]):\n            self.penalty_tensor = self.penalty_tensor[indices]\n            return self\n        return None\n\n\nclass FrequencyPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    Frequency penalty as defined by OpenAI\n\n    Args:\n        penalty (`float`):\n            The parameter for frequency penalty. 0.0 means no penalty.\n    \"\"\"\n\n    def __init__(self, penalty: float):\n        self.penalty = penalty\n\n    def __call__(\n        self, input_ids: torch.LongTensor, scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        score = torch.gather(scores, 1, input_ids)\n        # if score < 0 then penalty has to be multiplied to reduce the previous token probability\n        score = -torch.where(score < 0, score * self.penalty, score / self.penalty)\n        # set score to 0 where input_ids is a padding token\n        score *= input_ids.ne(0)\n\n        return scores.scatter_add_(1, input_ids, score)\n\n\nclass HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):\n    r\"\"\"\n    Frequency penalty as defined by OpenAI in\n    https://platform.openai.com/docs/guides/text-generation/parameter-details\n\n    Args:\n        frequency_penalty (`List[float]`):\n            The parameter for frequency penalty. 0.0 means no penalty.\n    \"\"\"\n\n    def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):\n        self.penalty = penalty\n        self.penalty_tensor = torch.tensor(\n            penalty, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        batch_size, input_size = input_ids.size()\n        vocab_size = scores.size(1)\n\n        # Calculate the frequency for each token so far\n        token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)\n        token_freq.scatter_add_(\n            1, input_ids, torch.ones_like(input_ids, dtype=torch.float)\n        )\n        token_freq /= input_size\n\n        # Apply the frequency penalty to logits\n        scores -= token_freq * self.penalty_tensor\n        return scores\n\n    def filter(self, indices):\n        self.penalty = [self.penalty[i] for i in indices]\n        if any([x != 0.0 for x in self.penalty]):\n            self.penalty_tensor = self.penalty_tensor[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTemperatureLogitsWarper:\n    r\"\"\"\n    [`LogitsWarper`] for temperature (exponential scaling output probability distribution).\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        temperature (`float`):\n            The value used to module the logits distribution.\n    \"\"\"\n\n    def __init__(\n        self, temperature: List[float], dtype: torch.dtype, device: torch.device\n    ):\n        self.temperature = temperature\n        self.temperature_tensor = torch.tensor(\n            temperature, dtype=dtype, device=device\n        ).unsqueeze(1)\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        scores.div_(self.temperature_tensor)\n        return scores\n\n    def filter(self, indices):\n        self.temperature = [self.temperature[i] for i in indices]\n        if any([x != 1.0 for x in self.temperature]):\n            self.temperature_tensor = self.temperature_tensor[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTopPLogitsWarper(LogitsProcessor):\n    \"\"\"\n    [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        top_p (`float`):\n            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or\n            higher are kept for generation.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        top_p: List[float],\n        dtype: torch.dtype,\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.top_p = top_p\n        self.top_p_opposite = 1 - torch.tensor(\n            top_p, dtype=dtype, device=device\n        ).unsqueeze(1)\n        self.filter_value = filter_value\n        self.min_tokens_to_keep = min_tokens_to_keep\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        sorted_logits, sorted_indices = torch.sort(scores, descending=False)\n        probs = sorted_logits.softmax(dim=-1)\n        # This is way faster for some reason\n        for i in range(probs.shape[0]):\n            probs[i] = probs[i].cumsum(dim=-1)\n\n        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)\n        sorted_indices_to_remove = probs <= self.top_p_opposite\n        # Keep at least min_tokens_to_keep\n        sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0\n\n        # scatter sorted tensors to original indexing\n        indices_to_remove = sorted_indices_to_remove.scatter(\n            1, sorted_indices, sorted_indices_to_remove\n        )\n        warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)\n\n        return warped_scores\n\n    def filter(self, indices):\n        self.top_p = [self.top_p[i] for i in indices]\n        if any([x < 1.0 for x in self.top_p]):\n            self.top_p_opposite = self.top_p_opposite[indices]\n            return self\n        return None\n\n\nclass HeterogeneousTopKLogitsWarper(LogitsProcessor):\n    r\"\"\"\n    [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        top_k (`int`):\n            The number of highest probability vocabulary tokens to keep for top-k-filtering.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        top_k: List[int],\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.top_k = top_k\n        self.max_top_k = max(top_k)\n        # value - 1 as we will use top_k to index and python uses 0 based numbering\n        self.top_k_tensor = torch.tensor(\n            [max(x - 1, min_tokens_to_keep - 1) for x in top_k],\n            dtype=torch.int64,\n            device=device,\n        ).unsqueeze(1)\n\n        # 0 is a special value that disables top_k warping for this member of the batch\n        disabled = [x == 0 for x in top_k]\n\n        if any(disabled):\n            self.top_k_disabled_mask = torch.tensor(\n                disabled, dtype=torch.bool, device=device\n            ).view(-1, 1)\n        else:\n            self.top_k_disabled_mask = None\n\n        self.filter_value = filter_value\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        # If max_top_k is superior to the vocab, we need to clamp or the warper will fail\n        if scores.size(-1) < self.max_top_k:\n            max_top_k = scores.size(-1)\n            top_k = torch.clamp_max(self.top_k_tensor, max_top_k)\n        else:\n            max_top_k = self.max_top_k\n            top_k = self.top_k_tensor\n\n        # Get the kth score for each member of the batch\n        kth_scores = torch.gather(torch.topk(scores, max_top_k)[0], 1, top_k)\n\n        # Mask member of kth_scores that do not want to use top_k warping\n        if self.top_k_disabled_mask is not None:\n            kth_scores.masked_fill_(self.top_k_disabled_mask, self.filter_value)\n\n        # Remove all tokens with a probability less than the last token of the top-k\n        indices_to_remove = scores < kth_scores\n        scores.masked_fill_(indices_to_remove, self.filter_value)\n        return scores\n\n    def filter(self, indices):\n        self.top_k = [self.top_k[i] for i in indices]\n        disabled = [x == 0 for x in self.top_k]\n\n        if not all(disabled):\n            self.top_k_tensor = self.top_k_tensor[indices]\n            self.max_top_k = max(self.top_k)\n\n            if self.top_k_disabled_mask is not None:\n                self.top_k_disabled_mask = (\n                    self.top_k_disabled_mask[indices] if any(disabled) else None\n                )\n\n            return self\n        return None\n\n\nclass HeterogeneousTypicalLogitsWarper(LogitsProcessor):\n    r\"\"\"\n    [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language\n    Generation](https://arxiv.org/abs/2202.00666) for more information.\n    This version allows for a separate value for each sample and runs inplace when possible.\n    It doesn't validate inputs.\n\n    Args:\n        mass (`float`):\n            Value of typical_p between 0 and 1 inclusive, defaults to 0.9.\n        filter_value (`float`, *optional*, defaults to `-float(\"Inf\")`):\n            All filtered values will be set to this float value.\n        min_tokens_to_keep (`int`, *optional*, defaults to 1):\n            Minimum number of tokens that cannot be filtered.\n    \"\"\"\n\n    def __init__(\n        self,\n        mass: List[float],\n        dtype: torch.dtype,\n        device: torch.device,\n        filter_value: float = -math.inf,\n        min_tokens_to_keep: int = 1,\n    ):\n        self.mass = mass\n        self.mass_tensor = torch.tensor(mass, dtype=dtype, device=device).unsqueeze(1)\n\n        # 1 is a special value that disables typical_p warping for this member of the batch\n        disabled = [x == 1.0 for x in mass]\n\n        if any(disabled):\n            self.disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device)\n        else:\n            self.disabled_mask = None\n\n        self.filter_value = filter_value\n        self.min_tokens_to_keep = min_tokens_to_keep\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        # calculate entropy\n        normalized = torch.nn.functional.log_softmax(scores, dim=-1)\n        p = torch.exp(normalized)\n        ent = -(normalized * p).nansum(-1, keepdim=True)\n\n        # shift and sort\n        shifted_scores = torch.abs((-normalized) - ent)\n        sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)\n        sorted_logits = scores.gather(-1, sorted_indices)\n        probs = sorted_logits.softmax(dim=-1)\n        # This is way faster for some reason\n        for i in range(probs.shape[0]):\n            probs[i] = probs[i].cumsum(dim=-1)\n\n        # Remove tokens with cumulative mass above the threshold\n        last_ind = (probs < self.mass_tensor).sum(dim=1)\n        last_ind[last_ind < 0] = 0\n\n        if self.disabled_mask is not None:\n            last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)\n\n        sorted_indices_to_remove = sorted_scores > sorted_scores.gather(\n            1, last_ind.view(-1, 1)\n        )\n        if self.min_tokens_to_keep > 1:\n            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)\n            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0\n        indices_to_remove = sorted_indices_to_remove.scatter(\n            1, sorted_indices, sorted_indices_to_remove\n        )\n\n        warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)\n\n        return warped_scores\n\n    def filter(self, indices):\n        self.mass = [self.mass[i] for i in indices]\n        disabled = [x == 1.0 for x in self.mass]\n\n        if not all(disabled):\n            self.mass_tensor = self.mass_tensor[indices]\n\n            if self.disabled_mask is not None:\n                self.disabled_mask = (\n                    self.disabled_mask[indices] if any(disabled) else None\n                )\n\n            return self\n        return None\n\n\nclass HeterogeneousProcessorWrapper(LogitsProcessor):\n    r\"\"\"\n    A wrapper for logit warpers or processors without heterogeneous parameter support.\n    Args:\n        processors (`Dict[int, LogitsProcessor]`):\n            A mapping of sample indices to logit warpers or processors, to be run sequentially.\n    \"\"\"\n\n    def __init__(\n        self,\n        processors: Dict[int, LogitsProcessor],\n    ):\n        self.processors = processors\n\n    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:\n        for i, processor in self.processors.items():\n            scores[i : i + 1] = processor(input_ids[i : i + 1], scores[i : i + 1])\n        return scores\n\n    def filter(self, indices):\n        new_processors = {}\n        for i, idx in enumerate(indices):\n            if idx in self.processors:\n                new_processors[i] = self.processors[idx]\n\n        if new_processors:\n            self.processors = new_processors\n            return self\n        return None\n\n\nclass GrammarLogitProcessor(LogitsProcessor):\n    fsm_state: DefaultDict[int, int]\n    fsm: RegexGuide\n\n    def __init__(\n        self,\n        tokenizer: Optional[PreTrainedTokenizerBase],\n        device: str,\n        grammar: str,\n        grammar_type: GrammarType,\n    ):\n        self.device = device\n        self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)\n        self.fsm = GrammarLogitProcessor._cached_compile_fsm(\n            grammar_type, grammar, self.tokenizer\n        )\n\n    def __call__(\n        self,\n        logits: torch.Tensor,\n        fsm_grammar_state: int,\n    ):\n        if fsm_grammar_state == -1 or self.fsm is None:\n            return logits\n        allowed_tokens = self.fsm.get_next_instruction(fsm_grammar_state).tokens\n        mask = torch.full_like(logits, -math.inf)\n        if allowed_tokens is not None:\n            mask[:, allowed_tokens] = 0\n        biased_scores = logits + mask\n        return biased_scores\n\n    def advance(self, next_token_id, fsm_grammar_state):\n        return GrammarLogitProcessor._advance(\n            next_token_id, fsm_grammar_state, self.fsm\n        )\n\n    @staticmethod\n    def _advance(next_token_id, fsm_grammar_state, fsm):\n        if fsm_grammar_state == -1:\n            return fsm_grammar_state\n        return fsm.get_next_state(fsm_grammar_state, next_token_id)\n\n    # TODO: move grammar compilation into the router\n    @staticmethod\n    @lru_cache(maxsize=32, typed=True)\n    def _cached_compile_fsm(\n        grammar_type: GrammarType,\n        schema: str,\n        tokenizer: Optional[PreTrainedTokenizerBase],\n    ):\n        start_time = time.time()\n        if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:\n            # JSON schema is compiled by the v3 router.\n            logger.error(\n                \"Non-regex grammars must be compiled by the router, grammar won't be enforced\"\n            )\n            # allows everything\n            schema = \"(.*?)\"\n\n        fsm = RegexGuide.from_regex(schema, tokenizer)\n        logger.debug(f\"Compiled FSM in {time.time() - start_time:.2f}s\")\n        return fsm\n\n    @staticmethod\n    @lru_cache(maxsize=32, typed=True)\n    def _cached_adapt_tokenizer(tokenizer):\n        \"\"\"Adapt tokenizer to work with the FSM.\n\n        The API of Outlines tokenizers is slightly different to that of\n        `transformers`. In addition we need to handle the missing spaces to\n        Llama's tokenizer to be able to compile FSMs for this model.\n\n        \"\"\"\n        start_time = time.time()\n        tokenizer.vocabulary = tokenizer.get_vocab()\n        tokenizer.special_tokens = set(tokenizer.all_special_tokens)\n\n        def convert_token_to_string(token: str) -> str:\n            from transformers.file_utils import SPIECE_UNDERLINE\n\n            string = tokenizer.convert_tokens_to_string([token])\n\n            # A hack to handle missing spaces to HF's Llama tokenizers\n            if token.startswith(SPIECE_UNDERLINE) or token == \"<0x20>\":\n                return \" \" + string\n\n            return string\n\n        tokenizer.convert_token_to_string = convert_token_to_string\n        logger.debug(f\"Adapted tokenizer in {time.time() - start_time:.2f}s\")\n        return tokenizer\n\n\nclass HeterogeneousGrammarLogitProcessor(LogitsProcessor):\n    def __init__(self, tokenizer, device, grammars, grammar_types):\n        self.device = device\n        self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)\n        self.fsms = []\n        for grammar, grammar_type in zip(grammars, grammar_types):\n            if len(grammar) == 0:\n                self.fsms.append(None)\n                continue\n            fsm = GrammarLogitProcessor._cached_compile_fsm(\n                grammar_type, grammar, self.tokenizer\n            )\n            self.fsms.append(fsm)\n\n    def __call__(\n        self,\n        logits: torch.Tensor,\n        fsm_grammar_states: List[int],\n    ):\n        mask = torch.full_like(logits, -math.inf)\n        for i in range(logits.shape[0]):\n            fsm = self.fsms[i]\n            if fsm_grammar_states[i] == -1 or fsm is None:\n                continue\n            allowed_tokens = fsm.get_next_instruction(fsm_grammar_states[i]).tokens\n            if allowed_tokens is not None:\n                mask[i, allowed_tokens] = 0\n            logits[i] += mask[i]\n        return logits\n\n    def advance_batch(self, next_token_ids, fsm_grammar_states):\n        return [\n            GrammarLogitProcessor._advance(\n                next_token_ids[i], fsm_grammar_states[i], self.fsms[i]\n            )\n            for i in range(len(next_token_ids))\n        ]\n\n    def advance_at_index(self, next_token_id, fsm_grammar_state, index):\n        if self.fsms[index] is None:\n            return fsm_grammar_state\n        return GrammarLogitProcessor._advance(\n            next_token_id, fsm_grammar_state, self.fsms[index]\n        )\n\n    def filter(self, indices):\n        new_fsms = []\n        for i in indices:\n            new_fsms.append(self.fsms[i])\n        self.fsms = new_fsms\n        return self\n"
  },
  {
    "path": "server/text_generation_server/utils/merges/strategies.py",
    "content": "import copy\nfrom abc import ABC\nfrom collections import defaultdict\nfrom typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union\nfrom text_generation_server.utils.merges.utils import (\n    calculate_majority_sign_mask,\n    disjoint_merge,\n    prune,\n)\nimport torch\n\nif TYPE_CHECKING:\n    from text_generation_server.adapters.lora import LoraConfig\n    from text_generation_server.utils.adapter import ModuleMap\n\n\nclass AdapterParameters:\n    def __init__(\n        self, adapter_ids, weights, merge_strategy, density, majority_sign_method\n    ):\n        self.adapter_ids = adapter_ids\n        self.weights = weights\n        self.merge_strategy = merge_strategy\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n\ndef _apply_weights(\n    tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor\n) -> torch.Tensor:\n    if isinstance(tensors, torch.Tensor):\n        t = tensors\n    else:\n        t = torch.stack(tensors, dim=0)\n\n    # element-wise weighting of each task tensor\n    # need to unsqueeze weights to match task tensor dimensions\n    # for multiplication to apply element-wise\n    while len(t.shape) > len(w.shape):\n        w = w.unsqueeze(-1)\n    return t * w\n\n\nclass MergeStrategy(ABC):\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        raise NotImplementedError()\n\n\nclass LinearMerge(MergeStrategy):\n    def __init__(self, **kwargs):\n        pass\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n        return weighted_task_tensors.sum(dim=0)\n\n\nclass TiesMerge(MergeStrategy):\n    def __init__(self, density: float, majority_sign_method: str = \"total\", **kwargs):\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"magnitude\") for tensor in task_tensors\n        ]\n        task_tensors = torch.stack(task_tensors, dim=0)\n\n        # elect sign before applying weights\n        majority_sign_mask = calculate_majority_sign_mask(\n            task_tensors, method=self.majority_sign_method\n        )\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n\n        # disjoint merge\n        return disjoint_merge(weighted_task_tensors, majority_sign_mask)\n\n\nclass DareLinearMerge(MergeStrategy):\n    def __init__(self, density: float, **kwargs):\n        self.density = density\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"random\", rescale=True)\n            for tensor in task_tensors\n        ]\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n        return weighted_task_tensors.sum(dim=0)\n\n\nclass DareTiesMerge(MergeStrategy):\n    def __init__(self, density: float, majority_sign_method: str = \"total\", **kwargs):\n        self.density = density\n        self.majority_sign_method = majority_sign_method\n\n    def merge(\n        self, task_tensors: List[torch.Tensor], weights: torch.Tensor\n    ) -> torch.Tensor:\n        # sparsify\n        task_tensors = [\n            prune(tensor, self.density, method=\"random\", rescale=True)\n            for tensor in task_tensors\n        ]\n        task_tensors = torch.stack(task_tensors, dim=0)\n\n        # elect sign before applying weights\n        majority_sign_mask = calculate_majority_sign_mask(\n            task_tensors, method=self.majority_sign_method\n        )\n        weighted_task_tensors = _apply_weights(task_tensors, weights)\n\n        # disjoint merge\n        mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)\n        return mixed_task_tensors\n\n\nstrategy_registry: Dict[str, Type[MergeStrategy]] = {\n    \"linear\": LinearMerge,\n    \"ties\": TiesMerge,\n    \"dare_linear\": DareLinearMerge,\n    \"dare_ties\": DareTiesMerge,\n}\n\n\ndef merge_adapters(\n    adapters: List[Tuple[\"ModuleMap\", \"LoraConfig\"]],\n    merge_params: AdapterParameters,\n) -> Tuple[\"ModuleMap\", \"LoraConfig\"]:\n    # strategy_name = MergeStrategyEnum.Name(merge_params.merge_strategy).lower()\n    strategy_name = \"linear\"\n\n    weights = merge_params.weights\n    if not weights:\n        weights = torch.ones(len(adapters))\n    else:\n        weights = torch.tensor(weights)\n\n    merge_config = {\n        \"density\": merge_params.density,\n        # \"majority_sign_method\": MajoritySignMethodEnum.Name(\n        #     merge_params.majority_sign_method\n        # ).lower(),\n        \"majority_sign_method\": \"total\",\n    }\n    merge_strategy = strategy_registry[strategy_name](**merge_config)\n\n    module_maps: Dict[str, Dict[str, Dict[str, List[torch.Tensor]]]] = defaultdict(\n        lambda: defaultdict(lambda: defaultdict(list))\n    )\n    lora_configs = []\n    weight_name_to_adapter_idx = defaultdict(list)\n\n    # input is list of (module_map, lora_config) tuples\n    # convert into dict[k][param_name] -> list of tensors\n    for idx, (module_map, lora_config) in enumerate(adapters):\n        for weight_name, data in module_map.items():\n            weight_name_to_adapter_idx[weight_name].append(idx)\n            for k, (param_data, param_name) in data.items():\n                module_maps[weight_name][k][param_name].append(param_data)\n        lora_configs.append(lora_config)\n\n    # validate lora configs are compatible\n    _validate_lora_configs(lora_configs)\n\n    # merge tensors for each module such that we have a single ModuleMap:\n    # dict[k] -> merged tensor\n    merged_module_map: \"ModuleMap\" = defaultdict(dict)\n    for weight_name, data in module_maps.items():\n        indices = weight_name_to_adapter_idx[weight_name]\n        param_weights = weights[indices]\n        for k, param_data in data.items():\n            for param_name, tensors in param_data.items():\n                merged_tensor = merge_strategy.merge(tensors, param_weights)\n                merged_module_map[weight_name][k] = (merged_tensor, param_name)\n\n    # merge lora configs\n    merged_lora_config = _merge_lora_configs(lora_configs)\n\n    return merged_module_map, merged_lora_config\n\n\ndef _validate_lora_configs(lora_configs: List[\"LoraConfig\"]):\n    # check that all configs have the same rank\n    ranks = set(lora_config.r for lora_config in lora_configs)\n    if len(ranks) > 1:\n        raise ValueError(\n            f\"unable to merge adapters, lora configs have different ranks: {ranks}\"\n        )\n\n    if all(len(lora_config.target_modules) == 0 for lora_config in lora_configs):\n        raise ValueError(\n            \"unable to merge adapters, lora configs have no target modules\"\n        )\n\n\ndef _merge_lora_configs(lora_configs: List[\"LoraConfig\"]) -> \"LoraConfig\":\n    merged_lora_config = copy.copy(lora_configs[0])\n\n    # merge target modules as a union operation\n    merged_target_modules = sorted(\n        set(\n            module\n            for lora_config in lora_configs\n            for module in lora_config.target_modules\n        )\n    )\n    merged_lora_config.target_modules = merged_target_modules\n\n    return merged_lora_config\n"
  },
  {
    "path": "server/text_generation_server/utils/merges/utils.py",
    "content": "# coding=utf-8\n# From: https://github.com/huggingface/peft/pull/1364\n# Copyright 2024-present the HuggingFace Inc. team.\n# Modifications by Predibase, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Literal\n\nimport torch\n\n\ndef magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:\n    \"\"\"\n    Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction\n    `density`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    \"\"\"\n    mask = torch.zeros_like(tensor).reshape(-1)\n    k = int(density * tensor.reshape(-1).shape[0])\n    top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)\n    mask[top_k[1]] = 1\n    return tensor * mask.reshape(tensor.shape)\n\n\ndef random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:\n    \"\"\"\n    Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction\n    `density`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.\n    \"\"\"\n    mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))\n    pruned_tensor = tensor * mask\n    if rescale:\n        torch.div(input=pruned_tensor, other=density)\n    return pruned_tensor\n\n\ndef prune(\n    tensor: torch.Tensor,\n    density: float,\n    method: Literal[\"magnitude\", \"random\"],\n    rescale: bool = False,\n) -> torch.Tensor:\n    \"\"\"\n    Prune the values of task tensors based on the `method`.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to prune.\n    density (`float`):The fraction of values to preserve. Should be in [0,1].\n    method (`str`):The method to use to prune. Should be one of [\"magnitude\", \"random\"].\n    rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.\n    \"\"\"\n    if density >= 1:\n        return tensor\n    elif density < 0:\n        raise ValueError(\"Density should be >= 0, got {density}\")\n    if method == \"magnitude\":\n        return magnitude_based_pruning(tensor, density)\n    elif method == \"random\":\n        return random_pruning(tensor, density, rescale=rescale)\n    else:\n        raise ValueError(f\"Unknown method {method}\")\n\n\ndef calculate_majority_sign_mask(\n    tensor: torch.Tensor, method: Literal[\"total\", \"frequency\"] = \"total\"\n):\n    \"\"\"\n    Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.\n\n    Args:\n    tensor (`torch.Tensor`):The tensor to get the mask from.\n    method (`str`):The method to use to get the mask. Should be one of [\"total\", \"frequency\"].\n    \"\"\"\n\n    sign = tensor.sign()\n    if method == \"total\":\n        sign_magnitude = (sign * tensor.abs()).sum(dim=0)\n    elif method == \"frequency\":\n        sign_magnitude = sign.sum(dim=0)\n    else:\n        raise RuntimeError(f'Unimplemented mask method \"{method}\"')\n    majority_sign = torch.where(sign_magnitude >= 0, 1, -1)\n    return sign == majority_sign\n\n\ndef disjoint_merge(task_tensors, majority_sign_mask):\n    mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)\n    num_params_preserved = majority_sign_mask.sum(dim=0)\n    return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)\n"
  },
  {
    "path": "server/text_generation_server/utils/peft.py",
    "content": "import os\nfrom typing import Union\nfrom loguru import logger\nimport torch\n\nfrom transformers import AutoTokenizer\nfrom peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM\n\n\ndef download_and_unload_peft(model_id, revision, trust_remote_code):\n    torch_dtype = torch.float16\n\n    logger.info(\"Trying to load a Peft model. It might take a while without feedback\")\n    try:\n        model = AutoPeftModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    except Exception:\n        model = AutoPeftModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    logger.info(\"Peft model detected.\")\n    logger.info(\"Merging the lora weights.\")\n\n    base_model_id = model.peft_config[\"default\"].base_model_name_or_path\n\n    model = model.merge_and_unload()\n\n    os.makedirs(model_id, exist_ok=True)\n    cache_dir = model_id\n    logger.info(f\"Saving the newly created merged model to {cache_dir}\")\n    tokenizer = AutoTokenizer.from_pretrained(\n        base_model_id, trust_remote_code=trust_remote_code\n    )\n    model.save_pretrained(cache_dir, safe_serialization=True)\n    model.config.save_pretrained(cache_dir)\n    tokenizer.save_pretrained(cache_dir)\n\n\ndef download_peft(\n    model_id: Union[str, os.PathLike], revision: str, trust_remote_code: bool\n):\n    torch_dtype = torch.float16\n    try:\n        _model = AutoPeftModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    except Exception:\n        _model = AutoPeftModelForSeq2SeqLM.from_pretrained(\n            model_id,\n            revision=revision,\n            torch_dtype=torch_dtype,\n            trust_remote_code=trust_remote_code,\n            low_cpu_mem_usage=True,\n        )\n    logger.info(\"Peft model downloaded.\")\n"
  },
  {
    "path": "server/text_generation_server/utils/prefill_chunking.py",
    "content": "from typing import Optional\n\nSUPPORT_CHUNKING: Optional[bool] = None\nMAX_PREFILL_TOKENS: Optional[int] = None\n\n\ndef set_support_chunking(support_chunking: bool):\n    global SUPPORT_CHUNKING\n    SUPPORT_CHUNKING = support_chunking\n\n\ndef get_support_chunking() -> bool:\n    global SUPPORT_CHUNKING\n    return SUPPORT_CHUNKING\n\n\ndef set_max_prefill_tokens(max_prefill_tokens: int):\n    global MAX_PREFILL_TOKENS\n    MAX_PREFILL_TOKENS = max_prefill_tokens\n\n\ndef get_max_prefill_tokens() -> int:\n    global MAX_PREFILL_TOKENS\n    return MAX_PREFILL_TOKENS\n"
  },
  {
    "path": "server/text_generation_server/utils/quantization.py",
    "content": "import json\nimport os\nfrom dataclasses import dataclass\nfrom typing import Optional, List\n\nfrom huggingface_hub import hf_hub_download\nfrom text_generation_server.layers.marlin.gptq import can_use_gptq_marlin\nfrom text_generation_server.utils.weights import (\n    DefaultWeightsLoader,\n    WeightsLoader,\n)\n\n\n# TODO: Split this config to have a single config type per quant method\n@dataclass\nclass _QuantizerConfig:\n    bits: int\n    checkpoint_format: Optional[str]\n    desc_act: bool\n    groupsize: int\n    quant_method: str\n    sym: bool\n    weight_block_size: Optional[List[int]]\n    modules_to_not_convert: List[str]\n\n\n@dataclass\nclass _FP8QuantizerConfig:\n    activation_scale_ub: float\n\n\ndef _get_config_json(model_id: str, revision: Optional[str], filename: str):\n    if os.path.exists(\n        os.path.join(\n            model_id,\n        )\n    ):\n        filename = os.path.join(model_id, filename)\n    else:\n        filename = hf_hub_download(model_id, filename=filename, revision=revision)\n    with open(filename, \"r\") as f:\n        return json.load(f)\n\n\n# We should probably do this with Pydantic JSON deserialization,\n# but for now we'll stay close to the old _set_gptq_params.\ndef _get_quantizer_config(model_id, revision):\n    bits = 4\n    groupsize = -1\n    quant_method = \"gptq\"\n    checkpoint_format = None\n    sym = False\n    desc_act = False\n    weight_block_size = None\n    modules_to_not_convert = []\n\n    filename = \"config.json\"\n    try:\n        data = _get_config_json(model_id, revision, filename)\n        # FP8 config\n        if data[\"quantization_config\"][\"quant_method\"] == \"fbgemm_fp8\":\n            return _FP8QuantizerConfig(\n                activation_scale_ub=data[\"quantization_config\"][\"activation_scale_ub\"]\n            )\n        weight_block_size = data[\"quantization_config\"].get(\"weight_block_size\", None)\n\n        if \"zero_point\" in data[\"quantization_config\"]:\n            sym = not data[\"quantization_config\"][\"zero_point\"]\n            quant_method = \"awq\"\n        elif \"sym\" in data[\"quantization_config\"]:\n            sym = data[\"quantization_config\"][\"sym\"]\n\n        bits = data[\"quantization_config\"][\"bits\"]\n        groupsize = data[\"quantization_config\"][\"group_size\"]\n        # Order is important here, desc_act is missing on some real models\n        quant_method = data[\"quantization_config\"][\"quant_method\"]\n        checkpoint_format = data[\"quantization_config\"].get(\"checkpoint_format\")\n        desc_act = data[\"quantization_config\"].get(\"desc_act\", False)\n        modules_to_not_convert = data[\"quantization_config\"].get(\n            \"modules_to_not_convert\", []\n        )\n        if modules_to_not_convert is None:\n            modules_to_not_convert = []\n    except Exception:\n        filename = \"quantize_config.json\"\n        try:\n            data = _get_config_json(model_id, revision, filename)\n            bits = data[\"bits\"]\n            groupsize = data[\"group_size\"]\n\n            if \"zero_point\" in data:\n                sym = not data[\"zero_point\"]\n                quant_method = \"awq\"\n            elif \"sym\" in data:\n                sym = data[\"sym\"]\n\n            desc_act = data[\"desc_act\"]\n            if \"version\" in data and data[\"version\"] == \"GEMM\":\n                quant_method = \"awq\"\n        except Exception:\n            filename = \"quant_config.json\"\n            try:\n                data = _get_config_json(model_id, revision, filename)\n                bits = data[\"w_bit\"]\n                groupsize = data[\"q_group_size\"]\n                desc_act = data[\"desc_act\"]\n                if \"version\" in data and data[\"version\"] == \"GEMM\":\n                    quant_method = \"awq\"\n            except Exception:\n                pass\n\n    return _QuantizerConfig(\n        bits=bits,\n        groupsize=groupsize,\n        quant_method=quant_method,\n        checkpoint_format=checkpoint_format,\n        sym=sym,\n        desc_act=desc_act,\n        weight_block_size=weight_block_size,\n        modules_to_not_convert=modules_to_not_convert,\n    )\n\n\ndef get_loader(\n    quantize: Optional[str], model_id: str, revision: Optional[str]\n) -> WeightsLoader:\n    if quantize == \"compressed-tensors\":\n        config = _get_config_json(model_id, revision, \"config.json\")\n        from text_generation_server.layers.compressed_tensors import (\n            CompressedTensorsLoader,\n        )\n\n        return CompressedTensorsLoader(config)\n\n    quantizer_config = _get_quantizer_config(model_id, revision)\n    if quantize in {\"awq\", \"gptq\"}:\n        from text_generation_server.layers.gptq import GPTQWeightsLoader\n\n        # TODO: improve check once we have one config type per quantize value\n        if not isinstance(quantizer_config, _QuantizerConfig):\n            raise ValueError(\n                f\"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config.\"\n            )\n\n        if can_use_gptq_marlin(\n            bits=quantizer_config.bits,\n            groupsize=quantizer_config.groupsize,\n            quant_method=quantizer_config.quant_method,\n            quantize=quantize,\n            sym=quantizer_config.sym,\n        ):\n            from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader\n\n            return GPTQMarlinWeightsLoader(\n                bits=quantizer_config.bits,\n                desc_act=quantizer_config.desc_act,\n                groupsize=quantizer_config.groupsize,\n                quant_method=quantizer_config.quant_method,\n                quantize=quantize,\n                sym=quantizer_config.sym,\n            )\n        else:\n            return GPTQWeightsLoader(\n                bits=quantizer_config.bits,\n                desc_act=quantizer_config.desc_act,\n                groupsize=quantizer_config.groupsize,\n                quant_method=quantizer_config.quant_method,\n                quantize=quantize,\n                sym=quantizer_config.sym,\n                modules_to_not_convert=quantizer_config.modules_to_not_convert,\n            )\n    elif quantize == \"bitsandbytes\":\n        from text_generation_server.layers.bnb import BNBWeight\n\n        return DefaultWeightsLoader(BNBWeight)\n    elif quantize == \"bitsandbytes-fp4\":\n        from text_generation_server.layers.bnb import BNBFP4Weight\n\n        return DefaultWeightsLoader(BNBFP4Weight)\n    elif quantize == \"bitsandbytes-nf4\":\n        from text_generation_server.layers.bnb import BNBNF4Weight\n\n        return DefaultWeightsLoader(BNBNF4Weight)\n    elif quantize == \"eetq\":\n        from text_generation_server.layers.eetq import EETQWeight\n\n        return DefaultWeightsLoader(EETQWeight)\n    elif quantize == \"exl2\":\n        from text_generation_server.layers.exl2 import Exl2WeightsLoader\n\n        return Exl2WeightsLoader()\n    elif quantize == \"marlin\":\n        from text_generation_server.layers.marlin import MarlinWeightsLoader\n\n        # TODO: improve check once we have one config type per quantize value\n        if not isinstance(quantizer_config, _QuantizerConfig):\n            raise ValueError(\n                f\"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config.\"\n            )\n\n        return MarlinWeightsLoader(\n            bits=quantizer_config.bits,\n            is_marlin_24=quantizer_config.checkpoint_format == \"marlin_24\",\n        )\n    elif quantize == \"fp8\" or quantize is None:\n        from text_generation_server.layers.fp8 import HybridFP8UnquantLoader\n\n        # Since the default for the quantize config is _QuantizerConfig,\n        # we need to add this check to not get an attribute error\n        activation_scale_ub = None\n        weight_block_size = quantizer_config.weight_block_size\n        if isinstance(quantizer_config, _FP8QuantizerConfig):\n            activation_scale_ub = quantizer_config.activation_scale_ub\n\n        return HybridFP8UnquantLoader(\n            activation_scale_ub,\n            to_fp8=quantize == \"fp8\",\n            weight_block_size=weight_block_size,\n        )\n    else:\n        raise ValueError(f\"Unknown quantization method: {quantize}\")\n"
  },
  {
    "path": "server/text_generation_server/utils/segments.py",
    "content": "# Origin:   https://github.com/predibase/lorax\n# Path:     lorax/server/lorax_server/utils/segments.py\n# License:  Apache License Version 2.0, January 2004\n\nfrom typing import List, Tuple, Union\n\nimport torch\nimport numpy as np\n\n\ndef find_segments(\n    adapter_indices: Union[torch.Tensor, List[int]],\n) -> Tuple[List[int], List[int]]:\n    if isinstance(adapter_indices, torch.Tensor):\n        adapter_indices = adapter_indices.cpu().numpy()\n    elif isinstance(adapter_indices, list):\n        adapter_indices = np.array(adapter_indices)\n\n    change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)\n    change_indices = np.nonzero(change_mask)[0]\n\n    segments = [0]\n    segments.extend(change_indices[1:].tolist())\n    segments.append(len(adapter_indices))\n\n    segment_indices = adapter_indices[change_indices].tolist()\n\n    return segments, segment_indices\n\n\nclass SegmentConcatBuilder:\n    def __init__(self):\n        self.adapter_segment_indices = []\n        self.adapter_segment_tensors = []\n\n    def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):\n        # Update adapter segments\n        if self.adapter_segment_tensors:\n            # Because we have already processed at least one batch, remove the 0 start index\n            # from this batch denoting the beginning of the segment, then offset all segment\n            # positions by the value of the last segment in the previous batch to account for\n            # the concatenation.\n            adapter_segments = (\n                adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]\n            )\n\n        if (\n            self.adapter_segment_indices\n            and self.adapter_segment_indices[-1] == segment_indices[0]\n        ):\n            # If the last segment in the previous batch is the same as the first segment in this batch,\n            # then we merge them together into a single segment. In effect, this means removing it from\n            # the segment indices of this batch, and extending the segment span by removing the segment\n            # end index from the previous batch.\n            segment_indices = segment_indices[1:]\n            self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]\n\n        self.adapter_segment_indices.extend(segment_indices)\n        self.adapter_segment_tensors.append(adapter_segments)\n\n    def build(self) -> Tuple[torch.Tensor, List[int]]:\n        return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices\n"
  },
  {
    "path": "server/text_generation_server/utils/speculate.py",
    "content": "SPECULATE = None\n\n\ndef get_speculate() -> int:\n    global SPECULATE\n    return SPECULATE\n\n\ndef set_speculate(speculate: int):\n    global SPECULATE\n    SPECULATE = speculate\n"
  },
  {
    "path": "server/text_generation_server/utils/tokens.py",
    "content": "import re\nfrom typing import List, Optional, Tuple, Set, Union\n\nimport torch\nfrom text_generation_server.pb import generate_pb2\nfrom text_generation_server.pb.generate_pb2 import FinishReason, GrammarType\nfrom text_generation_server.utils.logits_process import (\n    FrequencyPenaltyLogitsProcessor,\n    GrammarLogitProcessor,\n    HeterogeneousProcessorWrapper,\n    HeterogeneousRepetitionPenaltyLogitsProcessor,\n    HeterogeneousFrequencyPenaltyLogitsProcessor,\n    HeterogeneousTemperatureLogitsWarper,\n    HeterogeneousTopKLogitsWarper,\n    HeterogeneousTopPLogitsWarper,\n    HeterogeneousTypicalLogitsWarper,\n    HeterogeneousGrammarLogitProcessor,\n    static_warper,\n)\nfrom text_generation_server.utils.watermark import WatermarkLogitsProcessor\nfrom transformers import PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor\n\n\nclass NextTokenChooser:\n    def __init__(\n        self,\n        watermark: bool = False,\n        temperature: float = 1.0,\n        repetition_penalty: float = 1.0,\n        frequency_penalty: float = 0.0,\n        top_k: Optional[int] = None,\n        top_p: Optional[float] = None,\n        typical_p: Optional[float] = None,\n        do_sample: bool = False,\n        seed: int = 0,\n        device: str = \"cpu\",\n        tokenizer: Optional[PreTrainedTokenizerBase] = None,\n        grammar: str = \"\",\n        grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,\n        fsm_grammar_state: int = 0,\n    ):\n        self.watermark_processor = (\n            WatermarkLogitsProcessor(device=device) if watermark else None\n        )\n        self.repetition_processor = (\n            RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)\n            if repetition_penalty and repetition_penalty != 1.0\n            else None\n        )\n        self.frequency_processor = (\n            FrequencyPenaltyLogitsProcessor(penalty=frequency_penalty)\n            if frequency_penalty and frequency_penalty != 0.0\n            else None\n        )\n        self.grammar_processor = (\n            GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)\n            if grammar != \"\"\n            else None\n        )\n        self.tokenizer = tokenizer\n\n        has_warpers = (\n            (temperature is not None and temperature != 1.0)\n            or (top_k is not None and top_k != 0)\n            or (top_p is not None and top_p < 1.0)\n            or (typical_p is not None and typical_p < 1.0)\n        )\n        if has_warpers:\n            self.static_warper = static_warper(\n                temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p\n            )\n        else:\n            self.static_warper = None\n\n        sampling = do_sample or has_warpers\n\n        self.choice = Sampling(seed, device) if sampling else Greedy()\n        self.fsm_grammar_state = fsm_grammar_state\n        self.grammar = grammar\n\n    def __call__(self, input_ids, scores):\n        if self.watermark_processor is not None:\n            scores = self.watermark_processor(input_ids, scores)\n        if self.repetition_processor is not None:\n            scores = self.repetition_processor(input_ids, scores)\n        if self.frequency_processor is not None:\n            scores = self.frequency_processor(input_ids, scores)\n        if self.grammar_processor is not None:\n            scores = self.grammar_processor(scores, self.fsm_grammar_state)\n\n        if self.static_warper is None:\n            next_logprob = torch.log_softmax(scores, -1)\n        else:\n            scores, next_logprob = self.static_warper(scores)\n\n        next_id = self.choice(scores[-1]).view(1, 1)\n\n        return next_id, next_logprob\n\n    def advance_grammar(self, next_id: int):\n        if self.grammar_processor is not None:\n            self.fsm_grammar_state = self.grammar_processor.advance(\n                next_id, self.fsm_grammar_state\n            )\n        return self\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.NextTokenChooserParameters,\n        device: torch.device,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> \"NextTokenChooser\":\n        return NextTokenChooser(\n            watermark=pb.watermark,\n            temperature=pb.temperature,\n            repetition_penalty=pb.repetition_penalty,\n            frequency_penalty=pb.frequency_penalty,\n            top_k=pb.top_k,\n            top_p=pb.top_p,\n            typical_p=pb.typical_p,\n            do_sample=pb.do_sample,\n            seed=pb.seed,\n            device=device,\n            tokenizer=tokenizer,\n            grammar=pb.grammar,\n            grammar_type=pb.grammar_type,\n        )\n\n\nclass StopSequenceCriteria:\n    def __init__(self, stop_sequence: str):\n        stop_sequence = re.escape(stop_sequence)\n        self.regex = re.compile(f\"{stop_sequence}$\")\n\n    def __call__(self, output: str) -> bool:\n        if self.regex.findall(output):\n            return True\n        return False\n\n\nclass StoppingCriteria:\n    def __init__(\n        self,\n        eos_token_ids: Optional[Union[Set[int], int]],\n        stop_sequence_criterias: List[StopSequenceCriteria],\n        max_new_tokens: int = 20,\n        ignore_eos_token: bool = False,\n    ):\n        if eos_token_ids is None:\n            eos_token_ids = set()\n        elif isinstance(eos_token_ids, int):\n            eos_token_ids = set([eos_token_ids])\n        elif isinstance(eos_token_ids, set):\n            eos_token_ids = eos_token_ids\n        else:\n            raise RuntimeError(\n                f\"eos_token_ids is of invalid type {type(eos_token_ids)}, expected int, None or set[int]\"\n            )\n        self.eos_token_ids = eos_token_ids\n        self.stop_sequence_criterias = stop_sequence_criterias\n        self.max_new_tokens = max_new_tokens\n        self.current_tokens = 0\n        self.current_output = \"\"\n        self.ignore_eos_token = ignore_eos_token\n\n    def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:\n        self.current_tokens += 1\n        if self.current_tokens >= self.max_new_tokens:\n            return True, FinishReason.FINISH_REASON_LENGTH\n\n        if isinstance(last_token, torch.Tensor):\n            last_token = last_token.item()\n\n        if not self.ignore_eos_token and last_token in self.eos_token_ids:\n            return True, FinishReason.FINISH_REASON_EOS_TOKEN\n\n        if self.stop_sequence_criterias:\n            self.current_output += last_output\n            # There is no need to keep an output that is too long\n            if len(self.current_output) > 300:\n                # Slice to -200 to avoid doing it all the time\n                self.current_output = self.current_output[-200:]\n            for stop_sequence_criteria in self.stop_sequence_criterias:\n                if stop_sequence_criteria(self.current_output):\n                    return True, FinishReason.FINISH_REASON_STOP_SEQUENCE\n\n        return False, None\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: generate_pb2.StoppingCriteriaParameters,\n        tokenizer: PreTrainedTokenizerBase,\n    ) -> \"StoppingCriteria\":\n        stop_sequence_criterias = [\n            StopSequenceCriteria(sequence) for sequence in pb.stop_sequences\n        ]\n        # TODO Hack because eos_token_id cannot be what we want.\n        eos_token_id = getattr(tokenizer, \"_eos_token_ids\", tokenizer.eos_token_id)\n        return StoppingCriteria(\n            eos_token_id,\n            stop_sequence_criterias,\n            pb.max_new_tokens,\n            pb.ignore_eos_token,\n        )\n\n\ndef create_n_gram_speculation(\n    input_ids: torch.Tensor,\n    next_ids: torch.Tensor,\n    accepted_ids: torch.Tensor,\n    speculate: int,\n    verbose: bool,\n):\n    # Very trivial approach, find first match in the string.\n    # This is much less refined than actual n-gram but seems to work\n    # relatively OK in grounded mode and is by far much faster with\n    # much less worst case complexity as everything happens on device.\n    B = accepted_ids.shape[0]\n    device = input_ids.device\n    seeds = next_ids[accepted_ids.cumsum(dim=-1) - 1]\n    indices = (input_ids == seeds.unsqueeze(-1)).max(dim=1).indices + 1\n    all_indices = indices.unsqueeze(-1).expand(B, speculate) + torch.arange(\n        speculate, device=device\n    )\n    all_indices = torch.clamp(all_indices, max=input_ids.shape[1] - 1)\n\n    speculative_ids = input_ids.gather(dim=-1, index=all_indices)\n    return speculative_ids\n\n\nclass HeterogeneousNextTokenChooser:\n    def __init__(\n        self,\n        dtype: torch.dtype,\n        device: torch.device,\n        watermark: List[bool],\n        temperature: List[float],\n        repetition_penalty: List[float],\n        frequency_penalty: List[float],\n        top_k: List[int],\n        top_p: List[float],\n        typical_p: List[float],\n        do_sample: List[bool],\n        seeds: List[int],\n        tokenizer: PreTrainedTokenizerBase,\n        grammars: List[str],\n        grammar_types: List[int],\n        fsm_grammar_states=List[int],\n    ):\n        warpers = []\n\n        self.watermark_processor = (\n            HeterogeneousProcessorWrapper(\n                {\n                    i: WatermarkLogitsProcessor(device=device)\n                    for i, do_watermark in enumerate(watermark)\n                    if do_watermark\n                }\n            )\n            if any(watermark)\n            else None\n        )\n\n        self.repetition_processor = (\n            HeterogeneousRepetitionPenaltyLogitsProcessor(\n                repetition_penalty, dtype, device\n            )\n            if any([x != 1.0 for x in repetition_penalty])\n            else None\n        )\n\n        self.frequency_processor = (\n            HeterogeneousFrequencyPenaltyLogitsProcessor(\n                frequency_penalty, dtype, device\n            )\n            if any([x != 0.0 for x in frequency_penalty])\n            else None\n        )\n\n        self.grammar_processor = (\n            HeterogeneousGrammarLogitProcessor(\n                tokenizer, device, grammars, grammar_types\n            )\n            if any([grammar != \"\" for grammar in grammars])\n            else None\n        )\n\n        if any(x != 1.0 for x in temperature):\n            do_sample = [\n                sample or x != 1.0 for x, sample in zip(temperature, do_sample)\n            ]\n            warpers.append(\n                HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)\n            )\n\n        if any(x != 0 for x in top_k):\n            do_sample = [sample or x != 0 for x, sample in zip(top_k, do_sample)]\n            warpers.append(HeterogeneousTopKLogitsWarper(top_k, device))\n\n        if any(x < 1.0 for x in top_p):\n            do_sample = [sample or x < 1.0 for x, sample in zip(top_p, do_sample)]\n            warpers.append(HeterogeneousTopPLogitsWarper(top_p, dtype, device))\n\n        if any(x < 1.0 for x in typical_p):\n            do_sample = [sample or x < 1.0 for x, sample in zip(typical_p, do_sample)]\n            warpers.append(HeterogeneousTypicalLogitsWarper(typical_p, dtype, device))\n\n        self.warpers = warpers\n\n        if any(do_sample):\n            self.choice = HeterogeneousSampling(do_sample, seeds, device)\n        else:\n            self.choice = Greedy()\n\n        self.seeds = seeds\n        self.do_sample = do_sample\n        self.dtype = dtype\n        self.device = device\n        self.tokenizer = tokenizer\n        self.fsm_grammar_states = fsm_grammar_states\n        self.grammars = grammars\n        self.grammar_types = grammar_types\n\n    def __call__(\n        self,\n        input_ids: torch.Tensor,\n        scores: torch.Tensor,\n        speculate: int,\n        speculated_ids: Optional[torch.Tensor] = None,\n        speculative_scores: Optional[torch.Tensor] = None,\n        verbose=False,\n    ):\n        if speculated_ids is not None:\n            B = scores.shape[0] // (speculated_ids.shape[1] + 1)\n            S = speculated_ids.shape[1] + 1\n            scores = scores.view(B, S, -1)\n        else:\n            B = scores.shape[0]\n            S = 1\n            scores = scores.view(B, S, -1)\n\n        next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)\n\n        for j in range(S):\n            _scores = scores[:, j]\n            if self.watermark_processor is not None:\n                _scores = self.watermark_processor(input_ids, _scores)\n            if self.repetition_processor is not None:\n                _scores = self.repetition_processor(input_ids, _scores)\n            if self.frequency_processor is not None:\n                _scores = self.frequency_processor(input_ids, _scores)\n            if self.grammar_processor is not None:\n                _scores = self.grammar_processor(_scores, self.fsm_grammar_states)\n            for warper in self.warpers:\n                _scores = warper(input_ids, _scores)\n            _next_ids = self.choice(_scores)\n            scores[:, j] = _scores\n            next_ids[:, j] = _next_ids\n        next_ids = next_ids.view(B * S)\n        allscores = scores.view(B * S, -1)\n        alllogprobs = torch.log_softmax(allscores, -1)\n\n        if speculated_ids is not None:\n            accepted_ids = []\n            B = next_ids.shape[0] // (speculated_ids.shape[1] + 1)\n            S = speculated_ids.shape[1] + 1\n            indices = []\n            for i in range(B):\n                _next_ids = next_ids[i * S : (i + 1) * S]\n                _speculated_ids = speculated_ids[i]\n                validate_speculative = _next_ids[:-1] == _speculated_ids\n                index = i * S\n                accepted = 1\n                # First is always valid\n                indices.append(index)\n                for valid in validate_speculative.tolist():\n                    if valid:\n                        index += 1\n                        accepted += 1\n                        indices.append(index)\n                    else:\n                        break\n                accepted_ids.append(accepted)\n\n            accepted_ids = torch.tensor(\n                accepted_ids, device=input_ids.device, dtype=input_ids.dtype\n            )\n            next_ids = next_ids[indices]\n            logprobs = alllogprobs[indices]\n            indices = torch.arange(B, device=input_ids.device) * S\n            if speculative_scores is not None:\n                speculative_scores = speculative_scores[indices + accepted_ids - 1]\n        else:\n            accepted_ids = torch.ones_like(next_ids)\n            logprobs = alllogprobs\n\n        next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)\n\n        if speculate > 0:\n            if speculative_scores is not None:\n                # Medusa provided some scores\n                speculative_ids = Greedy()(speculative_scores)\n            else:\n                # n-gram\n                speculative_ids = create_n_gram_speculation(\n                    input_ids, next_ids, accepted_ids, speculate, verbose\n                )\n        else:\n            speculative_ids = None\n\n        return next_ids, next_logprobs, alllogprobs, accepted_ids, speculative_ids\n\n    def advance_grammar(self, next_ids: List[int]):\n        if self.grammar_processor is not None:\n            other_new_states = self.grammar_processor.advance_batch(\n                next_ids, self.fsm_grammar_states\n            )\n            self.fsm_grammar_states = other_new_states\n        return self\n\n    def advance_grammar_single(self, grammar_state_index: int, next_id: int):\n        if self.grammar_processor is not None:\n            self.fsm_grammar_states[grammar_state_index] = (\n                self.grammar_processor.advance_at_index(\n                    next_id,\n                    self.fsm_grammar_states[grammar_state_index],\n                    grammar_state_index,\n                )\n            )\n        return self\n\n    def filter(self, indices):\n        if self.watermark_processor is not None:\n            self.watermark_processor = self.watermark_processor.filter(indices)\n\n        if self.repetition_processor is not None:\n            self.repetition_processor = self.repetition_processor.filter(indices)\n\n        if self.frequency_processor is not None:\n            self.frequency_processor = self.frequency_processor.filter(indices)\n\n        if self.grammar_processor is not None:\n            self.grammar_processor = self.grammar_processor.filter(indices)\n\n        filtered_warpers = []\n        for warper in self.warpers:\n            filtered_warper = warper.filter(indices)\n            if filtered_warper is not None:\n                filtered_warpers.append(filtered_warper)\n        self.warpers = filtered_warpers\n\n        self.seeds = [self.seeds[i] for i in indices]\n        self.do_sample = [self.do_sample[i] for i in indices]\n\n        new_grammars = []\n        new_fsm_grammar_states = []\n        new_grammar_types = []\n        for i in indices:\n            new_grammars.append(self.grammars[i])\n            new_fsm_grammar_states.append(self.fsm_grammar_states[i])\n            new_grammar_types.append(self.grammar_types[i])\n\n        self.grammars = new_grammars\n        self.fsm_grammar_states = new_fsm_grammar_states\n        self.grammar_types = new_grammar_types\n\n        if any(self.do_sample):\n            self.choice.filter(indices)\n        else:\n            self.choice = Greedy()\n\n        return self\n\n    @classmethod\n    def from_pb(\n        cls,\n        pb: List[generate_pb2.NextTokenChooserParameters],\n        dtype: torch.dtype,\n        device: torch.device,\n        tokenizer: PreTrainedTokenizerBase,\n        fsm_grammar_states: Optional[List[int]] = None,\n    ) -> \"HeterogeneousNextTokenChooser\":\n        return HeterogeneousNextTokenChooser(\n            watermark=[pb_.watermark for pb_ in pb],\n            temperature=[pb_.temperature for pb_ in pb],\n            repetition_penalty=[pb_.repetition_penalty for pb_ in pb],\n            frequency_penalty=[pb_.frequency_penalty for pb_ in pb],\n            top_k=[pb_.top_k for pb_ in pb],\n            top_p=[pb_.top_p for pb_ in pb],\n            typical_p=[pb_.typical_p for pb_ in pb],\n            do_sample=[pb_.do_sample for pb_ in pb],\n            seeds=[pb_.seed for pb_ in pb],\n            device=device,\n            dtype=dtype,\n            tokenizer=tokenizer,\n            grammars=[pb_.grammar for pb_ in pb],\n            grammar_types=[pb_.grammar_type for pb_ in pb],\n            fsm_grammar_states=(\n                fsm_grammar_states if fsm_grammar_states else [0] * len(pb)\n            ),\n        )\n\n\nclass Sampling:\n    def __init__(self, seed: int, device: str = \"cpu\"):\n        self.generator = torch.Generator(device)\n        self.generator.manual_seed(seed)\n        self.seed = seed\n\n    def __call__(self, logits):\n        probs = torch.nn.functional.softmax(logits, -1)\n        # Avoid GPU<->CPU sync done by torch multinomial\n        # See: https://github.com/pytorch/pytorch/blob/925a3788ec5c06db62ca732a0e9425a26a00916f/aten/src/ATen/native/Distributions.cpp#L631-L637\n        q = torch.empty_like(probs).exponential_(1, generator=self.generator)\n        return probs.div_(q).argmax()\n\n\nclass Greedy:\n    def __call__(self, logits):\n        return logits.argmax(dim=-1)\n\n\nclass HeterogeneousSampling:\n    r\"\"\"\n    Mixed greedy and probabilistic sampling. Compute both and pick the right one for each sample.\n    \"\"\"\n\n    def __init__(self, do_sample: List[bool], seeds: List[int], device: torch.device):\n        self.seeds = seeds\n\n        self.greedy_indices = []\n        self.sampling_mapping = {}\n        for i, (sample, seed) in enumerate(zip(do_sample, seeds)):\n            if sample:\n                self.sampling_mapping[i] = Sampling(seed, device)\n            else:\n                self.greedy_indices.append(i)\n\n        self.greedy = Greedy()\n\n    def __call__(self, logits):\n        out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)\n        if self.greedy_indices:\n            # Computing for all indices is faster than slicing\n            torch.argmax(logits, -1, out=out)\n\n        for i, sampling in self.sampling_mapping.items():\n            out[i] = sampling(logits[i])\n        return out\n\n    def filter(self, indices):\n        new_greedy_indices = []\n        new_sampling_mapping = {}\n        for i, idx in enumerate(indices):\n            if idx in self.sampling_mapping:\n                new_sampling_mapping[i] = self.sampling_mapping[idx]\n            else:\n                new_greedy_indices.append(i)\n\n        self.greedy_indices = new_greedy_indices\n        self.sampling_mapping = new_sampling_mapping\n        return self\n\n\ndef batch_top_tokens(\n    top_n_tokens: List[int],\n    top_n_tokens_tensor: torch.Tensor,\n    logprobs: torch.Tensor,\n    accepted_ids: torch.Tensor,\n) -> Tuple[List[List[List[int]]], List[List[List[float]]]]:\n    \"\"\"Find the top n most likely tokens for a batch of generations.\n\n    When multiple tokens have equal probabilities and they don't all fit, the\n    remaining tokens are also returned.\n    \"\"\"\n    max_top_n = max(top_n_tokens)\n    # Early exit when top_n_tokens is not used\n    if max_top_n == 0:\n        return [[[]]] * len(top_n_tokens), [[[]]] * len(top_n_tokens)\n\n    batch_size = accepted_ids.shape[0]\n    speculate_size = logprobs.shape[0] // batch_size\n    top_n_tokens_tensor = top_n_tokens_tensor.repeat_interleave(speculate_size)\n    # Ensure top_n doesn't exceed vocab size\n    top_n_tokens = [\n        min(tok, logprobs.size(-1))\n        for tok in top_n_tokens\n        for _ in range(speculate_size)\n    ]\n\n    # Parallel kthvalue adapted from https://discuss.pytorch.org/t/how-to-efficiently-get-the-k-th-largest-values-in-parallel/160529/2\n    # Sorted topk is faster than torch.sort() since we only need a small subset\n    sorted_top_k = torch.topk(logprobs, k=max_top_n, dim=-1, sorted=True).values\n\n    nth_highest = torch.gather(\n        sorted_top_k, 1, (top_n_tokens_tensor - 1).clip(min=0).unsqueeze(1)\n    )\n    nth_highest[nth_highest == -float(\"inf\")] = torch.finfo(logprobs.dtype).min\n\n    # Find the new \"fuzzy\" top n values\n    top_n_indices = (logprobs >= nth_highest).nonzero()\n    _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)\n\n    k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()\n    # Take a new topk for these new max n values\n    top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)\n\n    top_n_ishes = top_n_ishes.tolist()\n    top_indices = top_k.indices.tolist()\n    top_values = top_k.values.tolist()\n\n    batch_top_token_ids = []\n    batch_top_token_logprobs = []\n    accepted_ids_list = accepted_ids.tolist()\n    for i, n_accepted_ids in enumerate(accepted_ids_list):\n        start = speculate_size * i\n        stop = speculate_size * (i + 1)\n        _top_indices = top_indices[start:stop]\n        _top_values = top_values[start:stop]\n        _top_n_ishes = top_n_ishes[start:stop]\n        _top_n_tokens = top_n_tokens[start:stop]\n\n        _top_indices = _top_indices[:n_accepted_ids]\n        _top_values = _top_values[:n_accepted_ids]\n        _top_n_ishes = _top_n_ishes[:n_accepted_ids]\n        _top_n_tokens = _top_n_tokens[:n_accepted_ids]\n\n        row_top_token_ids = []\n        row_top_token_logprobs = []\n\n        for idxs, vals, n, req_n in zip(\n            _top_indices, _top_values, _top_n_ishes, _top_n_tokens\n        ):\n            indices = idxs[:n] if req_n > 0 else []\n            values = vals[:n] if req_n > 0 else []\n\n            row_top_token_ids.append(indices)\n            row_top_token_logprobs.append(values)\n\n        batch_top_token_ids.append(row_top_token_ids)\n        batch_top_token_logprobs.append(row_top_token_logprobs)\n\n    return batch_top_token_ids, batch_top_token_logprobs\n"
  },
  {
    "path": "server/text_generation_server/utils/watermark.py",
    "content": "# coding=utf-8\n# Copyright 2023 Authors of \"A Watermark for Large Language Models\"\n# available at https://arxiv.org/abs/2301.10226\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport os\n\nimport torch\nfrom transformers import LogitsProcessor\nfrom typing import List, Union\n\nGAMMA = float(os.getenv(\"WATERMARK_GAMMA\", 0.5))\nDELTA = float(os.getenv(\"WATERMARK_DELTA\", 2.0))\n\n\nclass WatermarkLogitsProcessor(LogitsProcessor):\n    def __init__(\n        self,\n        gamma: float = GAMMA,\n        delta: float = DELTA,\n        hash_key: int = 15485863,  # just a large prime number to create a rng seed with sufficient bit width\n        device: str = \"cpu\",\n    ):\n        # watermarking parameters\n        self.gamma = gamma\n        self.delta = delta\n        self.rng = torch.Generator(device=device)\n        self.hash_key = hash_key\n\n    def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):\n        if isinstance(input_ids, list):\n            assert (\n                len(input_ids) >= 1\n            ), \"requires at least a 1 token prefix sequence to seed rng\"\n            prev_token = input_ids[-1]\n        else:\n            assert len(input_ids) == 1\n            input_ids = input_ids[0]\n            assert (\n                input_ids.shape[-1] >= 1\n            ), \"requires at least a 1 token prefix sequence to seed rng\"\n            prev_token = input_ids[-1].item()\n        self.rng.manual_seed(self.hash_key * prev_token)\n\n    def _get_greenlist_ids(\n        self,\n        input_ids: Union[List[int], torch.LongTensor],\n        max_value: int,\n        device: torch.device,\n    ) -> List[int]:\n        # seed the rng using the previous tokens/prefix\n        self._seed_rng(input_ids)\n\n        greenlist_size = int(max_value * self.gamma)\n        vocab_permutation = torch.randperm(max_value, device=device, generator=self.rng)\n        greenlist_ids = vocab_permutation[:greenlist_size]\n        return greenlist_ids\n\n    @staticmethod\n    def _calc_greenlist_mask(\n        scores: torch.FloatTensor, greenlist_token_ids\n    ) -> torch.BoolTensor:\n        green_tokens_mask = torch.zeros_like(scores)\n        green_tokens_mask[-1, greenlist_token_ids] = 1\n        final_mask = green_tokens_mask.bool()\n        return final_mask\n\n    @staticmethod\n    def _bias_greenlist_logits(\n        scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float\n    ) -> torch.Tensor:\n        scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias\n        return scores\n\n    def __call__(\n        self, input_ids: Union[List[int], torch.LongTensor], scores: torch.FloatTensor\n    ) -> torch.FloatTensor:\n        greenlist_ids = self._get_greenlist_ids(\n            input_ids, scores.shape[-1], scores.device\n        )\n        green_tokens_mask = self._calc_greenlist_mask(\n            scores=scores, greenlist_token_ids=greenlist_ids\n        )\n\n        scores = self._bias_greenlist_logits(\n            scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta\n        )\n        return scores\n"
  },
  {
    "path": "server/text_generation_server/utils/weights.py",
    "content": "import torch\n\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Union, Type\nfrom safetensors import safe_open\nfrom dataclasses import dataclass\n\nfrom text_generation_server.utils.import_utils import SYSTEM\n\n\nclass WeightsLoader(ABC):\n    \"\"\"\n    Instances of this type implement higher-level weight loading.\n\n    At a low-level, every weight is stored in the Safetensors format.\n    The interpretation of weights may be different however, for instance\n    could be packed, quantized weights. Loaders are responsible for\n    interpreting the raw tensors, sharding tensors in a manner compatible\n    with the format, etc.\n    \"\"\"\n\n    @abstractmethod\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply without tensor paralllism.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        \"\"\"\n        Get the packed weights at the given prefix with column-splitting for\n        tensor parallelism. This method should be used when multiple different\n        weights are packed into a tensor, for instance, query/key/value\n        weights or a gate/up projection.\n\n        The `block_sizes` determines the proportions of the packed tensors.\n        The columns are split in equally sized blocks when `block_sizes` is an\n        `int`, or in blocks proportional given to the sizes. For instance\n        `[2, 1, 1]` will divide an input with dimensionality `1024` in\n        `[512, 256, 256]`.\n        \"\"\"\n        ...\n\n    def get_weights_col(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get weights at the given prefix and apply column-splitting for tensor\n        paralllism.\n        \"\"\"\n        return weights.get_multi_weights_col([prefix], 0)\n\n    @abstractmethod\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        \"\"\"\n        Get the weights at the given prefixes, column-split them for tensor\n        parallelim, and then concatenate the weights along the given dimension.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        \"\"\"\n        Get the weights at the given prefix and apply row-splitting for tensor\n        parallism.\n        \"\"\"\n        ...\n\n\nclass Weight(ABC):\n    \"\"\"Instances of this type implement unquantized/quantized/to-be\n    quantized weights.\"\"\"\n\n    @abstractmethod\n    def get_linear(self, bias: torch.Tensor):\n        \"\"\"Create a linear layer from this weight.\"\"\"\n        ...\n\n\n@dataclass\nclass UnquantizedWeight(Weight):\n    weight: torch.Tensor\n\n    def get_linear(self, bias: torch.Tensor):\n        from text_generation_server.layers.linear import FastLinear, FastLinearROCm\n\n        if SYSTEM == \"rocm\":\n            return FastLinearROCm(self.weight, bias)\n        else:\n            return FastLinear(self.weight, bias)\n\n\nclass DefaultWeightsLoader(WeightsLoader):\n    \"\"\"Weight loader that loads (unquantized) Torch tensors.\"\"\"\n\n    def __init__(self, weight_class: Type[UnquantizedWeight]):\n        \"\"\"Create a loader. Weights will be wrapped using the given `weights_class`,\n        normally this will be `UnquantizedWeight`, but a quantizer-specific class\n        such as `Fp8Weight` can be used to quantize the weights during loading.\n        \"\"\"\n        self.weight_class = weight_class\n\n    \"\"\"\n    Loader that uses tensors as-is with the exception of applying sharding\n    and/or concatenation.\n    \"\"\"\n\n    def get_weights(self, weights: \"Weights\", prefix: str):\n        return weights.get_tensor(f\"{prefix}.weight\")\n\n    def get_weights_col_packed(\n        self,\n        weights: \"Weights\",\n        prefix: str,\n        block_sizes: Union[int, List[int]],\n    ):\n        return self.weight_class(\n            weights.get_packed_sharded(\n                f\"{prefix}.weight\", dim=0, block_sizes=block_sizes\n            ),\n        )\n\n    def get_multi_weights_col(self, weights: \"Weights\", prefixes: List[str], dim: int):\n        w = [weights.get_sharded(f\"{p}.weight\", dim=0) for p in prefixes]\n        return self.weight_class(torch.cat(w, dim=dim))\n\n    def get_weights_row(self, weights: \"Weights\", prefix: str):\n        return self.weight_class(\n            weights.get_sharded(f\"{prefix}.weight\", dim=1),\n        )\n\n\nclass Weights:\n    def __init__(\n        self,\n        filenames: List[Path],\n        device,\n        dtype,\n        process_group,\n        weights_loader: WeightsLoader,\n        aliases: Optional[Dict[str, List[str]]] = None,\n        prefix: Optional[str] = None,\n    ):\n        routing = {}\n        for filename in filenames:\n            with safe_open(filename, framework=\"pytorch\") as f:\n                for k in f.keys():\n                    if k in routing:\n                        raise RuntimeError(\n                            f\"Key {k} was found in multiple files: {filename} and {routing[k]}\"\n                        )\n                    routing[k] = filename\n        if aliases is None:\n            aliases = {}\n        self.aliases = aliases\n        self.routing = routing\n        self.device = device\n        self.dtype = dtype\n        self.process_group = process_group\n        self.prefix = prefix\n        self.weights_loader = weights_loader\n        self._handles = {}\n\n    def _get_handle(self, filename):\n        if filename not in self._handles:\n            f = safe_open(filename, framework=\"pytorch\")\n            self._handles[filename] = f\n\n        return self._handles[filename]\n\n    def get_filename(self, tensor_name: str) -> (str, str):\n        names = [tensor_name]\n        if self.prefix is not None:\n            prefixed = f\"{self.prefix}.{tensor_name}\"\n            names.append(prefixed)\n        for name in names:\n            filename = self.routing.get(name, None)\n            if filename is not None:\n                return str(filename), name\n\n            aliases = self.aliases.get(name, [])\n            for alias in aliases:\n                filename = self.routing.get(alias, None)\n                if filename is not None:\n                    return str(filename), alias\n        raise RuntimeError(f\"weight {tensor_name} does not exist\")\n\n    def _get_slice(self, tensor_name: str):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        return slice_\n\n    def has_tensor(self, tensor_name: str):\n        try:\n            self.get_filename(tensor_name)\n        except Exception:\n            return False\n        return True\n\n    def get_shape(self, tensor_name: str):\n        return self._get_slice(tensor_name).get_shape()\n\n    def get_tensor(\n        self, tensor_name: str, to_device: bool = True, to_dtype: bool = True\n    ) -> torch.Tensor:\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        tensor = f.get_tensor(tensor_name)\n        # Special case for gptq which shouldn't convert\n        # u4 which are disguised as int32. Exl2 uses int16\n        # as well. FP8 uses torch.float8_e4m3fn\n        if (\n            tensor.dtype\n            not in [\n                torch.float8_e4m3fn,\n                torch.int8,\n                torch.int16,\n                torch.int32,\n                torch.int64,\n            ]\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n        if to_device:\n            tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_partial_sharded(\n        self, tensor_name: str, dim: int, to_device=True, to_dtype=True\n    ):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n\n        size = slice_.get_shape()[dim]\n        block_size = (size + world_size - 1) // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n\n        if dim == 0:\n            tensor = slice_[start:stop]\n        elif dim == 1:\n            tensor = slice_[:, start:stop]\n        else:\n            raise NotImplementedError(\"Let's make that generic when needed\")\n        # Special case for gptq which shouldn't convert\n        # u4 which are disguised as int32. exl2 uses int16.\n        # FP8 uses torch.float8_e4m3fn.\n        if (\n            tensor.dtype\n            not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n        if to_device:\n            tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):\n        filename, tensor_name = self.get_filename(tensor_name)\n        f = self._get_handle(filename)\n        slice_ = f.get_slice(tensor_name)\n        world_size = self.process_group.size()\n        size = slice_.get_shape()[dim]\n        assert (\n            size % world_size == 0\n        ), f\"The choosen size {size} is not compatible with sharding on {world_size} shards\"\n        return self.get_partial_sharded(\n            tensor_name, dim, to_device=to_device, to_dtype=to_dtype\n        )\n\n    def get_packed_sharded(\n        self,\n        tensor_name: str,\n        dim: int,\n        block_sizes: Union[int, List[int]],\n        to_dtype=True,\n    ) -> torch.Tensor:\n        \"\"\"\n        Get a shard from a tensor that packs multiple tensors.\n\n        When a tensor packs multiple tensors (such as QKV or an up\n        projection + gate projection), sharding with `get_sharded` is not\n        safe since it would not split the packed tensors across shards.\n\n        This method shards a tensor, such that the packed tensors are\n        split across shards.\n\n        The columns are split in equally sized blocks when blocks is an `int`, or\n        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will\n        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is\n        convenient for e.g. splitting QKV without knowing the storage details of\n        quantized weights.\n        \"\"\"\n        slice_ = self._get_slice(tensor_name)\n        total_size = slice_.get_shape()[dim]\n        block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)\n\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n\n        tensors = []\n        block_offset = 0\n        for block_size in block_sizes:\n            assert (\n                block_size % world_size == 0\n            ), f\"Prepacked tensor cannot be sharded across {world_size} shards\"\n            shard_block_size = block_size // world_size\n            start = rank * shard_block_size\n            stop = (rank + 1) * shard_block_size\n            if dim == 0:\n                tensor = slice_[block_offset + start : block_offset + stop]\n            elif dim == 1:\n                tensor = slice_[:, block_offset + start : block_offset + stop]\n            else:\n                raise NotImplementedError(\"Currently only dim=0 or dim=1 is supported\")\n            tensors.append(tensor)\n            block_offset += block_size\n        tensor = torch.cat(tensors, dim=dim)\n        tensor = tensor.to(device=self.device)\n\n        # Avoid casting quantizer dtypes.\n        if (\n            tensor.dtype\n            not in [\n                torch.float8_e4m3fn,\n                torch.int8,\n                torch.int16,\n                torch.int32,\n                torch.int64,\n            ]\n            and to_dtype\n        ):\n            tensor = tensor.to(dtype=self.dtype)\n\n        return tensor\n\n    def get_weights(self, prefix: str):\n        return self.weights_loader.get_weights(self, prefix)\n\n    def get_weights_col_packed_qkv(\n        self,\n        prefix: str,\n        num_heads: int,\n        num_key_value_heads: int,\n    ):\n        return self.get_weights_col_packed(\n            prefix, [num_heads, num_key_value_heads, num_key_value_heads]\n        )\n\n    def get_weights_col_packed_gate_up(self, prefix: str):\n        return self.get_weights_col_packed(prefix, 2)\n\n    def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):\n        \"\"\"\n        The columns are split in equally sized blocks when blocks is an `int`, or\n        in blocks proportional given to the sizes. For instance `[2, 1, 1]` will\n        divide an input with dimensionality `1024` in `[512, 256, 256]`. This is\n        convenient for e.g. splitting QKV without knowing the storage details of\n        quantized weights.\n        \"\"\"\n        return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)\n\n    def get_weights_col(self, prefix: str):\n        return self.weights_loader.get_weights_col(self, prefix)\n\n    def get_multi_weights_col(self, prefixes: List[str], dim: int):\n        return self.weights_loader.get_multi_weights_col(self, prefixes, dim)\n\n    def get_tensor_shard(self, var, dim):\n        world_size = self.process_group.size()\n        rank = self.process_group.rank()\n        block_size = var.size()[dim] // world_size\n        start = rank * block_size\n        stop = (rank + 1) * block_size\n        if dim == 0:\n            tensor = var[start:stop]\n        elif dim == 1:\n            tensor = var[:, start:stop]\n        else:\n            raise NotImplementedError(\"Let's make that generic when needed\")\n        tensor = tensor.to(dtype=self.dtype)\n        tensor = tensor.to(device=self.device)\n        return tensor\n\n    def get_weights_row(self, prefix: str):\n        return self.weights_loader.get_weights_row(self, prefix)\n\n    @contextmanager\n    def use_loader(self, weights_loader: WeightsLoader):\n        \"\"\"\n        This method is a context manager that can be used to use `Weights` with\n        a different loader for the duration of the context.\n        \"\"\"\n\n        old_loader = self.weights_loader\n        self.weights_loader = weights_loader\n        try:\n            yield\n        finally:\n            self.weights_loader = old_loader\n\n    @property\n    def loader(self):\n        return self.weights_loader\n\n\ndef _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:\n    \"\"\"\n    Convert block count or proportions to block sizes.\n\n    This function accepts\n\n    - The number of blocks (int), in which case the block size is\n      total_size//blocks; or\n    - A list of block sizes (List[int]).\n\n    In the latter case, if sum(blocks) < total_size, the ratios between\n    the block sizes will be preserved. For instance, if blocks is\n    [2, 1, 1] and total_size is 1024, the returned block sizes are\n    [512, 256, 256].\n    \"\"\"\n    if isinstance(blocks, list):\n        total_blocks = sum(blocks)\n        assert (\n            total_size % total_blocks == 0\n        ), f\"Cannot split {total_size} in proportional blocks: {blocks}\"\n        part_size = total_size // total_blocks\n        return [part_size * block for block in blocks]\n    else:\n        assert total_size % blocks == 0, f\"Prepacked is not divisible by {blocks}\"\n        single_size = total_size // blocks\n        return [single_size] * blocks\n"
  },
  {
    "path": "tgi-entrypoint.sh",
    "content": "#!/bin/bash\n\nldconfig 2>/dev/null || echo 'unable to refresh ld cache, not a big deal in most cases'\n\nsource /usr/src/.venv/bin/activate\nexec text-generation-launcher $@\n"
  },
  {
    "path": "update_doc.py",
    "content": "import subprocess\nimport argparse\nimport ast\nimport json\nimport os\n\nTEMPLATE = \"\"\"\n# Supported Models\n\nText Generation Inference enables serving optimized models. The following sections list which models (VLMs & LLMs) are supported.\n\nSUPPORTED_MODELS\n\n\nIf the above list lacks the model you would like to serve, depending on the model's pipeline type, you can try to initialize and serve the model anyways to see how well it performs, but performance isn't guaranteed for non-optimized models:\n\n```python\n# for causal LMs/text-generation models\nAutoModelForCausalLM.from_pretrained(<model>, device_map=\"auto\")\n# or, for text-to-text generation models\nAutoModelForSeq2SeqLM.from_pretrained(<model>, device_map=\"auto\")\n```\n\nIf you wish to serve a supported model that already exists on a local folder, just point to the local folder.\n\n```bash\ntext-generation-launcher --model-id <PATH-TO-LOCAL-BLOOM>\n```\n\"\"\"\n\n\ndef check_cli(check: bool):\n    output = subprocess.check_output([\"text-generation-launcher\", \"--help\"]).decode(\n        \"utf-8\"\n    )\n\n    wrap_code_blocks_flag = \"<!-- WRAP CODE BLOCKS -->\"\n    final_doc = f\"# Text-generation-launcher arguments\\n\\n{wrap_code_blocks_flag}\\n\\n\"\n\n    lines = output.split(\"\\n\")\n\n    header = \"\"\n    block = []\n    for line in lines:\n        if line.startswith(\"  -\") or line.startswith(\"      -\"):\n            rendered_block = \"\\n\".join(block)\n            if header:\n                final_doc += f\"## {header}\\n```shell\\n{rendered_block}\\n```\\n\"\n            else:\n                final_doc += f\"```shell\\n{rendered_block}\\n```\\n\"\n            block = []\n            tokens = line.split(\"<\")\n            if len(tokens) > 1:\n                header = tokens[-1][:-1]\n            else:\n                header = line.split(\"--\")[-1]\n            header = header.upper().replace(\"-\", \"_\")\n\n        block.append(line)\n\n    rendered_block = \"\\n\".join(block)\n    final_doc += f\"## {header}\\n```shell\\n{rendered_block}\\n```\\n\"\n    block = []\n\n    filename = \"docs/source/reference/launcher.md\"\n    if check:\n        with open(filename, \"r\") as f:\n            doc = f.read()\n            if doc != final_doc:\n                tmp = \"launcher.md\"\n                with open(tmp, \"w\") as g:\n                    g.write(final_doc)\n                diff = subprocess.run(\n                    [\"diff\", tmp, filename], capture_output=True\n                ).stdout.decode(\"utf-8\")\n                print(diff)\n                raise Exception(\n                    \"Cli arguments Doc is not up-to-date, run `python update_doc.py` in order to update it\"\n                )\n    else:\n        with open(filename, \"w\") as f:\n            f.write(final_doc)\n\n\ndef check_supported_models(check: bool):\n    filename = \"server/text_generation_server/models/__init__.py\"\n    with open(filename, \"r\") as f:\n        tree = ast.parse(f.read())\n\n    enum_def = [\n        x for x in tree.body if isinstance(x, ast.ClassDef) and x.name == \"ModelType\"\n    ][0]\n    _locals = {}\n    _globals = {}\n    exec(f\"import enum\\n{ast.unparse(enum_def)}\", _globals, _locals)\n    ModelType = _locals[\"ModelType\"]\n    list_string = \"\"\n    for data in ModelType:\n        list_string += f\"- [{data.value['name']}]({data.value['url']})\"\n        if data.value.get(\"multimodal\", None):\n            list_string += \" (Multimodal)\"\n        list_string += \"\\n\"\n\n    final_doc = TEMPLATE.replace(\"SUPPORTED_MODELS\", list_string)\n\n    filename = \"docs/source/supported_models.md\"\n    if check:\n        with open(filename, \"r\") as f:\n            doc = f.read()\n            if doc != final_doc:\n                tmp = \"supported.md\"\n                with open(tmp, \"w\") as g:\n                    g.write(final_doc)\n                diff = subprocess.run(\n                    [\"diff\", tmp, filename], capture_output=True\n                ).stdout.decode(\"utf-8\")\n                print(diff)\n                raise Exception(\n                    \"Supported models is not up-to-date, run `python update_doc.py` in order to update it\"\n                )\n    else:\n        with open(filename, \"w\") as f:\n            f.write(final_doc)\n\n\ndef get_openapi_schema():\n    try:\n        output = subprocess.check_output([\"text-generation-router\", \"print-schema\"])\n        return json.loads(output)\n    except subprocess.CalledProcessError as e:\n        print(f\"Error running text-generation-router print-schema: {e}\")\n        raise SystemExit(1)\n    except json.JSONDecodeError:\n        print(\"Error: Invalid JSON received from text-generation-router print-schema\")\n        raise SystemExit(1)\n\n\ndef check_openapi(check: bool):\n    new_openapi_data = get_openapi_schema()\n    filename = \"docs/openapi.json\"\n    tmp_filename = \"openapi_tmp.json\"\n\n    with open(tmp_filename, \"w\") as f:\n        json.dump(new_openapi_data, f, indent=2)\n        f.write(\"\\n\")\n\n    if check:\n        diff = subprocess.run(\n            [\n                \"diff\",\n                tmp_filename,\n                filename,\n            ],\n            capture_output=True,\n        ).stdout.decode(\"utf-8\")\n        os.remove(tmp_filename)\n\n        if diff:\n            print(diff)\n            raise Exception(\n                \"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it\"\n            )\n\n    else:\n        os.rename(tmp_filename, filename)\n        print(\"OpenAPI documentation updated.\")\n    p = subprocess.run(\n        [\n            \"redocly\",\n            # allow for trailing whitespace since it's not significant\n            # and the precommit hook will remove it\n            \"lint\",\n            \"--skip-rule\",\n            \"security-defined\",\n            filename,\n        ],\n        capture_output=True,\n    )\n    errors = p.stderr.decode(\"utf-8\")\n    # The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where\n    # utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969\n    print(errors)\n    if p.returncode != 0:\n        print(errors)\n        raise Exception(\n            f\"OpenAPI documentation is invalid, `redocly lint {filename}` showed some error:\\n {errors}\"\n        )\n    return True\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--check\", action=\"store_true\")\n\n    args = parser.parse_args()\n\n    check_cli(args.check)\n    check_supported_models(args.check)\n    check_openapi(args.check)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]