[
  {
    "path": ".github/CODEOWNERS",
    "content": "* @dacorvo @sunmarc\n"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\n<!--\nCongratulations! You've made it this far! You're not quite done yet though.\n\nOnce merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.\n\nThen, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.\n\nOnce you're done, someone will review your PR shortly (see the section \"Who can review?\" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost.\n-->\n\n<!-- Remove if not applicable -->\n\nFixes # (issue)\n\n\n## Before submitting\n- [ ] Did you read the [contributor guideline](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md#create-a-pull-request),\n      Pull Request section?\n- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link\n      to it if that's the case.\n- [ ] Did you run all tests locally and make sure they pass.\n- [ ] Did you write any new necessary tests?\n\n\n## Who can review?\n\nAnyone in the community is free to review the PR once the tests have passed. Feel free to tag\nmembers/contributors who may be interested in your PR.\n"
  },
  {
    "path": ".github/workflows/check-commits.yml",
    "content": "name: Check Commits\n\non: [workflow_call]\n\njobs:\n  build:\n    name: Check commits\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n\n      - uses: huggingface/action-check-commits@v1.0.0\n        with:\n          GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n          max-commits: \"10\"\n          min-words: \"3\"\n          forbidden-words: \"fixup\"\n"
  },
  {
    "path": ".github/workflows/linux-cpu-tests.yml",
    "content": "name: Linux CPU tests\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"optimum/quanto/**\"\n      - \"tests/**\"\n      - \"pyproject.toml\"\n  pull_request:\n    types: [assigned, opened, synchronize, reopened]\n    paths:\n      - \"optimum/quanto/**\"\n      - \"tests/**\"\n      - \"pyproject.toml\"\n\njobs:\n  check-commits:\n    uses: ./.github/workflows/check-commits.yml\n  python-quality:\n    uses: ./.github/workflows/python-quality.yml\n  test-ubuntu-cpu:\n    needs: [check-commits, python-quality]\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.9\", \"3.11\"]\n\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd  # v6.0.2\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e  # v2\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build and install quanto\n        run: |\n          pip install --upgrade pip\n          pip install -e .[dev]\n\n      - name: Run base tests\n        run: |\n          python -m pytest tests --ignore=tests/models --ignore=tests/cli\n\n      - name: Run models tests\n        run: |\n          pip install accelerate transformers diffusers\n          python -m pytest tests/models\n\n\n      - name: Run CLI tests\n        run: |\n          pip install optimum\n          python -m pytest tests/cli\n\n  run_staging_tests:\n    needs: [check-commits, python-quality]\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.9\", \"3.11\"]\n\n    steps:\n      - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd  # v6.0.2\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e  # v2\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Build and install quanto\n        run: |\n          pip install --upgrade pip\n          pip install -e .[dev]\n\n      - name: Run models hub tests\n        run: |\n          pip install accelerate transformers diffusers\n          HUGGINGFACE_CO_STAGING=true python -m pytest tests/models -k \"hub\"\n"
  },
  {
    "path": ".github/workflows/linux-cuda-tests.yml",
    "content": "name: Linux CUDA tests\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"optimum/quanto/**\"\n      - \"tests/**\"\n      - \"pyproject.toml\"\n  pull_request:\n    types: [assigned, opened, synchronize, reopened]\n    paths:\n      - \"optimum/quanto/**\"\n      - \"tests/**\"\n      - \"pyproject.toml\"\n\njobs:\n  check-commits:\n    uses: ./.github/workflows/check-commits.yml\n  python-quality:\n    uses: ./.github/workflows/python-quality.yml\n  test-ubuntu-cuda:\n    needs: [check-commits, python-quality]\n    runs-on:\n      group: aws-g5-4xlarge-plus\n    strategy:\n      fail-fast: false\n      matrix:\n        cuda-version: [\"11.8\", \"12.4\", \"12.6\"]\n    container:\n      image: pytorch/pytorch:2.6.0-cuda${{ matrix.cuda-version }}-cudnn9-devel\n      options: --gpus 0\n\n    steps:\n      - uses: actions/checkout@v2\n      - name: Check CUDA installation\n        run: |\n          nvcc -V\n\n      - name: Build and install quanto\n        run: |\n          pip install --upgrade pip\n          pip install -e .[dev]\n\n      - name: Run base tests\n        run: |\n          python -m pytest tests --ignore=tests/models --ignore=tests/cli\n\n      - name: Run models tests\n        run: |\n          pip install accelerate transformers diffusers\n          python -m pytest tests/models\n\n      - name: Run CLI tests\n        run: |\n          pip install optimum\n          python -m pytest tests/cli\n"
  },
  {
    "path": ".github/workflows/linux-examples.yml",
    "content": "name: Linux examples (CPU, CUDA)\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"optimum/quanto/**\"\n      - \"examples/**\"\n      - \"pyproject.toml\"\n  pull_request:\n    types: [assigned, opened, synchronize, reopened]\n    paths:\n      - \"optimum/quanto/**\"\n      - \"examples/**\"\n      - \"pyproject.toml\"\n\njobs:\n  check-commits:\n    uses: ./.github/workflows/check-commits.yml\n  python-quality:\n    uses: ./.github/workflows/python-quality.yml\n  run-examples:\n    needs: [check-commits, python-quality]\n    runs-on:\n      group: aws-g5-4xlarge-plus\n    strategy:\n      fail-fast: false\n      matrix:\n        device: [\"cpu\", \"cuda\"]\n    container:\n      image: pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel\n      options: --gpus 0\n\n    steps:\n      - uses: actions/checkout@v2\n      - name: Check CUDA installation\n        run: |\n          nvcc -V\n\n      - name: Build and install packages\n        run: |\n          pip install --upgrade pip\n          pip install -e .[examples]\n\n      # Run examples\n      - name: Run MNIST classification example\n        run: |\n          for w in int4 int8 float8; do \\\n            for a in none int8 float8; do \\\n              python examples/vision/image-classification/mnist/quantize_mnist_model.py \\\n                --weights $w --activations $a --device ${{ matrix.device }}; \\\n            done; \\\n          done\n      - name: Run OWL detection example\n        run: |\n          for w in int4 int8 float8; do \\\n            python examples/vision/object-detection/quantize_owl_model.py \\\n              --image http://images.cocodataset.org/val2017/000000039769.jpg \\\n              --texts \"a photo of a cat\" \"a remote\" \\\n              --weights $w --device ${{ matrix.device }}; \\\n          done\n      - name: Run text-classification example\n        run: |\n          for w in int4 int8; do \\\n            for a in none int8; do \\\n              python examples/nlp/text-classification/sst2/quantize_sst2_model.py \\\n                --weights $w --activations $a --device ${{ matrix.device }}; \\\n            done; \\\n          done\n      - name: Run text-to-image example\n        if: ${{ matrix.device == 'cuda'}}\n        run: |\n          for w in int4 int8 fp8; do \\\n            python examples/vision/text-to-image/quantize_pixart_sigma.py \\\n              --qtype $w --device ${{ matrix.device }}; \\\n          done\n"
  },
  {
    "path": ".github/workflows/python-quality.yml",
    "content": "name: Python code quality\n\non: [workflow_call]\n\njobs:\n  check_code_quality:\n    runs-on: ubuntu-latest\n\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python\n        uses: actions/setup-python@v2\n        with:\n          python-version: 3.9\n      - name: Install dependencies\n        run: |\n          pip install --upgrade pip\n          pip install .[dev]\n      - run: ruff format bench examples optimum tests --diff\n      - run: ruff check --show-fixes bench examples optimum tests\n"
  },
  {
    "path": ".github/workflows/security.yml",
    "content": "name: Security Checks\n\non:\n  push:\n\npermissions:\n  contents: read\n\njobs:\n  secrets:\n    runs-on: ubuntu-latest\n    steps:\n      - shell: bash\n        env:\n          REF_NAME: ${{ github.ref_name }}\n          HEAD_REF: ${{ github.event.pull_request.head.ref }}\n        run: |\n          if [ \"${{ github.event_name }}\" == \"push\" ]; then\n            echo \"depth=$(($(jq length <<< '${{ toJson(github.event.commits) }}') + 2))\" >> $GITHUB_ENV\n            echo \"branch=$REF_NAME\" >> $GITHUB_ENV\n          fi\n          if [ \"${{ github.event_name }}\" == \"pull_request\" ]; then\n            echo \"depth=$((${{ github.event.pull_request.commits }}+2))\" >> $GITHUB_ENV\n            echo \"branch=$HEAD_REF\" >> $GITHUB_ENV\n          fi\n      - name: Checkout code\n        uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd  # v6.0.2\n        with:\n          ref: ${{env.branch}}\n          fetch-depth: ${{env.depth}}\n      - name: Scan for secrets\n        uses: trufflesecurity/trufflehog@6bd2d14f7a4bc1e569fa3550efa7ec632a4fa67b  # main"
  },
  {
    "path": ".github/workflows/stale.yml",
    "content": "name: 'Close stale issues and PRs'\non:\n  schedule:\n    - cron: '30 1 * * *'\n  workflow_dispatch:\n\npermissions:\n  issues: write\n  pull-requests: write\n\njobs:\n  stale:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/stale@v9\n        with:\n          stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.'\n          stale-pr-message: 'This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.'\n          close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.'\n          close-pr-message: 'This PR was closed because it has been stalled for 5 days with no activity.'\n          days-before-issue-stale: 30\n          days-before-pr-stale: 15\n          days-before-issue-close: 5\n          days-before-pr-close: 5\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\n.pytest_cache\n*.egg-info\ndist\n.venv\nbuild/"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "<!---\nCopyright 2024 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Contribute to optimum-quanto\n\nEveryone is welcome to contribute, and we value everybody's contribution. Code\ncontributions are not the only way to help the community. Answering questions, helping\nothers, and improving the documentation are also immensely valuable.\n\nIt also helps us if you spread the word! Reference the library in blog posts\nabout the awesome projects it made possible, shout out on Twitter every time it has\nhelped you, or simply ⭐️ the repository to say thank you.\n\nHowever you choose to contribute, please be mindful and respect our\n[code of conduct](https://github.com/huggingface/transformers/blob/main/CODE_OF_CONDUCT.md).\n\n**This guide is directly inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**\n\n## Ways to contribute\n\nThere are several ways you can contribute:\n\n* Fix outstanding issues with the existing code.\n* Submit issues related to bugs or desired new features.\n* Implement new kernels.\n\n> All contributions are equally valuable to the community. 🥰\n\n## Fixing outstanding issues\n\nIf you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md/#create-a-pull-request) and open a Pull Request!\n\n## Submitting a bug-related issue or feature request\n\nDo your best to follow these guidelines when submitting a bug-related issue or a feature\nrequest. It will make it easier for us to come back to you quickly and with good\nfeedback.\n\n### Did you find a bug?\n\nThe `optimum-quanto` backend will become more robust and reliable thanks to users who will report the problems they encounter.\n\nBefore you report an issue, we would really appreciate it if you could **make sure the bug was not\nalready reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. If you're unsure whether the bug is in your code or the library, please ask in the [forum](https://discuss.huggingface.co/) first. This helps us respond quicker to fixing issues related to the library versus general questions.\n\nOnce you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:\n\n* Your **OS type and version** and **Python** and **PyTorch** versions.\n* A short, self-contained, code snippet that allows us to reproduce the bug in\n  less than 30s.\n* The *full* traceback if an exception is raised.\n* Attach any other additional information, like screenshots, you think may help.\n\n### Do you want a new feature?\n\nIf there is a new feature you'd like to see, please open an issue and describe:\n\n1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?\n\n   Whatever it is, we'd love to hear about it!\n\n2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.\n3. Provide a *code snippet* that demonstrates the features usage.\n4. If the feature is related to a paper, please include a link.\n\nIf your issue is well written we're already 80% of the way there by the time you create it.\n\n## Do you want to implement a new kernel?\n\nWith the constant evolution of hardware backends, there is always a need for updating the kernels for better performance.\n\n* The hardware configuration(s) it will apply to.\n* If any, a short description of the novel techniques that should be used to implement the kernel.\n\nIf you are willing to contribute the kernel yourself, let us know so we can help you add it to `optimum-quanto`!\n\n## Create a Pull Request\n\nBefore writing any code, we strongly advise you to search through the existing PRs or\nissues to make sure nobody is already working on the same thing. If you are\nunsure, it is always a good idea to open an issue to get some feedback.\n\nYou will need basic `git` proficiency to contribute. While `git` is not the easiest tool to use, it has the greatest manual. Type `git --help` in a shell and enjoy! If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.\n\nYou'll need **Python 3.8** or above to contribute. Follow the steps below to start contributing:\n\n1. Fork the [repository](https://github.com/huggingface/optimum-quanto) by\n   clicking on the **[Fork](https://github.com/huggingface/optimum-quanto/fork)** button on the repository's page. This creates a copy of the code\n   under your GitHub user account.\n\n2. Clone your fork to your local disk, and add the base repository as a remote:\n\n   ```bash\n   git clone git@github.com:<your Github handle>/optimum-quanto.git\n   cd optimum-quanto\n   git remote add upstream https://github.com/huggingface/optimum-quanto.git\n   ```\n\n3. Create a new branch to hold your development changes:\n\n   ```bash\n   git checkout -b a-descriptive-name-for-my-changes\n   ```\n\n   🚨 **Do not** work on the `main` branch!\n\n4. Set up a development environment by running the following command in a virtual environment:\n\n   ```bash\n   pip install -e \".[dev]\"\n   ```\n\n   If `optimum-quanto` was already installed in the virtual environment, remove\n   it with `pip uninstall optimum-quanto` before reinstalling it in editable\n   mode with the `-e` flag.\n\n5. Develop the features in your branch.\n\n   As you work on your code, you should make sure the test suite\n   passes. Run the tests impacted by your changes like this:\n\n   ```bash\n   pytest tests/<TEST_TO_RUN>.py\n   ```\n\n   `optimum-quanto` relies on `black` and `ruff` to format its source code\n   consistently. After you make changes, apply automatic style corrections and code verifications\n   that can't be automated in one go with:\n\n   ```bash\n   make style\n   ```\n   Once you're happy with your changes, add the changed files with `git add` and\n   record your changes locally with `git commit`:\n\n   ```bash\n   git add modified_file.py\n   git commit\n   ```\n\n   This repository uses a `rebase` strategy when merging pull-requests, meaning that your commits will **not** be squashed automatically.\n\n   We therefore request you to keep a tidy queue of commits in your pull-request, clearly communicating the changes you made in each commit.\n\n   **This is enforced by the continuous integration, so your pull-request will not be reviewed if your commit queue is not clean.**\n\n   Although this is not mandatory, we kindly ask you to consider using [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary)\n   (here the full [specification](https://www.conventionalcommits.org/en/v1.0.0/))!\n\n   This article gives a brief [rationale](https://julien.ponge.org/blog/the-power-of-conventional-commits/) of why this will make our life and yours easier.\n\n   To keep your copy of the code up to date with the original\n   repository, rebase your branch on `upstream/branch` *before* you open a pull request or if requested by a maintainer:\n\n   ```bash\n   git fetch upstream\n   git rebase upstream/main\n   ```\n\n   Before submitting, cleanup your commit history to make it more readable for the reviewer (like squashing temporary commits and editing commit messages to clearly explain what you changed).\n\n   Push your changes to your branch:\n\n   ```bash\n   git push -u origin a-descriptive-name-for-my-changes\n   ```\n\n   If you've already opened a pull request, you'll need to force push with the `--force` flag. Otherwise, if the pull request hasn't been opened yet, you can just push your changes normally.\n\n6. Now you can go to your fork of the repository on GitHub and click on **Pull Request** to open a pull request. Make sure you tick off all the boxes on our [checklist](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md/#pull-request-checklist) below. When you're ready, you can send your changes to the project maintainers for review.\n\n7. It's ok if maintainers request changes, it happens to our core contributors\n   too! So everyone can see the changes in the pull request, work in your local\n   branch and push the changes to your fork. They will automatically appear in\n   the pull request.\n\n### Pull request checklist\n\n☐ The pull request title should summarize your contribution.<br>\n☐ If your pull request addresses an issue, please mention the issue number in the pull\nrequest description to make sure they are linked (and people viewing the issue know you\nare working on it).<br>\n☐ To indicate a work in progress please prefix the title with `[WIP]`. These are\nuseful to avoid duplicated work, and to differentiate it from PRs ready to be merged.<br>\n☐ Make sure existing tests pass.<br>\n☐ If adding a new feature, also add tests for it.<br>\n☐ All public methods must have informative docstrings.<br>\n\n### Tests\n\nAn extensive test suite is included to test the library behavior in the [tests](https://github.com/huggingface/optimum-quanto/tree/main/tests) folder.\n\nFrom the root of the repository, specify a *path to a subfolder or a test file* to run the test.\n\n```bash\npython -m pytest -sv ./tests/<subfolder>/<test>.py\n```\n\nYou can run all tests by typing:\n\n```bash\nmake test\n```\n\n### Style guide\n\nFor documentation strings, `optimum-quanto` follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).\nCheck `transformers` [documentation writing guide](https://github.com/huggingface/transformers/tree/main/docs#writing-documentation---specification)\nfor more information.\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2023 - The Hugging Face team. All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: check test style\n\ncheck_dirs := optimum tests bench examples\n\ncheck:\n\truff check --show-fixes ${check_dirs}\n\truff format ${check_dirs} --diff\n\nstyle:\n\truff check ${check_dirs} --fix\n\truff format ${check_dirs}\n\ntest:\n\tpython -m pytest -sv tests\n"
  },
  {
    "path": "README.md",
    "content": "# Optimum Quanto\n\n> This project is currently in maintenance mode. We accept pull requests only for minor bug fixes, documentation improvements, and other maintenance tasks. Major new features or breaking changes are unlikely to be merged. For production-ready quantization features or active development, consider alternative projects such as [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) or [torchAO](https://github.com/pytorch/ao).\n\n🤗 Optimum Quanto is a pytorch quantization backend for [optimum](https://huggingface.co/docs/optimum/en/index).\n\nIt has been designed with versatility and simplicity in mind:\n\n- all features are available in eager mode (works with non-traceable models),\n- quantized models can be placed on any device (including CUDA and MPS),\n- automatically inserts quantization and dequantization stubs,\n- automatically inserts quantized functional operations,\n- automatically inserts quantized modules (see below the list of supported modules),\n- provides a seamless workflow from a float model to a dynamic to a static quantized model,\n- serialization compatible with pytorch `weight_only` and 🤗 `safetensors`,\n- accelerated matrix multiplications on CUDA devices (int8-int8, fp16-int4, bf16-int8, bf16-int4),\n- supports int2, int4, int8 and float8 weights,\n- supports int8 and float8 activations.\n\nFeatures yet to be implemented:\n\n- dynamic activations smoothing,\n- kernels for all mixed matrix multiplications on all devices,\n- compatibility with [torch compiler](https://pytorch.org/docs/stable/torch.compiler.html) (aka dynamo).\n\n## Performances\n\nIn a nutshell:\n\n- accuracy: models compiled with `int8`/`float8` weights and `float8` activations are very close to the full-precision models,\n- latency: whenever optimized kernels are available, the inference of quantized model is comparable with the full-precision models when quantizing only the model weights,\n- device memory: approximately divided by float bits / integer bits.\n\nThe paragraph below is just an example. Please refer to the `bench` folder for detailed results per use-case of model.\n\n### meta-llama/Meta-Llama-3.1-8B\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/optimum-quanto/blob/main/bench/generation/charts/meta-llama-Meta-Llama-3.1-8B_bf16_Perplexity.png\" alt=\"meta-llama/Meta-Llama-3.1-8B WikiText perplexity\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/optimum-quanto/blob/main/bench/generation/charts/meta-llama-Meta-Llama-3.1-8B_bf16_Latency__ms_.png\" alt=\"meta-llama/Meta-Llama-3.1-8B Latency\">\n  </div>\n </center>\n</div>\n\n## Installation\n\nOptimum Quanto is available as a pip package.\n\n```sh\npip install optimum-quanto\n```\n\n## Quantization workflow for Hugging Face models\n\n`optimum-quanto` provides helper classes to quantize, save and reload Hugging Face quantized models.\n\n### LLM models\n\nThe first step is to quantize the model\n\n```python\nfrom transformers import AutoModelForCausalLM\nfrom optimum.quanto import QuantizedModelForCausalLM, qint4\n\nmodel = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B')\nqmodel = QuantizedModelForCausalLM.quantize(model, weights=qint4, exclude='lm_head')\n```\n\nNote: the model quantized weights will be frozen. If you want to keep them unfrozen to train them you need to use `optimum.quanto.quantize` directly.\n\nThe quantized model can be saved using `save_pretrained`:\n\n```python\nqmodel.save_pretrained('./Llama-3-8B-quantized')\n```\n\nIt can later be reloaded using `from_pretrained`:\n\n```python\nfrom optimum.quanto import QuantizedModelForCausalLM\n\nqmodel = QuantizedModelForCausalLM.from_pretrained('Llama-3-8B-quantized')\n```\n\n### Diffusers models\n\nYou can quantize any of the submodels inside a diffusers pipeline and seamlessly include them later in another pipeline.\n\nHere we quantize the `transformer` of a `Pixart` pipeline.\n\n```python\nfrom diffusers import PixArtTransformer2DModel\nfrom optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8\n\nmodel = PixArtTransformer2DModel.from_pretrained(\"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\", subfolder=\"transformer\")\nqmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8)\nqmodel.save_pretrained(\"./pixart-sigma-fp8\")\n```\n\nLater, we can reload the quantized model and recreate the pipeline:\n\n```python\nfrom diffusers import PixArtTransformer2DModel\nfrom optimum.quanto import QuantizedPixArtTransformer2DModel\n\ntransformer = QuantizedPixArtTransformer2DModel.from_pretrained(\"./pixart-sigma-fp8\")\ntransformer.to(device=\"cuda\")\npipe = PixArtSigmaPipeline.from_pretrained(\n  \"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\",\n  transformer=None,\n  torch_dtype=torch.float16,\n).to(\"cuda\")\npipe.transformer = transformer\n```\n\n## Quantization workflow for vanilla pytorch models (low-level API)\n\nOne thing to keep in mind when using the low-level quanto API is that by default models\nweights are dynamically quantized: an explicit call must be made to 'freeze' the quantized weights.\n\nA typical quantization workflow would consist of the following steps:\n\n**1. Quantize**\n\nThe first step converts a standard float model into a dynamically quantized model.\n\n```python\nfrom optimum.quanto import quantize, qint8\n\nquantize(model, weights=qint8, activations=qint8)\n```\n\nAt this stage, only the inference of the model is modified to dynamically quantize the weights.\n\n**2. Calibrate (optional if activations are not quantized)**\n\nQuanto supports a calibration mode that allows to record the activation ranges while passing representative samples through the quantized model.\n\n```python\nfrom optimum.quanto import Calibration\n\nwith Calibration(momentum=0.9):\n    model(samples)\n```\n\nThis automatically activates the quantization of the activations in the quantized modules.\n\n\n**3. Tune, aka Quantization-Aware-Training (optional)**\n\nIf the performance of the model degrades too much, one can tune it for a few epochs to recover the float model performance.\n\n```python\nimport torch\n\nmodel.train()\nfor batch_idx, (data, target) in enumerate(train_loader):\n    data, target = data.to(device), target.to(device)\n    optimizer.zero_grad()\n    output = model(data).dequantize()\n    loss = torch.nn.functional.nll_loss(output, target)\n    loss.backward()\n    optimizer.step()\n```\n\n**4. Freeze integer weights**\n\nWhen freezing a model, its float weights are replaced by quantized integer weights.\n\n```python\nfrom optimum.quanto import freeze\n\nfreeze(model)\n```\n\n**5. Serialize quantized model**\n\nQuantized models weights can be serialized to a `state_dict`, and saved to a file.\nBoth `pickle` and `safetensors` (recommended) are supported.\n\n```python\nfrom safetensors.torch import save_file\n\nsave_file(model.state_dict(), 'model.safetensors')\n```\n\nIn order to be able to reload these weights, you also need to store the quantized\nmodel quantization map.\n\n```python\nimport json\n\nfrom optimum.quanto import quantization_map\n\nwith open('quantization_map.json', 'w') as f:\n  json.dump(quantization_map(model), f)\n```\n\n**5. Reload a quantized model**\n\nA serialized quantized model can be reloaded from a `state_dict` and a `quantization_map` using the `requantize` helper.\nNote that you need first to instantiate an empty model.\n\n```python\nimport json\n\nfrom safetensors.torch import load_file\nfrom optimum.quanto import requantize\n\nstate_dict = load_file('model.safetensors')\nwith open('quantization_map.json', 'r') as f:\n  quantization_map = json.load(f)\n\n# Create an empty model from your modeling code and requantize it\nwith torch.device('meta'):\n  new_model = ...\nrequantize(new_model, state_dict, quantization_map, device=torch.device('cuda'))\n```\n\nPlease refer to the [examples](https://github.com/huggingface/quanto/tree/main/examples) for instantiations of that workflow.\n\n\n## Design overview\n\n### Tensors\n\nAt the heart of quanto is a Tensor subclass that corresponds to:\n- the projection of a source Tensor into the optimal range for a given destination type,\n- the mapping of projected values to the destination type.\n\nFor floating-point destination types, the mapping is done by the native pytorch cast (i.e. `Tensor.to()`).\n\nFor integer destination types, the mapping is a simple rounding operation (i.e. `torch.round()`).\n\nThe goal of the projection is to increase the accuracy of the conversion by minimizing the number of:\n- saturated values (i.e. mapped to the destination type min/max),\n- zeroed values (because they are below the smallest number that can be represented by the destination type)\n\nThe projection is symmetric per-tensor or per-channel for `int8` and `float8`, and group-wise affine (with a shift or 'zero-point') for lower bitwidth.\n\nOne of the benefits of using a lower-bitwidth representation is that you will be able to take advantage of accelerated operations\nfor the destination type, which is typically faster than their higher precision equivalents.\n\nQuanto does not support the conversion of a Tensor using mixed destination types.\n\n### Modules\n\nQuanto provides a generic mechanism to replace `torch` modules by `optimum-quanto` modules that are able to process quanto tensors.\n\n`optimum-quanto` modules dynamically convert their weights until a model is frozen, which slows down inference a bit but is\nrequired if the model needs to be tuned.\n\nWeights are usually quantized per-channel along the first dimension (output features).\n\nBiases are not converted to preserve the accuracy of a typical `addmm` operation.\n\nExplanation: to be consistent with the unquantized arithmetic operations, biases would need to be quantized with a scale that\nis equal to the product of the input and weight scales, which leads to a ridiculously small scale, and conversely\nrequires a very high bitwidth to avoid clipping. Typically, with `int8` inputs and weights, biases would need to be quantized\nwith at least `12` bits, i.e. in `int16`. Since most biases are today `float16`, this is a waste of time.\n\nActivations are dynamically quantized per-tensor using static scales (defaults to the range `[-1, 1]`).\n\nTo preserve accuracy, the model needs to be calibrated to evaluate the best activation scales (using a momentum).\n\nThe following modules can be quantized:\n\n- [Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (QLinear).\nWeights are always quantized, and biases are not quantized. Inputs and outputs can be quantized.\n- [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) (QConv2D).\nWeights are always quantized, and biases are not quantized. Inputs and outputs can be quantized.\n- [LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html),\nWeights and biases are __not__ quantized. Outputs can be quantized.\n\n## Pitfalls to avoid when quantizing activations\n\nActivations are always quantized per-tensor because most linear algebra operations in a model graph are not compatible\nwith per-axis inputs: you simply cannot add numbers that are not expressed in the same base (`you cannot add apples and oranges`).\n\nWeights involved in matrix multiplications are, on the contrary, always quantized along their first axis, because all output features\nare evaluated independently from one another.\n\nThe outputs of a quantized matrix multiplication will anyway always be dequantized, even if activations are quantized, because:\n\n- the resulting accumulated values are expressed with a much higher bitwidth (typically `int32` or `float32`) than the activation bitwidth (typically `int8` or `float8`),\n- they might be combined with a `float` bias.\n\nQuantizing activations per-tensor to `int8` can lead to serious quantization errors if the corresponding tensors contain large outlier values.\nTypically, this will lead to quantized tensors with most values set to zero (except the outliers).\n\nA possible solution to work around that issue is to 'smooth' the activations statically as illustrated by [SmoothQuant](https://github.com/mit-han-lab/smoothquant).\nYou can find a script to smooth some model architectures under [external/smoothquant](external/smoothquant).\n\nA better option is to represent activations using `float8`.\n"
  },
  {
    "path": "bench/generation/README.md",
    "content": "# Quanto generation benchmark\n\nThis repository contains scripts to evaluate the performances of quantized models using three metrics:\n\n- `latency.py` evaluates the latency per generated token,\n- `prediction.py` evaluates the accuracy when predicting the last token of prompts from the [Lambada dataset](https://huggingface.co/datasets/lambada),\n- `perplexity.py` evaluates the perplexity of the model on the [WikiText dataset](https://huggingface.co/datasets/wikitext), as defined in the [transformers documentation](https://huggingface.co/docs/transformers/en/perplexity).\n\nA `evaluate_model.py` utility script is also provided to evaluate the metrics on a specific model for several quantization configurations, and output the result to a `png` barchart and/or a `json` file.\n\nNote: the language modeling head (lm_head) of the tested models is not quantized.\n\nThe paragraphs below display results for some popular models on a NVIDIA A10 GPU.\n\n## meta-llama/Meta-Llama-3.1-8B\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/meta-llama-Meta-Llama-3.1-8B_bf16_Accuracy.png\" alt=\"meta-llama/Meta-llama-3.1-8B Lambada prediction accuracy\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/meta-llama-Meta-Llama-3.1-8B_bf16_Perplexity.png\" alt=\"meta-llama/Meta-Llama-3.1-8B WikiText perplexity\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/meta-llama-Meta-Llama-3.1-8B_bf16_Latency__ms_.png\" alt=\"meta-llama/Meta-Llama-3.1-8B Latency\">\n  </div>\n </center>\n</div>\n\n## mistralai/Mistral-7B-Instruct-v0.3\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/mistralai-Mistral-7B-Instruct-v0.3_bf16_Accuracy.png\" alt=\"mistralai/Mistral-7B-Instruct-v0.3 Lambada prediction accuracy\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/mistralai-Mistral-7B-Instruct-v0.3_bf16_Perplexity.png\" alt=\"mistralai/Mistral-7B-Instruct-v0.3 WikiText perplexity\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/mistralai-Mistral-7B-Instruct-v0.3_bf16_Latency__ms_.png\" alt=\"mistralai/Mistral-7B-Instruct-v0.3 Latency\">\n  </div>\n </center>\n</div>\n\n## google/gemma-2b\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/google-gemma-2b_bf16_Accuracy.png\" alt=\"google-gemma-2b Lambada prediction accuracy\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/google-gemma-2b_bf16_Perplexity.png\" alt=\"google-gemma-2b WikiText perplexity\">\n  </div>\n </center>\n</div>\n\n<div class=\"row\"><center>\n  <div class=\"column\">\n    <img src=\"https://github.com/huggingface/quanto/blob/main/bench/generation/charts/google-gemma-2b_bf16_Latency__ms_.png\" alt=\"google-gemma-2b Latency\">\n  </div>\n </center>\n</div>\n"
  },
  {
    "path": "bench/generation/evaluate_configurations.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport json\n\nimport torch\nfrom evaluate_model import evaluate\nfrom gen_barchart import gen_barchart\nfrom transformers import AutoConfig\n\nfrom optimum.quanto import qtype\n\n\ndef evaluate_model_configurations(\n    model_id: str, metric: str, device: torch.device, batch_size: int = 32, dtype: torch.dtype = torch.float16\n):\n    weights = [\n        \"int4\",\n        \"int8\",\n        \"float8\",\n    ]\n\n    activations = [\n        \"none\",\n        \"float8\",\n    ]\n\n    def short_name(qtype: qtype):\n        return {\n            \"none\": \"f16\" if dtype == torch.float16 else \"bf16\",\n            \"int4\": \"i4\",\n            \"int8\": \"i8\",\n            \"float8\": \"f8\",\n        }[qtype]\n\n    results = {}\n\n    # Evaluate float16/bfloat16 model\n    config_name = f\"W{short_name('none')}A{short_name('none')}\"\n    print(f\"{model_id}[{config_name}]:\")\n    results[config_name] = evaluate(model_id, metric, \"quanto\", \"none\", \"none\", batch_size, device, dtype)\n    # Evaluate quantized models\n    for w in weights:\n        for a in activations:\n            config_name = f\"W{short_name(w)}A{short_name(a)}\"\n            print(f\"{model_id}[{config_name}]:\")\n            results[config_name] = evaluate(model_id, metric, \"quanto\", w, a, batch_size, device, dtype)\n\n    return results\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Evaluate quantized model predictions on Lambada Dataset\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"facebook/opt-350m\",\n        help=\"The name of the trained Model.\",\n    )\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    parser.add_argument(\"--metric\", type=str, default=\"prediction\", choices=[\"latency\", \"prediction\", \"perplexity\"])\n    parser.add_argument(\"--batch_size\", type=int, default=32, help=\"The batch size during evaluation.\")\n    parser.add_argument(\"--dtype\", type=str, help=\"Use the following dtype to load the model.\")\n    parser.add_argument(\"--json\", action=\"store_true\", help=\"Dump the results to a json file.\")\n    parser.add_argument(\"--png\", action=\"store_true\", help=\"Generate a PNG.\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    if args.dtype is None:\n        config = AutoConfig.from_pretrained(args.model)\n        dtype = getattr(config, \"torch_dtype\", torch.float16)\n    else:\n        dtype = torch.float16 if args.dtype == \"fp16\" else torch.bfloat16\n    results = evaluate_model_configurations(args.model, args.metric, device, batch_size=args.batch_size, dtype=dtype)\n    if args.json:\n        model_name = args.model.split(\"/\")[-1]\n        json_path = f\"{model_name}-{args.metric}.json\"\n        with open(json_path, \"w\") as fp:\n            json.dump({model_name: results}, fp, indent=4)\n    if args.png:\n        if args.metric == \"latency\":\n            title = f\"{args.model}: Mean latency per token\"\n            label = \"Latency (ms)\"\n        elif args.metric == \"prediction\":\n            title = f\"{args.model}: Prediction accuracy on Lambada dataset\"\n            label = \"Accuracy\"\n        elif args.metric == \"perplexity\":\n            title = f\"{args.model}: Perplexity evaluated on WikiText dataset\"\n            label = \"Perplexity\"\n        gen_barchart(args.model, title, label, results, dtype)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/generation/evaluate_many_models.sh",
    "content": "#!/bin/bash\n# Absolute path to this script, e.g. /home/user/bin/foo.sh\nSCRIPT=$(readlink -f \"$0\")\n# Absolute path this script is in, thus /home/user/bin\nSCRIPT_PATH=$(dirname \"$SCRIPT\")\n\nmodels=(\n    google/gemma-2b\n    meta-llama/Meta-Llama-3.1-8B\n    mistralai/Mistral-7B-Instruct-v0.3\n)\n\nfor m in ${models[@]}; do\n    python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric prediction --png --json --batch_size 16\n    python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric perplexity --png --json --batch_size 16\n    python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric latency --png --json --batch_size 16\ndone\n"
  },
  {
    "path": "bench/generation/evaluate_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport importlib\n\nimport torch\nfrom datasets import load_dataset\nfrom metrics.latency import latency\nfrom metrics.perplexity import perplexity\nfrom metrics.prediction import prediction_accuracy\n\n\nif importlib.util.find_spec(\"awq\") is not None:\n    from setup.awq import setup as awq_setup\nif importlib.util.find_spec(\"bitsandbytes\") is not None:\n    from setup.bnb import setup as bnb_setup\nif importlib.util.find_spec(\"hqq\") is not None:\n    from setup.hqq import setup as hqq_setup\nfrom setup.quanto import setup as quanto_setup\nfrom transformers import AutoConfig\n\n\n@torch.no_grad()\ndef calibrate(model, tokenizer, batch_size, batches):\n    samples = batch_size * batches\n    cal_dataset = load_dataset(\"lambada\", split=[\"validation\"])[0]\n    model.eval()\n    total = 0\n    for batch in cal_dataset.iter(batch_size=batch_size):\n        inputs = tokenizer(batch[\"text\"], return_tensors=\"pt\", padding=True)\n        input_ids = inputs.input_ids.to(model.device)\n        attention_mask = inputs.attention_mask.to(model.device)\n        model(input_ids, attention_mask=attention_mask)\n        total += input_ids.size(0)\n        if total >= samples:\n            break\n\n\ndef evaluate(\n    model_id: str,\n    metric: str,\n    quantizer: str,\n    weights: str,\n    activations: str,\n    batch_size: int,\n    device: torch.device,\n    dtype: torch.dtype = None,\n):\n    if quantizer == \"quanto\":\n        if dtype is None:\n            config = AutoConfig.from_pretrained(model_id)\n            dtype = getattr(config, \"torch_dtype\", torch.float16)\n        model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device, dtype)\n    elif quantizer == \"awq\":\n        model, tokenizer = awq_setup(model_id, weights, activations, group_size=128)\n    elif quantizer == \"bnb\":\n        model, tokenizer = bnb_setup(model_id, weights, activations, device)\n    elif quantizer == \"hqq\":\n        model, tokenizer = hqq_setup(model_id, weights, activations, device)\n    else:\n        raise ValueError(f\"Unsupported quantizer {quantizer}\")\n    dtype = next(model.parameters()).dtype\n    weights = dtype if weights == \"none\" else weights\n    activations = dtype if activations == \"none\" else activations\n    print(f\"Evaluating {model_id} {metric} with {weights} weights and {activations} activations.\")\n    if metric == \"latency\":\n        return latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=3)\n    elif metric == \"prediction\":\n        return prediction_accuracy(model, tokenizer, batch_size)\n    elif metric == \"perplexity\":\n        return perplexity(model, tokenizer)\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Evaluate quantized model metrics\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"facebook/opt-350m\",\n        help=\"The name of the trained Model.\",\n    )\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    parser.add_argument(\"--metric\", type=str, default=\"prediction\", choices=[\"latency\", \"prediction\", \"perplexity\"])\n    parser.add_argument(\"--quantizer\", type=str, default=\"quanto\", choices=[\"quanto\", \"awq\", \"bnb\", \"hqq\"])\n    parser.add_argument(\n        \"--weights\",\n        type=str,\n        default=\"none\",\n        choices=[\"none\", \"int4\", \"int8\", \"float8\"],\n    )\n    parser.add_argument(\n        \"--activations\",\n        type=str,\n        default=\"none\",\n        choices=[\"none\", \"int8\", \"float8\"],\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=32, help=\"The batch size during evaluation.\")\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=\"none\",\n        choices=[\"none\", \"fp16\", \"bf16\"],\n    )\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n    dtype = {\"none\": None, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}[args.dtype]\n    evaluate(args.model, args.metric, args.quantizer, args.weights, args.activations, args.batch_size, device, dtype)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/generation/gen_barchart.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport json\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\n\ndef save_bar_chart(title, labels, ylabel, series, save_path):\n    x = np.arange(len(labels))  # the label locations\n    width = 0.15  # the width of the bars\n    multiplier = 0\n\n    fig, ax = plt.subplots(layout=\"constrained\")\n    fig.set_figwidth(10)\n\n    max_value = 0\n\n    for attribute, measurement in series.items():\n        max_value = max(max_value, max(measurement))\n        offset = width * multiplier\n        rects = ax.bar(x + offset, measurement, width, label=attribute)\n        ax.bar_label(rects, padding=5)\n        multiplier += 1\n\n    # Add some text for labels, title and custom x-axis tick labels, etc.\n    ax.set_ylabel(ylabel)\n    ax.set_title(title)\n    ax.set_xticks(x + width, labels)\n    ax.legend(loc=\"upper left\", ncols=4)\n    ax.set_ylim(0, max_value * 1.2)\n\n    plt.savefig(save_path)\n\n\ndef gen_barchart(model_id, title, label, results, dtype):\n    dtype_str = \"f16\" if dtype is torch.float16 else \"bf16\"\n    activations = (dtype_str, \"f8\")\n    weights = (\"i4\", \"i8\", \"f8\")\n    series = {}\n    reference = round(results[f\"W{dtype_str}A{dtype_str}\"], 2)\n    series[f\"Weights {dtype_str}\"] = [\n        reference,\n    ] * len(activations)\n    for w in weights:\n        name = f\"Weights {w}\"\n        series[name] = []\n        for a in activations:\n            result = results[f\"W{w}A{a}\"]\n            series[name].append(round(result, 2))\n    model_name = model_id.replace(\"/\", \"-\")\n    metric_name = label.replace(\" \", \"_\").replace(\"(\", \"_\").replace(\")\", \"_\")\n    save_bar_chart(\n        title=title,\n        labels=[f\"Activations {a}\" for a in activations],\n        series=series,\n        ylabel=label,\n        save_path=f\"{model_name}_{dtype_str}_{metric_name}.png\",\n    )\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"benchmark\", type=str, help=\"A benchmark result file (.json).\")\n    parser.add_argument(\"--title\", type=str, required=True, help=\"The graph title.\")\n    parser.add_argument(\"--label\", type=str, required=True, help=\"The graph vertical label.\")\n    args = parser.parse_args()\n    with open(args.benchmark) as f:\n        benchmark = json.load(f)\n        for model_id, results in benchmark.items():\n            gen_barchart(model_id, args.title, args.label, results)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/generation/metrics/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "bench/generation/metrics/latency.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport time\n\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\nfrom transformers import GenerationConfig\n\n\ndef latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=10):\n    def synchronize(device):\n        if device.type == \"cuda\":\n            torch.cuda.synchronize()\n        elif device.type == \"mps\":\n            torch.mps.synchronize()\n        elif device.type == \"xpu\":\n            torch.xpu.synchronize()\n        else:\n            torch.cpu.synchronize()\n\n    def timing_event(device):\n        if device.type == \"cuda\":\n            return torch.cuda.Event(enable_timing=True)\n        elif device.type == \"mps\":\n            return torch.mps.Event(enable_timing=True)\n        elif device.type == \"xpu\":\n            return torch.xpu.Event(enable_timing=True)\n\n        class CPUEvent:\n            def __init__(self):\n                self.time = None\n\n            def record(self):\n                self.time = time.time()\n\n            def elapsed_time(self, other):\n                assert self.time is not None\n                assert other.time is not None\n                return (other.time - self.time) * 1000\n\n        return CPUEvent()\n\n    generation_config = GenerationConfig(\n        max_new_tokens=nb_tokens,\n        min_new_tokens=nb_tokens,\n        use_cache=True,\n        pad_token_id=tokenizer.pad_token_id,\n        num_beams=1,\n        do_sample=False,\n        eos_token_id=None,  # This is required for min_new_tokens to actually have an effect.\n    )\n    if getattr(model, \"generation_config\", None) is not None:\n        model.generation_config.eos_token_id = None  # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.\n\n    synchronize(device)\n    if device.type == \"cuda\":\n        torch.cuda.reset_peak_memory_stats()\n    elif device.type == \"xpu\":\n        torch.xpu.reset_peak_memory_stats()\n\n    memory = get_device_memory(device)\n    if memory is not None:\n        print(f\"Device memory: {memory / (2**30):.4f} GB\")\n\n    latencies = []\n    input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)\n    masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)\n\n    for _ in tqdm(range(iterations)):\n        start_event = timing_event(device)\n        end_event = timing_event(device)\n        synchronize(device)\n        start_event.record()\n\n        _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)\n        end_event.record()\n        synchronize(device)\n\n        latency_ms = start_event.elapsed_time(end_event)\n        latencies.append(latency_ms)\n\n    if device.type == \"cuda\":\n        peak_memory = torch.cuda.max_memory_allocated()\n        print(f\"Peak memory during benchmark: {peak_memory / (2**30):.4f} GB\")\n    elif device.type == \"xpu\":\n        peak_memory = torch.xpu.max_memory_allocated()\n        print(f\"Peak memory during benchmark: {peak_memory / (2**30):.4f} GB\")\n\n    mean_latency = np.mean(latencies) / generation_config.min_new_tokens\n    print(f\"Average latency per token: {mean_latency} ms\")\n    return mean_latency\n\n\ndef get_device_memory(device):\n    gc.collect()\n    if device.type == \"cuda\":\n        torch.cuda.empty_cache()\n        return torch.cuda.memory_allocated()\n    elif device.type == \"mps\":\n        torch.mps.empty_cache()\n        return torch.mps.current_allocated_memory()\n    elif device.type == \"xpu\":\n        torch.xpu.empty_cache()\n        return torch.xpu.memory_allocated()\n    return None\n"
  },
  {
    "path": "bench/generation/metrics/perplexity.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport sys\n\nimport numpy as np\nimport torch\nfrom datasets import load_dataset\nfrom tqdm import tqdm\n\n\nclass Perplexity:\n    \"\"\"\n    A class for calculating the perplexity of a language model.\n    \"\"\"\n\n    def __init__(self, model, tokenizer, dataset_path=\"wikitext\", dataset_name=None, split=\"test\", text_column=\"text\"):\n        \"\"\"\n        Calculate perplexity using the same method as seen in llama.cpp.\n\n        Parameters\n        ----------\n        model : AutoModelForCausalLM\n            The language model for which the perplexity is calculated.\n        tokenizer : AutoTokenizer\n            The tokenizer corresponding to the model.\n        dataset_path : str, optional\n            The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'.\n        dataset_name : str, optional\n            The name of the dataset. Default is None.\n        split : str, optional\n            The split of the dataset to use. Default is 'test'.\n        text_column : str, optional\n            The name of the column in the dataset that contains the text data. Default is 'text'.\n        \"\"\"\n        self._model = model\n        self._tokenizer = tokenizer\n        self._dataset_path = dataset_path\n        self._dataset_name = dataset_name\n        self._split = split\n        self._text_column = text_column\n        self._text = self._prepare_data()\n\n    def _prepare_data(self):\n        \"\"\"\n        Prepares the dataset by loading and formatting.\n\n        Returns\n        -------\n        str\n            The formatted dataset as a single string.\n        \"\"\"\n        if self._dataset_path == \"wikitext\":\n            self._dataset_name = \"wikitext-2-raw-v1\"\n\n        # Load the dataset\n        data = load_dataset(self._dataset_path, self._dataset_name, split=self._split)\n        # Format the text column of the dataset\n        text_list = [\" \\n\" if s == \"\" else s for s in data[self._text_column]]\n        return \"\".join(text_list)\n\n    @staticmethod\n    def softmax(logits):\n        \"\"\"\n        Static method for applying the softmax function.\n\n        Parameters\n        ----------\n        logits : np.ndarray\n            The input to the softmax function.\n\n        Returns\n        -------\n        np.ndarray\n            The output of the softmax function.\n        \"\"\"\n        e_x = np.exp(logits - np.max(logits))\n        return e_x / e_x.sum(axis=0)\n\n    def calculate_perplexity(self, n_ctx=512, n_batch=512):\n        \"\"\"\n        Calculates the perplexity of the language model.\n\n        Parameters\n        ----------\n        n_ctx : int\n            The context size.\n        n_batch : int\n            The batch size.\n\n        Returns\n        -------\n        list\n            The list of perplexity scores calculated.\n        \"\"\"\n        # Tokenize the text\n        self._tokenizer.model_max_length = sys.maxsize\n        tokens = self._tokenizer(self._text, truncation=False, return_tensors=\"pt\").input_ids.to(self._model.device)\n\n        nll = 0.0  # Negative log likelihood\n        count = 0  # Counter for processed tokens\n        curr_ppl = 0\n        all_perplexity = []\n\n        with tqdm(range(len(tokens[0]) // n_ctx), desc=\"Perplexity: - \") as progress:\n            for i in progress:\n                # Process each batch of tokens\n                nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count)\n\n                # Calculate and display the current perplexity\n                curr_ppl = np.exp(nll / count)\n                all_perplexity.append(curr_ppl)\n                progress.set_description(f\"Perplexity: {curr_ppl:.4f}\")\n\n        return all_perplexity\n\n    def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count):\n        \"\"\"\n        Processes each batch of tokens.\n\n        Parameters\n        ----------\n        i : int\n            The batch index.\n        n_ctx : int\n            The context size.\n        n_batch : int\n            The batch size.\n        tokens : torch.Tensor\n            The tokenized text.\n        nll : float\n            The current negative log likelihood.\n        count : int\n            The current count of processed tokens.\n\n        Returns\n        -------\n        float\n            The updated negative log likelihood.\n        int\n            The updated count of processed tokens.\n        \"\"\"\n        start = i * n_ctx\n        end = start + n_ctx\n\n        num_batches = (n_ctx + n_batch - 1) // n_batch\n\n        logits = []\n\n        for j in range(num_batches):\n            batch_start = start + j * n_batch\n            batch_size = min(end - batch_start, n_batch)\n\n            token_org = tokens[0][batch_start].item()\n\n            if j == 0:\n                # Replace the first token with the BOS token\n                tokens[0][batch_start] = self._tokenizer.bos_token_id\n\n            # Compute the logits for the current batch of tokens\n            batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size)\n\n            tokens[0][batch_start] = token_org\n\n            logits.append(batch_logits)\n\n        # We rely on the fact that attention in the forward pass only looks at previous\n        # tokens here, so the logits returned for each token are an accurate representation\n        # of what the model would have predicted at that point.\n        #\n        # Example, we have a context window of 512, we will compute perplexity for each of the\n        # last 256 tokens.  Then, we split the input up into context window size chunks to\n        # process the entire prompt.\n\n        for j in range(min(512, n_ctx // 2), n_ctx - 1):\n            tok_logits = logits[0][0][j].cpu().numpy()\n            # Compute the probability of the next token\n            prob = self.softmax(tok_logits)[tokens[0][start + j + 1]]\n\n            # Update the negative log likelihood and the count of processed tokens\n            nll += -np.log(prob, where=prob > 0)\n            count += 1\n\n        return nll, count\n\n    def _compute_batch_logits(self, tokens, batch_start, batch_size):\n        \"\"\"\n        Computes the logits for a batch of tokens.\n\n        Parameters\n        ----------\n        tokens : torch.Tensor\n            The tokenized text.\n        batch_start : int\n            The start index of the batch.\n        batch_size : int\n            The size of the batch.\n\n        Returns\n        -------\n        torch.Tensor\n            The logits for the batch of tokens.\n        \"\"\"\n        # Compute the logits without keeping track of gradients\n        with torch.no_grad():\n            outputs = self._model(tokens[:, batch_start : batch_start + batch_size])\n        return outputs.logits.detach()\n\n\ndef perplexity(\n    model,\n    tokenizer,\n    stride: int = 512,\n):\n    print(\"Evaluating perplexity\")\n    ppl = Perplexity(model, tokenizer)\n    ppl_value = np.mean(ppl.calculate_perplexity(n_ctx=stride))\n    return ppl_value\n"
  },
  {
    "path": "bench/generation/metrics/prediction.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nfrom datasets import load_dataset\n\n\n@torch.no_grad()\ndef prediction_accuracy(model, tokenizer, batch_size, samples=None):\n    test_dataset = load_dataset(\"lambada\", split=[\"test\"])[0]\n    model.eval()\n    # The task is to predict the last token of the input.\n    total, hit = 0, 0\n    start = time.time()\n    for batch in test_dataset.iter(batch_size=batch_size):\n        inputs = tokenizer(batch[\"text\"], return_tensors=\"pt\", padding=True)\n        input_ids = inputs.input_ids.to(model.device)\n        attention_mask = inputs.attention_mask.to(model.device)\n        labels = input_ids[:, -1]\n        # Pass only the first tokens\n        outputs = model(input_ids[:, :-1], attention_mask=attention_mask[:, :-1])\n        preds = outputs.logits[:, -1, :].argmax(dim=-1)\n        total += labels.size(0)\n        hit += (preds == labels).sum().item()\n        if samples is not None and total >= samples:\n            break\n    end = time.time()\n    acc = hit / total\n    print(f\"{total} sequences evaluated in {end - start:.2f} s. accuracy = {acc:.2f}\")\n    return acc\n"
  },
  {
    "path": "bench/generation/setup/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n"
  },
  {
    "path": "bench/generation/setup/awq.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom awq import AutoAWQForCausalLM\nfrom transformers import AutoTokenizer\n\n\ndef prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):\n    if past_key_values is not None:\n        cache_length = past_length = past_key_values[0][0].shape[2]\n        max_cache_length = None\n\n        # Keep only the unprocessed tokens:\n        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n        # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as\n        # input)\n        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n        # input_ids based on the past_length.\n        elif past_length < input_ids.shape[1]:\n            input_ids = input_ids[:, past_length:]\n        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n        if (\n            max_cache_length is not None\n            and attention_mask is not None\n            and cache_length + input_ids.shape[1] > max_cache_length\n        ):\n            attention_mask = attention_mask[:, -max_cache_length:]\n\n    position_ids = kwargs.get(\"position_ids\", None)\n    if attention_mask is not None and position_ids is None:\n        # create position_ids on the fly for batch generation\n        position_ids = attention_mask.long().cumsum(-1) - 1\n        position_ids.masked_fill_(attention_mask == 0, 1)\n        if past_key_values:\n            position_ids = position_ids[:, -input_ids.shape[1] :]\n\n    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n    if inputs_embeds is not None and past_key_values is None:\n        model_inputs = {\"inputs_embeds\": inputs_embeds}\n    else:\n        model_inputs = {\"input_ids\": input_ids}\n\n    model_inputs.update(\n        {\n            \"position_ids\": position_ids,\n            \"past_key_values\": past_key_values,\n            \"use_cache\": kwargs.get(\"use_cache\"),\n            \"attention_mask\": attention_mask,\n        }\n    )\n    return model_inputs\n\n\ndef setup(model_id: str, weights: str, activations: str, group_size: int = 64, version=\"GEMV_FAST\"):\n    if activations != \"none\":\n        raise ValueError(\"Activation quantization is not supported by HQQ\")\n    if weights != \"int4\":\n        raise ValueError(\"AWQ only supports int4 weights.\")\n    quant_config = {\"zero_point\": True, \"q_group_size\": group_size, \"w_bit\": 4, \"version\": version}\n    # Load model\n    model = AutoAWQForCausalLM.from_pretrained(model_id)\n    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = \"left\"\n    # Quantize\n    model.quantize(tokenizer, quant_config=quant_config)\n    # We need to save otherwise it doesn't work\n    quant_path = model_id.replace(\"/\", \"-\") + f\"_{group_size}_{version}\"\n    model.save_quantized(quant_path)\n    # Reload model\n    model = AutoAWQForCausalLM.from_quantized(quant_path)\n    # Hack: force transformers 4.36.2 behaviour\n    model.model.prepare_inputs_for_generation = prepare_inputs_for_generation\n    # Hack because AWQ models are not transformers models\n    model.device = next(model.parameters()).device\n    return model, tokenizer\n"
  },
  {
    "path": "bench/generation/setup/bnb.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n\n\ndef setup(\n    model_id: str,\n    weights: str,\n    activations: str,\n    device: torch.device,\n):\n    if activations != \"none\":\n        raise ValueError(\"Activation quantization is not supported by BitsAndBytes\")\n    if weights == \"int4\":\n        quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type=\"fp4\")\n    elif weights == \"int8\":\n        quantization_config = BitsAndBytesConfig(load_in_8bit=True)\n    else:\n        raise ValueError(\"BitsAndBytes only supports int4 and int8 weights.\")\n    dtype = torch.float32 if device.type == \"cpu\" else torch.float16\n    tokenizer = AutoTokenizer.from_pretrained(model_id)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = \"left\"\n    quantization_config.bnb_4bit_compute_dtype = dtype\n    model = AutoModelForCausalLM.from_pretrained(\n        model_id, torch_dtype=dtype, low_cpu_mem_usage=True, quantization_config=quantization_config\n    )\n\n    return model, tokenizer\n"
  },
  {
    "path": "bench/generation/setup/hqq.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom hqq.core.quantize import BaseQuantizeConfig\nfrom hqq.engine.hf import HQQModelForCausalLM\nfrom transformers import AutoTokenizer\n\n\ndef setup(model_id: str, weights: str, activations: str, device: torch.device, group_size: int = 64):\n    if activations != \"none\":\n        raise ValueError(\"Activation quantization is not supported by HQQ\")\n    if weights == \"int4\":\n        quant_config = BaseQuantizeConfig(nbits=4, group_size=group_size)\n    elif weights == \"int8\":\n        quant_config = BaseQuantizeConfig(nbits=8, group_size=group_size)\n    else:\n        raise ValueError(\"HQQ only supports int4 and int8 weights.\")\n    # Load model\n    model = HQQModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)\n    # Quantize\n    model.quantize_model(quant_config=quant_config, compute_dtype=torch.float16, device=device)\n    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = \"left\"\n    return model, tokenizer\n"
  },
  {
    "path": "bench/generation/setup/quanto.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport time\n\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize\n\n\n@torch.no_grad()\ndef calibrate(model, tokenizer, batch_size, batches):\n    samples = batch_size * batches\n    cal_dataset = load_dataset(\"lambada\", split=[\"validation\"])[0]\n    model.eval()\n    total = 0\n    for batch in cal_dataset.iter(batch_size=batch_size):\n        inputs = tokenizer(batch[\"text\"], return_tensors=\"pt\", padding=True)\n        input_ids = inputs.input_ids.to(model.device)\n        attention_mask = inputs.attention_mask.to(model.device)\n        model(input_ids, attention_mask=attention_mask)\n        total += input_ids.size(0)\n        if total >= samples:\n            break\n\n\ndef setup(\n    model_id: str,\n    weights: str,\n    activations: str,\n    batch_size: int,\n    device: torch.device,\n    dtype: torch.dtype,\n):\n    weights = keyword_to_qtype(weights)\n    activations = keyword_to_qtype(activations)\n    tokenizer = AutoTokenizer.from_pretrained(model_id)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = \"left\"\n    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, low_cpu_mem_usage=True).to(device)\n    if weights is not None or activations is not None:\n        print(\"Quantizing\")\n        start = time.time()\n        quantization_root = model\n        if hasattr(model, \"model\"):\n            quantization_root = model.model\n        quantize(quantization_root, weights=weights, activations=activations)\n        if activations is not None:\n            print(\"Calibrating\")\n            with Calibration():\n                calibrate(model, tokenizer, batch_size, batches=4)\n        print(\"Freezing\")\n        freeze(model)\n        print(f\"Finished: {time.time() - start:.2f}\")\n    return model, tokenizer\n\n\ndef keyword_to_qtype(k):\n    return {\n        \"none\": None,\n        \"int4\": qint4,\n        \"int8\": qint8,\n        \"float8\": qfloat8,\n    }[k]\n"
  },
  {
    "path": "bench/kernels/benchmark.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport time\nfrom contextlib import nullcontext\n\nimport numpy as np\nimport torch\nfrom tqdm.auto import tqdm\n\nfrom optimum.quanto.library import disable_extensions\n\n\ndef get_unpack_bench(bits, device):\n    qmax = 2**bits\n    a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device)\n\n    def bench_fn():\n        return torch.ops.quanto.unpack(a, bits)\n\n    return bench_fn\n\n\ndef timing(get_bench_func, device, iterations=10):\n    def synchronize(device):\n        if device.type == \"cuda\":\n            torch.cuda.synchronize()\n        elif device.type == \"mps\":\n            torch.mps.synchronize()\n        elif device.type == \"xpu\":\n            torch.xpu.synchronize()\n        else:\n            torch.cpu.synchronize()\n\n    def timing_event(device):\n        if device.type == \"cuda\":\n            return torch.cuda.Event(enable_timing=True)\n        elif device.type == \"mps\":\n            return torch.mps.Event(enable_timing=True)\n        elif device.type == \"xpu\":\n            return torch.xpu.Event(enable_timing=True)\n\n        class CPUEvent:\n            def __init__(self):\n                self.time = None\n\n            def record(self):\n                self.time = time.time()\n\n            def elapsed_time(self, other):\n                assert self.time is not None\n                assert other.time is not None\n                return (other.time - self.time) * 1000\n\n        return CPUEvent()\n\n    synchronize(device)\n\n    bench_func = get_bench_func(device)\n    # Warmup to load library\n    bench_func()\n    latencies = np.empty((iterations, 2))\n    for i in tqdm(range(iterations)):\n        for j, context in enumerate([disable_extensions(), nullcontext()]):\n            start_event = timing_event(device)\n            end_event = timing_event(device)\n            synchronize(device)\n            start_event.record()\n            with context:\n                bench_func()\n            end_event.record()\n            synchronize(device)\n            latencies[i, j] = start_event.elapsed_time(end_event)\n    return np.mean(latencies[:, 0]), np.mean(latencies[:, 1])\n\n\nGET_BENCH_FUNCTIONS = {\n    \"unpack_2bit\": lambda device: get_unpack_bench(2, device),\n    \"unpack_4bit\": lambda device: get_unpack_bench(4, device),\n}\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Kernel benchmark\")\n    parser.add_argument(\"--kernel\", type=str, default=None, help=\"The kernel to benchmark. None to test all of them\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for benchmark.\")\n    parser.add_argument(\"--it\", type=int, default=10, help=\"The number of benchmark iterations\")\n    args = parser.parse_args()\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n    all_kernels = GET_BENCH_FUNCTIONS.keys()\n    kernels = all_kernels if args.kernel is None else [args.kernel]\n    for kernel in kernels:\n        get_bench_fn = GET_BENCH_FUNCTIONS[kernel]\n        python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it)\n        ratio = python_ms / ext_ms\n        print(f\"\\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/kernels/benchmark_marlin_fp8.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport argparse\nfrom typing import Optional\n\nimport numpy as np\nimport torch\n\nfrom optimum.quanto.tensor.weights.marlin.packed import pack_fp8_as_int32\n\n\nM_SHAPES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]\nN_SHAPES = [4096]\nK_SHAPES = [4096]\n\n\ndef run_benchmark(\n    m: Optional[int],\n    n: Optional[int],\n    k: Optional[int],\n    n_runs: int,\n    n_warmup: int,\n    dtype: torch.dtype = torch.float16,\n):\n    print(f\"\\n----------- m={m}, n={n}, k={k}\")\n    n_tokens = m\n    in_features = k\n    out_features = n\n\n    assert m is not None\n\n    device = torch.device(\"cuda\")\n    inputs = torch.rand(n_tokens, in_features, dtype=dtype, device=device)\n\n    other_shape = (in_features, out_features)\n    other_data = torch.rand(other_shape, dtype=dtype, device=device).to(torch.float8_e4m3fn)\n    other_data_int32 = pack_fp8_as_int32(other_data)\n    perm = torch.empty(0, dtype=torch.int, device=device)\n\n    other_data_repack = torch.ops.quanto.gptq_marlin_repack(\n        b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8\n    )\n    other_scale = torch.rand(1, dtype=dtype, device=device)\n    other_scale = other_scale.repeat(1, out_features)\n\n    workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device)\n\n    latencies_marlin_fp8 = []\n    latencies_torch = []\n    with torch.no_grad():\n        for i in range(n_runs):\n            start_event = torch.cuda.Event(enable_timing=True)\n            end_event = torch.cuda.Event(enable_timing=True)\n            torch.cuda.synchronize(device)\n            start_event.record()\n\n            _ = torch.ops.quanto.fp8_marlin_gemm(\n                a=inputs,\n                b_q_weight=other_data_repack,\n                b_scales=other_scale,\n                workspace=workspace,\n                num_bits=8,\n                size_m=n_tokens,\n                size_n=out_features,\n                size_k=in_features,\n            )\n            end_event.record()\n            torch.cuda.synchronize(device)\n\n            latency_ms = start_event.elapsed_time(end_event)\n            if i >= n_warmup:\n                latencies_marlin_fp8.append(latency_ms)\n\n            start_event = torch.cuda.Event(enable_timing=True)\n            end_event = torch.cuda.Event(enable_timing=True)\n            torch.cuda.synchronize(device)\n            start_event.record()\n            other = other_data.to(dtype) * other_scale\n            _ = torch.matmul(inputs, other)\n            end_event.record()\n            torch.cuda.synchronize(device)\n\n            latency_ms = start_event.elapsed_time(end_event)\n            if i >= n_warmup:\n                latencies_torch.append(latency_ms)\n\n    mean_latency_torch = np.mean(latencies_torch)\n    mean_latency_marlin_fp8 = np.mean(latencies_marlin_fp8)\n    print(\"mean_latency_torch:\", mean_latency_torch)\n    print(\"mean_latency_marlin_fp8:\", mean_latency_marlin_fp8)\n\n    return mean_latency_torch, mean_latency_marlin_fp8\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Marlin FP8 kernel benchmark\")\n    parser.add_argument(\"--nruns\", type=int, default=20, help=\"The number of benchmark iterations\")\n    parser.add_argument(\"--nwarmup\", type=int, default=2, help=\"The number of warmup iterations (deducted from nruns)\")\n    parser.add_argument(\n        \"--m\",\n        type=int,\n        help=\"m dimension of A=m*k\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--n\",\n        type=int,\n        help=\"n dimension of B=k*n (out_features)\",\n        default=None,\n    )\n    parser.add_argument(\n        \"--k\",\n        type=int,\n        help=\"k dimension of A=m*k and B=k*n (in_features), hidden_size\",\n        default=None,\n    )\n    args = parser.parse_args()\n\n    if args.m is not None:\n\n        def shape_generator():\n            yield (args.m, args.n, args.k)\n\n    else:\n\n        def shape_generator():\n            for m in M_SHAPES:\n                for n in N_SHAPES:\n                    for k in K_SHAPES:\n                        yield (m, n, k)\n\n    result = \"m,n_out,k_in,torch_latency_ms,marlin_fp8_latency_ms\\n\"\n    for m, n, k in shape_generator():\n        mean_latency_torch, mean_latency_marlin_fp8 = run_benchmark(m, n, k, args.nruns, args.nwarmup)\n\n        result += (\n            \",\".join(\n                [\n                    str(m),\n                    str(n),\n                    str(k),\n                    f\"{mean_latency_torch:.4f}\",\n                    f\"{mean_latency_marlin_fp8:.4f}\",\n                ]\n            )\n            + \"\\n\"\n        )\n\n    print(\"\\nResults:\")\n    print(result)\n"
  },
  {
    "path": "bench/kernels/benchmark_w4a16.py",
    "content": "# From: https://github.com/IST-DASLab/marlin/blob/master/bench.py\nimport argparse\nimport time\n\nimport torch\n\nfrom optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking\nfrom optimum.quanto.tensor.weights.marlin import marlin_permute\nfrom optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor\n\n\ndef benchmark(f, warmup=1, iter=10):\n    for i in range(warmup + iter):\n        f()\n        # We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also\n        # happen during realistic model inference as many launches are submitted to the kernel queue.\n        if i == warmup - 1:\n            torch.cuda.synchronize()\n            tick = time.time()\n    torch.cuda.synchronize()\n    res = (time.time() - tick) / iter\n    # Make sure there is enough to \"cool down\" the GPU in between benchmarks to avoid throttling for later runs when\n    # we execute many benchmarks consecutively\n    time.sleep(1.0)\n    return res\n\n\ndef get_problem(m, n, k, groupsize=128):\n    dev = torch.device(\"cuda:0\")\n    A = torch.rand((m, k), dtype=torch.half, device=dev)\n    B_4bit = torch.randint(0, 2**4, (n, k), dtype=torch.uint8, device=dev)\n    B_awq = AWQPackedTensor.pack(B_4bit, packing=AWQPacking.V2)._data\n    B_marlin = MarlinInt4PackedTensor.pack(B_4bit)._data\n    B_ref = torch.rand((k, n), dtype=torch.half, device=dev)\n    s = torch.rand((k // groupsize, n), dtype=torch.half, device=dev) / 2**4\n    s_marlin = marlin_permute(s)\n    z = torch.randint(-(2 ** (4 - 1)), 2 ** (4 - 1), (k // groupsize, n), dtype=torch.int8, device=dev)\n    sz = -z * s\n    sz_marlin = marlin_permute(sz)\n    torch.cuda.synchronize()\n    return A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin\n\n\ndef benchmark_dense(A, B, m, n, k):\n    res = benchmark(lambda: torch.matmul(A, B))\n    return {\n        \"s\": res,\n        \"TFLOP/s\": 2 * A.numel() * n / res / 10**12,\n        \"GB/s\": (2 * A.numel() + 2 * B.numel() + 2 * (m * n)) / res / 10**9,\n    }\n\n\ndef benchmark_awq(A, B, s, sz, m, n, k):\n    res = benchmark(\n        lambda: torch.ops.quanto.gemm_f16i4_awq(A, B, s, sz, rows=m, out_cols=n, in_cols=k, bits=4, group_size=128)\n    )\n    return {\n        \"s\": res,\n        \"TFLOP/s\": 2 * (m * k) * n / res / 10**12,\n        \"GB/s\": (2 * A.numel() + 2 * B.numel() + 2 * (m * n) + 2 * s.numel() + 2 * sz.numel()) / res / 10**9,\n    }\n\n\ndef benchmark_marlin(A, B, s, sz, m, n, k):\n    workspace = torch.zeros(n // 128 * 16, dtype=torch.int, device=torch.device(\"cuda:0\"))\n    res = benchmark(lambda: torch.ops.quanto.gemm_f16i4_marlin(A, B, s, sz, workspace))\n    return {\n        \"s\": res,\n        \"TFLOP/s\": 2 * (m * k) * n / res / 10**12,\n        \"GB/s\": (2 * A.numel() + 4 * B.numel() + 2 * (m * n) + 2 * s.numel() + 2 * sz.numel()) / res / 10**9,\n    }\n\n\nMODELS = {\n    \"Llama7B\": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)],\n    \"Llama13B\": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)],\n    \"Llama33B\": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)],\n    \"Llama65B\": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)],\n    \"Falcon180B\": [\n        # Note that parallel attention and FC allows layer fusions\n        (14848, 14848 * 5 + 1024),\n        (14848 * 5, 14848),\n    ],\n}\n\n\ndef run_benchmark(model, tokens=None):\n    if tokens is None:\n        tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n    elif not isinstance(tokens, (list, tuple)):\n        tokens = [tokens]\n    groupsize = 128\n    layers = MODELS[model]\n    print(model)\n    for m in tokens:\n        tot_awq = {\"s\": 0, \"TFLOP/s\": 0, \"GB/s\": 0, \"speedup\": 0}\n        tot_marlin = {\"s\": 0, \"TFLOP/s\": 0, \"GB/s\": 0, \"speedup\": 0}\n        for layer in layers:\n            k, n = layer\n            A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin = get_problem(m, n, k, groupsize)\n            res_d = benchmark_dense(A, B_ref, m, n, k)\n            res_awq = benchmark_awq(A, B_awq, s, sz, m, n, k)\n            res_awq[\"speedup\"] = res_d[\"s\"] / res_awq[\"s\"]\n            tot_awq[\"s\"] += res_awq[\"s\"]\n            for key in tot_awq:\n                if key != \"s\":\n                    tot_awq[key] += res_awq[key] * res_awq[\"s\"]\n            res_marlin = benchmark_marlin(A, B_marlin, s_marlin, sz_marlin, m, n, k)\n            res_marlin[\"speedup\"] = res_d[\"s\"] / res_marlin[\"s\"]\n            tot_marlin[\"s\"] += res_marlin[\"s\"]\n            for key in tot_marlin:\n                if key != \"s\":\n                    tot_marlin[key] += res_marlin[key] * res_marlin[\"s\"]\n        for key in tot_awq:\n            if key != \"s\":\n                tot_awq[key] /= tot_awq[\"s\"]\n        for key in tot_marlin:\n            if key != \"s\":\n                tot_marlin[key] /= tot_marlin[\"s\"]\n        print(\n            \"AWQ, tokens=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f\"\n            % (m, tot_awq[\"s\"], tot_awq[\"TFLOP/s\"], tot_awq[\"GB/s\"], tot_awq[\"speedup\"])\n        )\n        print(\n            \"Marlin, batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f\"\n            % (m, tot_marlin[\"s\"], tot_marlin[\"TFLOP/s\"], tot_marlin[\"GB/s\"], tot_marlin[\"speedup\"])\n        )\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"W4A16 Matrix Multiplication Kernel benchmark\")\n    parser.add_argument(\n        \"--model\", type=str, default=None, help=\"The model configuration to benchmark. None to test all of them.\"\n    )\n    parser.add_argument(\n        \"--tokens\",\n        type=int,\n        default=None,\n        help=\"The numbers of input tokens used to benchmark. None to test a predefined range.\",\n    )\n    args = parser.parse_args()\n    models = MODELS if args.model is None else [args.model]\n    for model in models:\n        run_benchmark(model, args.tokens)\n        print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/torch_kernels/README.md",
    "content": "This contains a few scripts to test pytorch kernels that are relevant for quantization.\n"
  },
  {
    "path": "bench/torch_kernels/test_int_mm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport timeit\n\nimport torch\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Torch integer matmul benchmark\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for the test.\")\n    parser.add_argument(\"--it\", type=int, default=100, help=\"Number of iterations for average\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    def avg_time(f, it):\n        return timeit.Timer(f).timeit(it) / it\n\n    # Resstrictions for accelerated integer matmul:\n    # - input matrices must be 2D\n    # - the collapsing dimension must be a multiple of 8\n    A = torch.randint(1, 10, [2400, 3200]).type(torch.int8).to(device)\n    B = torch.randint(1, 10, [3200, 4800]).type(torch.int8).to(device)\n\n    print(f\"Evaluating integer matmul on {device.type}:\")\n    # Warmup (slow)\n    torch._int_mm(A, B)\n    # Average on several calls\n    t = avg_time(lambda: torch._int_mm(A, B), args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n    # Convert inputs to float\n\n    def to_float(x):\n        if x.device.type == (\"cpu\"):\n            # matrix multiplication is not supported for float16 on CPU\n            return x.to(torch.float32)\n        return x.to(torch.float16)\n\n    A = to_float(A)\n    B = to_float(B)\n    print(f\"Evaluating {A.dtype} matmul on {device.type}:\")\n\n    # Warmup (slow)\n    torch.matmul(A, B)\n    # Average on several calls\n    t = avg_time(lambda: torch.matmul(A, B), args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/torch_kernels/test_int_mm_inductor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport timeit\n\nimport torch\n\n\ndef mm(a, b):\n    return torch._int_mm(a, b)\n\n\nA = torch.randint(1, 10, [2400, 2400]).type(torch.int8).cuda()\nB = torch.randint(1, 10, [2400, 2400]).type(torch.int8).cuda()\nit = 100\n\n# Warmup (slow)\nmm(A, B)\n# Get a reference\nprint(timeit.Timer(lambda: mm(A, B)).timeit(it) / it)\n\ncmm = torch.compile(mm, backend=\"inductor\")\n# First invocation will trigger the actual compilation\ncmm(A, B)\n# Now compare execution time\nprint(timeit.Timer(lambda: cmm(A, B)).timeit(it) / it)\n"
  },
  {
    "path": "bench/torch_kernels/test_weight_int4pack_mm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport timeit\n\nimport torch\n\n\ndef _group_quantize_tensor(w, n_bit=4, q_group_size=16):\n    assert w.dim() == 2\n    w = w.transpose(0, 1).contiguous()\n    assert q_group_size > 1\n    assert w.shape[-1] % q_group_size == 0\n\n    to_quant = w.reshape(-1, q_group_size)\n    assert torch.isnan(to_quant).sum() == 0\n\n    max_val = to_quant.amax(dim=1, keepdim=True)\n    min_val = to_quant.amin(dim=1, keepdim=True)\n    max_int = 2**n_bit - 1\n    min_int = 0\n    scales = (max_val - min_val).clamp(min=1e-6) / max_int\n    assert torch.isnan(scales).sum() == 0\n\n    zeros = min_val + scales * (2 ** (n_bit - 1))\n    assert torch.isnan(zeros).sum() == 0\n\n    out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)\n    assert torch.isnan(out).sum() == 0\n\n    out = out.to(dtype=torch.int32).reshape(w.shape)\n\n    # Scales and zeros for the same q-group should be contiguous, so we can\n    # load as a 32-bit word\n    scales = scales.view(w.shape[0], -1)\n    zeros = zeros.view(w.shape[0], -1)\n    scales_and_zeros = (\n        torch.cat(\n            [\n                scales.reshape(scales.size(0), scales.size(1), 1),\n                zeros.reshape(zeros.size(0), zeros.size(1), 1),\n            ],\n            2,\n        )\n        .transpose(0, 1)\n        .contiguous()\n    )\n\n    return out, scales_and_zeros\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Torch quantized int4 weight matmul benchmark\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\"--dtype\", type=str, default=\"fp16\", choices=[\"fp16\", \"bf16\"], help=\"floating point type\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for the test.\")\n    parser.add_argument(\"--it\", type=int, default=10, help=\"Number of iterations for average\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    def avg_time(f, it):\n        return timeit.Timer(f).timeit(it) / it\n\n    dtype = {\"fp16\": torch.float16, \"bf16\": torch.bfloat16}[args.dtype]\n\n    A = torch.rand([2400, 3200], dtype=dtype, device=device)\n    B = torch.rand([3200, 4800], dtype=dtype, device=device)\n    group_size = 128\n    B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size)\n    if device.type == \"cpu\":\n        B_packed = torch._convert_weight_to_int4pack_for_cpu(B_int32, innerKTiles=2)\n    else:\n        B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8)\n        B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2)\n\n    # Check quantized mm is close to float mm\n    if device.type == \"cpu\":\n        qout = torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros)\n    else:\n        qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros)\n    out = torch.mm(A, B)\n\n    mean_err = ((qout - out).abs() / out.abs()).mean()\n    print(mean_err)\n\n    print(f\"Evaluating quantized int4 matmul on {device.type}:\")\n    # Warmup (slow)\n    if device.type == \"cpu\":\n        torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros)\n    else:\n        torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros)\n    # Average on several calls\n    if device.type == \"cpu\":\n        t = (\n            avg_time(lambda: torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros), args.it)\n            * 1000\n        )\n    else:\n        t = avg_time(lambda: torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros), args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n    print(f\"Evaluating {A.dtype} matmul on {device.type}:\")\n\n    # Warmup (slow)\n    torch.mm(A, B)\n    # Average on several calls\n    t = avg_time(lambda: torch.mm(A, B), args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "bench/torch_kernels/test_weight_int8pack_mm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport timeit\n\nimport torch\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Torch quantized int8 weight matmul benchmark\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for the test.\")\n    parser.add_argument(\"--it\", type=int, default=10, help=\"Number of iterations for average\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    def avg_time(f, it):\n        return timeit.Timer(f).timeit(it) / it\n\n    A = torch.rand([2400, 3200], dtype=torch.bfloat16, device=device)\n    B = torch.randint(-128, 127, [4800, 3200], dtype=torch.int8, device=device)\n    B_scale = torch.rand([4800], dtype=torch.bfloat16, device=device)\n\n    print(f\"Evaluating quantized int8 matmul on {device.type}:\")\n    # Warmup (slow)\n    torch._weight_int8pack_mm(A, B, B_scale)\n    # Average on several calls\n    t = avg_time(lambda: torch._weight_int8pack_mm(A, B, B_scale), args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n    # Convert weights to float\n\n    B = B.to(torch.bfloat16).t()\n    print(f\"Evaluating {A.dtype} matmul on {device.type}:\")\n\n    # Warmup (slow)\n    torch.matmul(A, B) * B_scale\n    # Average on several calls\n    t = avg_time(lambda: torch.matmul(A, B) * B_scale, args.it) * 1000\n    print(f\"Average inference on {args.it} iterations: {t:.4f} ms\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/nlp/text-classification/sst2/quantize_sst2_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport io\nimport time\n\nimport numpy as np\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline\nfrom transformers.pipelines.pt_utils import KeyDataset\n\nfrom optimum.quanto import Calibration, freeze, qint4, qint8, quantize\n\n\ndef evaluate_model(model, tokenizer, dataset, device, batch_size):\n    p = pipeline(\"sentiment-analysis\", model, tokenizer=tokenizer, device=device)\n    results = p(KeyDataset(dataset, \"sentence\"), batch_size=batch_size)\n    start = time.time()\n    pred_labels = [0 if result[\"label\"] == \"NEGATIVE\" else 1 for result in results]\n    end = time.time()\n    accuracy = np.sum(np.equal(pred_labels, dataset[\"label\"])) / len(pred_labels)\n    print(f\"{len(pred_labels)} sentences evaluated in {end - start:.2f} s. accuracy = {accuracy}\")\n\n\ndef keyword_to_itype(k):\n    return {\"none\": None, \"int8\": qint8, \"int4\": qint4}[k]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Transformers SST2 Example\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"distilbert-base-uncased-finetuned-sst-2-english\",\n        help=\"The name of the trained Model.\",\n    )\n    parser.add_argument(\"--samples\", type=int, default=872, help=\"The number of sst2 samples to use for evaluation.\")\n    parser.add_argument(\"--batch_size\", type=int, default=100, help=\"The batch size to use for evaluation.\")\n    parser.add_argument(\"--weights\", type=str, default=\"int8\", choices=[\"int4\", \"int8\"])\n    parser.add_argument(\"--activations\", type=str, default=\"int8\", choices=[\"none\", \"int8\"])\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for evaluation.\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device)\n    tokenizer = AutoTokenizer.from_pretrained(args.model)\n    dataset = load_dataset(\"sst2\", split=f\"validation[:{args.samples}]\")\n\n    print(\"Float model\")\n    evaluate_model(model, tokenizer, dataset, device, args.batch_size)\n    weights = keyword_to_itype(args.weights)\n    activations = keyword_to_itype(args.activations)\n    quantize(model, weights=weights, activations=activations)\n    if activations is not None:\n        print(\"Calibrating ...\")\n        with Calibration():\n            evaluate_model(model, tokenizer, dataset, device, args.batch_size)\n    freeze(model)\n    print(f\"Quantized model (w: {args.weights}, a: {args.activations})\")\n    evaluate_model(model, tokenizer, dataset, device, args.batch_size)\n    b = io.BytesIO()\n    torch.save(model.state_dict(), b)\n    b.seek(0)\n    state_dict = torch.load(b)\n    model_reloaded = AutoModelForSequenceClassification.from_pretrained(args.model).to(device)\n    quantize(model_reloaded, weights=weights, activations=activations)\n    model_reloaded.load_state_dict(state_dict)\n    print(\"Serialized quantized model\")\n    evaluate_model(model, tokenizer, dataset, device, args.batch_size)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/nlp/text-generation/quantize_causal_lm_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport time\n\nimport torch\nfrom datasets import load_dataset\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8\n\n\n@torch.no_grad()\ndef generate(model, tokenizer, device, prompt, max_new_tokens):\n    inputs = tokenizer(prompt, return_tensors=\"pt\", padding=True)\n    start = time.time()\n    outputs = model.generate(\n        input_ids=inputs.input_ids.to(device),\n        max_new_tokens=max_new_tokens,\n        attention_mask=inputs.attention_mask.to(device),\n        do_sample=True,\n        top_k=50,\n        top_p=0.9,\n    )\n    end = time.time()\n    generated_text = tokenizer.decode(outputs[0])\n    print(f\"Generated '{generated_text}' in [{end - start:.2f} s]\")\n\n\n@torch.no_grad()\ndef calibrate(model, tokenizer, dataset, device, batch_size, samples=None):\n    model.eval()\n    total = 0\n    for batch in dataset.iter(batch_size=batch_size):\n        inputs = tokenizer(batch[\"text\"], return_tensors=\"pt\", padding=True)\n        input_ids = inputs.input_ids.to(device)\n        attention_mask = inputs.attention_mask.to(device)\n        model(input_ids, attention_mask=attention_mask)\n        total += input_ids.size(0)\n        if samples is not None and total >= samples:\n            break\n\n\ndef keyword_to_itype(k):\n    return {\n        \"none\": None,\n        \"int4\": qint4,\n        \"int8\": qint8,\n        \"float8\": qfloat8,\n    }[k]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Transformers Causal LM Example\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"facebook/opt-350m\",\n        help=\"The name of the trained Model.\",\n    )\n    parser.add_argument(\"--prompt\", type=str, default=\"One of my fondest memory is\", help=\"The generation prompt.\")\n    parser.add_argument(\"--max_new_tokens\", type=int, default=20, help=\"The maximum number of tokens to generate.\")\n    parser.add_argument(\"--batch_size\", type=int, default=32, help=\"The batch_size for evaluation (and calibration).\")\n    parser.add_argument(\"--validation_batch\", type=int, default=4, help=\"The number of batch to use for calibration.\")\n    parser.add_argument(\n        \"--load_dtype\",\n        type=str,\n        default=\"float16\",\n        choices=[\"float16\", \"float32\", \"bfloat16\"],\n        help=\"Precision to load the initial model\",\n    )\n    parser.add_argument(\n        \"--weights\",\n        type=str,\n        default=\"int8\",\n        choices=[\"int4\", \"int8\", \"float8\"],\n    )\n    parser.add_argument(\n        \"--activations\",\n        type=str,\n        default=\"int8\",\n        choices=[\"none\", \"int8\", \"float8\"],\n    )\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    parser.add_argument(\n        \"--no-streamline\",\n        action=\"store_false\",\n        help=\"Do not remove consecutive quantize/dequantize (not recommended).\",\n    )\n    parser.add_argument(\n        \"--debug\", action=\"store_true\", help=\"Provide detailed feedback on the console during calibration.\"\n    )\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    torch_dtype = (\n        torch.float16\n        if args.load_dtype == \"float16\"\n        else torch.bfloat16\n        if args.load_dtype == \"bfloat16\"\n        else torch.float32\n    )\n    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True).to(\n        device\n    )\n    tokenizer = AutoTokenizer.from_pretrained(args.model)\n    tokenizer.pad_token_id = tokenizer.eos_token_id\n    tokenizer.padding_side = \"left\"\n    cal_dataset = load_dataset(\"lambada\", split=[\"validation\"])[0]\n\n    print(f\"{args.model} (w: {args.weights}, a: {args.activations})\")\n    weights = keyword_to_itype(args.weights)\n    activations = keyword_to_itype(args.activations)\n    qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations)\n    if activations is not None:\n        print(\"Calibrating ...\")\n        cal_dataset.shuffle(args.seed)\n        with Calibration(streamline=args.no_streamline, debug=args.debug):\n            cal_samples = args.batch_size * args.validation_batch\n            calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples)\n    generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/speech/speech_recognition/quantize_asr_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# REQUIRES: librosa, soundfile\nimport argparse\nimport io\nimport time\nfrom functools import partial\n\nimport evaluate\nimport numpy as np\nimport torch\nfrom datasets import load_dataset\nfrom evaluate import load\nfrom transformers import WhisperForConditionalGeneration, WhisperProcessor\n\nfrom optimum.quanto import Calibration, freeze, qint4, qint8, quantize\n\n\ndef map_to_feats(batch, processor):\n    audio = batch[\"audio\"]\n    input_features = processor(\n        audio[\"array\"], sampling_rate=audio[\"sampling_rate\"], return_tensors=\"pt\"\n    ).input_features\n    batch[\"input_features\"] = input_features\n    batch[\"reference\"] = processor.tokenizer.normalize(batch[\"text\"])\n\n    return batch\n\n\ndef transcribe_batch(batch, model, processor):\n    with torch.no_grad():\n        features = torch.from_numpy(np.array(batch[\"input_features\"], dtype=np.float32)).squeeze(1)\n        predicted_ids = model.generate(features.to(model.device))\n    transcription = [processor.decode(ids) for ids in predicted_ids]\n    batch[\"prediction\"] = [processor.tokenizer.normalize(x) for x in transcription]\n    return batch\n\n\ndef evaluate_model(model, processor, dataset, metric: evaluate.EvaluationModule, batch_size=10):\n    map_fn = partial(transcribe_batch, model=model, processor=processor)\n    start = time.time()\n    result = dataset.map(map_fn, batched=True, batch_size=batch_size)\n    end = time.time()\n    score = 100 * metric.compute(references=result[\"reference\"], predictions=result[\"prediction\"])\n    print(score)\n    print(f\"{len(result)} sentences evaluated in {end - start:.2f} s. {metric.name} = {score}\")\n\n\ndef keyword_to_itype(k):\n    return {\"none\": None, \"int8\": qint8, \"int4\": qint4}[k]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Transformers Whisper Example\")\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"openai/whisper-medium\",\n        help=\"The name of the trained Model.\",\n    )\n    parser.add_argument(\n        \"--samples\", type=int, default=872, help=\"The number of librispeech samples to use for evaluation.\"\n    )\n    parser.add_argument(\"--batch_size\", type=int, default=10, help=\"The batch size to use for evaluation.\")\n    parser.add_argument(\"--weights\", type=str, default=\"int8\", choices=[\"int4\", \"int8\"])\n    parser.add_argument(\"--activations\", type=str, default=\"int8\", choices=[\"none\", \"int8\"])\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for evaluation.\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            print(\"USING CUDA\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        else:\n            device = torch.device(\"cpu\")\n            print(\"USING CPU\")\n    else:\n        device = torch.device(args.device)\n\n    model = WhisperForConditionalGeneration.from_pretrained(args.model).to(device)\n    model.config.forced_decoder_ids = None\n    processor = WhisperProcessor.from_pretrained(args.model)\n    dataset = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n    processed_dataset = dataset.map(lambda x: map_to_feats(x, processor))\n    wer = load(\"wer\")\n\n    print(\"Float model:\")\n    evaluate_model(model, processor, processed_dataset, wer, args.batch_size)\n    weights = keyword_to_itype(args.weights)\n    activations = keyword_to_itype(args.activations)\n    quantize(model, weights=weights, activations=activations)\n    if activations is not None:\n        print(\"Calibrating ...\")\n        with Calibration():\n            evaluate_model(model, processor, processed_dataset, wer, args.batch_size)\n    freeze(model)\n    print(f\"Quantized model (w: {args.weights}, a: {args.activations})\")\n    evaluate_model(model, processor, processed_dataset, wer, args.batch_size)\n    b = io.BytesIO()\n    torch.save(model.state_dict(), b)\n    b.seek(0)\n    state_dict = torch.load(b)\n    model_reloaded = WhisperForConditionalGeneration.from_pretrained(args.model).to(device)\n    quantize(model_reloaded, weights=weights, activations=activations)\n    model_reloaded.load_state_dict(state_dict)\n    print(\"Serialized quantized model\")\n    evaluate_model(model, processor, processed_dataset, wer, args.batch_size)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/speech/speech_recognition/requirements.txt",
    "content": "transformers\nevaluate\nlibrosa\nsoundfile\njiwer\n"
  },
  {
    "path": "examples/vision/StableDiffusion/README.md",
    "content": "# Quantize Stable Diffusion examples\n\n## Running locally with PyTorch\n\n### Installing the dependencies\n\nBefore running the scripts, make sure to install the library's training dependencies:\n\n**Important**\n\nTo make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:\n```bash\ngit clone https://github.com/huggingface/quanto\ncd quanto\npip install -e .\n```\n\nThen cd in the `examples/vision/StableDiffusion` folder and run\n```bash\npip install -r requirements.txt\n```\n\n**Now, we can launch the image generation script:**\n\n```bash\npython quantize_StableDiffusion.py --batch_size=1 --torch_dtype=\"fp32\"\n```\n\nTo better track our training experiments, we're using the following flags in the command above:\n\n* `batch_size` Batch size is the number of samples used in one iteration of training.\n\n* `torch_dtype` {fp32,fp16,bf16}\n* `unet_qtype` {fp8,int8,int4,none}\n\nOur experiments were conducted on a single 24GB A10 GPU.\n```bash\nfp16-fp16\n\nbatch_size: 1, torch_dtype: fp16, unet_dtype: none  in 3.307 seconds.Memory: 3.192GB.\n```\n\n```bash\nbf16-int8\n\nbatch_size: 1, torch_dtype: bf16, unet_dtype: int8  in 3.918 seconds.Memory: 2.644GB.\n```\n\n```bash\nfp16-int8\n\nbatch_size: 1, torch_dtype: fp16, unet_dtype: int8  in 3.920 seconds.Memory: 2.634GB.\n``` \n\nwill both get high-quality images at fast speed generation"
  },
  {
    "path": "examples/vision/StableDiffusion/quantize_StableDiffusion.py",
    "content": "import argparse\nimport gc\n\nimport torch\nimport torch.utils.benchmark as benchmark\nfrom diffusers import DiffusionPipeline\n\nfrom optimum.quanto import freeze, qfloat8, qint4, qint8, quantize\n\n\nCKPT = \"runwayml/stable-diffusion-v1-5\"\nNUM_INFERENCE_STEPS = 50\nWARM_UP_ITERS = 5\nPROMPT = \"ghibli style, a fantasy landscape with castles\"\n\nTORCH_DTYPES = {\"fp32\": torch.float32, \"fp16\": torch.float16, \"bf16\": torch.bfloat16}\nUNET_QTYPES = {\n    \"fp8\": qfloat8,\n    \"int8\": qint8,\n    \"int4\": qint4,\n    \"none\": None,\n}\n\n\ndef load_pipeline(torch_dtype, unet_dtype=None, device=\"cpu\"):\n    pipe = DiffusionPipeline.from_pretrained(CKPT, torch_dtype=torch_dtype, use_safetensors=True).to(device)\n\n    if unet_dtype:\n        quantize(pipe.unet, weights=unet_dtype)\n        freeze(pipe.unet)\n\n    pipe.set_progress_bar_config(disable=True)\n    return pipe\n\n\ndef run_inference(pipe, batch_size=1):\n    _ = pipe(\n        prompt=args.prompt,\n        num_inference_steps=args.num_inference_steps,\n        num_images_per_prompt=args.batch_size,\n        generator=torch.manual_seed(0),\n    )\n\n\ndef benchmark_fn(f, *args, **kwargs):\n    t0 = benchmark.Timer(stmt=\"f(*args, **kwargs)\", globals={\"args\": args, \"kwargs\": kwargs, \"f\": f})\n    return f\"{(t0.blocked_autorange().mean):.3f}\"\n\n\ndef bytes_to_giga_bytes(bytes):\n    return f\"{(bytes / 1024 / 1024 / 1024):.3f}\"\n\n\ndef get_device_memory(device):\n    gc.collect()\n    if device.type == \"cuda\":\n        torch.cuda.empty_cache()\n        return torch.cuda.memory_allocated()\n    elif device.type == \"mps\":\n        torch.mps.empty_cache()\n        return torch.mps.current_allocated_memory()\n    elif device.type == \"xpu\":\n        torch.xpu.empty_cache()\n        return torch.xpu.memory_allocated()\n    return None\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--prompt\", type=str, default=\"ghibli style, a fantasy landscape with castles\")\n    parser.add_argument(\"--output_path\", type=str, default=None)\n    parser.add_argument(\"--num_inference_steps\", type=int, default=50, help=\"Number of inference steps\")\n    parser.add_argument(\"--batch_size\", type=int, default=1)\n    parser.add_argument(\"--torch_dtype\", type=str, default=\"fp32\", choices=list(TORCH_DTYPES.keys()))\n    parser.add_argument(\"--unet_qtype\", type=str, default=None, choices=list(UNET_QTYPES.keys()))\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    args = parser.parse_args()\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    pipeline = load_pipeline(\n        TORCH_DTYPES[args.torch_dtype], UNET_QTYPES[args.unet_qtype] if args.unet_qtype else None, device\n    )\n\n    for _ in range(WARM_UP_ITERS):\n        run_inference(pipeline, args.batch_size)\n\n    time = benchmark_fn(run_inference, pipeline, args.batch_size)\n    if device.type == \"cuda\":\n        memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())  # in GBs.\n    elif device.type == \"xpu\":\n        memory = bytes_to_giga_bytes(torch.xpu.max_memory_allocated())  # in GBs.\n    else:\n        memory = 0\n    get_device_memory(device)\n    print(\n        f\"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_qtype}  in {time} seconds.\"\n    )\n    print(f\"Memory: {memory}GB.\")\n\n    img_name = f\"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_qtype}.png\"\n    image = pipeline(\n        prompt=args.prompt,\n        num_inference_steps=NUM_INFERENCE_STEPS,\n        num_images_per_prompt=args.batch_size,\n    ).images[0]\n    image.save(img_name)\n"
  },
  {
    "path": "examples/vision/StableDiffusion/requirements.txt",
    "content": "quanto\ndiffusers\ntorch\ntransformers\naccelerate\nwandb"
  },
  {
    "path": "examples/vision/image-classification/mnist/quantize_mnist_model.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport argparse\nimport time\nfrom tempfile import NamedTemporaryFile\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import init_empty_weights\nfrom safetensors.torch import load_file, save_file\nfrom torchvision import datasets, transforms\nfrom transformers import AutoConfig, AutoModel\n\nfrom optimum.quanto import (\n    Calibration,\n    QTensor,\n    freeze,\n    qfloat8,\n    qint4,\n    qint8,\n    quantization_map,\n    quantize,\n    requantize,\n)\n\n\ndef test(model, device, test_loader):\n    model.to(device)\n    model.eval()\n    test_loss = 0\n    correct = 0\n    with torch.no_grad():\n        start = time.time()\n        for data, target in test_loader:\n            data, target = data.to(device), target.to(device)\n            output = model(data)\n            if isinstance(output, QTensor):\n                output = output.dequantize()\n            test_loss += F.nll_loss(output, target, reduction=\"sum\").item()  # sum up batch loss\n            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n            correct += pred.eq(target.view_as(pred)).sum().item()\n        end = time.time()\n\n    test_loss /= len(test_loader.dataset)\n\n    print(\n        \"\\nTest set evaluated in {:.2f} s: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n\".format(\n            end - start, test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)\n        )\n    )\n\n\ndef train(log_interval, model, device, train_loader, optimizer, epoch):\n    model.to(device)\n    model.train()\n    for batch_idx, (data, target) in enumerate(train_loader):\n        data, target = data.to(device), target.to(device)\n        optimizer.zero_grad()\n        output = model(data)\n        if isinstance(output, QTensor):\n            output = output.dequantize()\n        loss = F.nll_loss(output, target)\n        loss.backward()\n        optimizer.step()\n        if batch_idx % log_interval == 0:\n            print(\n                \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n                    epoch,\n                    batch_idx * len(data),\n                    len(train_loader.dataset),\n                    100.0 * batch_idx / len(train_loader),\n                    loss.item(),\n                )\n            )\n\n\ndef keyword_to_itype(k):\n    return {\"none\": None, \"int4\": qint4, \"int8\": qint8, \"float8\": qfloat8}[k]\n\n\ndef main():\n    # Training settings\n    parser = argparse.ArgumentParser(description=\"PyTorch MNIST Example\")\n    parser.add_argument(\n        \"--batch-size\", type=int, default=250, metavar=\"N\", help=\"input batch size for testing (default: 250)\"\n    )\n    parser.add_argument(\"--seed\", type=int, default=1, metavar=\"S\", help=\"random seed (default: 1)\")\n    parser.add_argument(\"--model\", type=str, default=\"dacorvo/mnist-mlp\", help=\"The name of the trained Model.\")\n    parser.add_argument(\"--weights\", type=str, default=\"int8\", choices=[\"int4\", \"int8\", \"float8\"])\n    parser.add_argument(\"--activations\", type=str, default=\"int8\", choices=[\"none\", \"int8\", \"float8\"])\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for evaluation.\")\n    args = parser.parse_args()\n\n    torch.manual_seed(args.seed)\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    dataset_kwargs = {\"batch_size\": args.batch_size}\n    if torch.cuda.is_available() or torch.xpu.is_available():\n        backend_kwargs = {\"num_workers\": 1, \"pin_memory\": True, \"shuffle\": True}\n        dataset_kwargs.update(backend_kwargs)\n\n    transform = transforms.Compose(\n        [\n            transforms.ToTensor(),\n            transforms.Normalize((0.1307,), (0.3081,)),\n            transforms.Lambda(lambda x: torch.flatten(x)),\n        ]\n    )\n    dataset1 = datasets.MNIST(\"./data\", train=True, download=True, transform=transform)\n    train_loader = torch.utils.data.DataLoader(dataset1, **dataset_kwargs)\n    dataset2 = datasets.MNIST(\"./data\", train=False, download=True, transform=transform)\n    test_loader = torch.utils.data.DataLoader(dataset2, **dataset_kwargs)\n    model = AutoModel.from_pretrained(args.model, trust_remote_code=True)\n    model.eval()\n    print(\"Float model\")\n    test(model, device, test_loader)\n    weights = keyword_to_itype(args.weights)\n    activations = keyword_to_itype(args.activations)\n    quantize(model, weights=weights, activations=activations)\n    if activations is not None:\n        print(\"Calibrating ...\")\n        with Calibration():\n            test(model, device, test_loader)\n    print(f\"Quantized model (w: {args.weights}, a: {args.activations})\")\n    test(model, device, test_loader)\n    print(\"Tuning quantized model for one epoch\")\n    optimizer = torch.optim.Adadelta(model.parameters(), lr=0.5)\n    train(50, model, device, train_loader, optimizer, 1)\n    print(\"Quantized tuned model\")\n    test(model, device, test_loader)\n    print(\"Quantized frozen model\")\n    freeze(model)\n    test(model, device, test_loader)\n    # Serialize model to a state_dict, save it to disk and reload it\n    with NamedTemporaryFile() as tmp_file:\n        save_file(model.state_dict(), tmp_file.name)\n        state_dict = load_file(tmp_file.name)\n    model_reloaded = AutoModel.from_pretrained(args.model, trust_remote_code=True)\n    # Create an empty model\n    config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)\n    with init_empty_weights():\n        model_reloaded = AutoModel.from_config(config, trust_remote_code=True)\n    # Requantize it using the serialized state_dict\n    requantize(model_reloaded, state_dict, quantization_map(model), device)\n    print(\"Serialized quantized model\")\n    test(model_reloaded, device, test_loader)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/vision/image-classification/pets/quantize_vit_model.py",
    "content": "import argparse\nimport time\nfrom tempfile import NamedTemporaryFile\n\nimport torch\nimport torch.nn.functional as F\nfrom accelerate import init_empty_weights\nfrom datasets import load_dataset\nfrom safetensors.torch import load_file, save_file\nfrom transformers import (\n    ViTConfig,\n    ViTForImageClassification,\n    ViTImageProcessor,\n)\n\nfrom optimum.quanto import (\n    Calibration,\n    QTensor,\n    freeze,\n    qfloat8,\n    qint4,\n    qint8,\n    quantization_map,\n    quantize,\n    requantize,\n)\n\n\ndef test(model, device, test_loader):\n    model.to(device)\n    model.eval()\n    test_loss = 0\n    correct = 0\n    with torch.no_grad():\n        start = time.time()\n        for batch in test_loader:\n            data, target = batch[\"pixel_values\"], batch[\"labels\"]\n            data, target = data.to(device), target.to(device)\n            output = model(data).logits\n            if isinstance(output, QTensor):\n                output = output.dequantize()\n            test_loss += F.nll_loss(output, target, reduction=\"sum\").item()  # sum up batch loss\n            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability\n            correct += pred.eq(target.view_as(pred)).sum().item()\n        end = time.time()\n\n    test_loss /= len(test_loader.dataset)\n\n    print(\n        \"\\nTest set evaluated in {:.2f} s: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n\".format(\n            end - start, test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)\n        )\n    )\n\n\ndef keyword_to_itype(k):\n    return {\"none\": None, \"int4\": qint4, \"int8\": qint8, \"float8\": qfloat8}[k]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"ViT PETS Example\")\n    parser.add_argument(\"--model\", type=str, default=\"super-j/vit-base-pets\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for evaluation.\")\n    parser.add_argument(\"--weights\", type=str, default=\"int8\", choices=[\"int4\", \"int8\", \"float8\"])\n    parser.add_argument(\"--activations\", type=str, default=\"int8\", choices=[\"none\", \"int8\", \"float8\"])\n    args = parser.parse_args()\n\n    dataset_kwargs = {}\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n            cuda_kwargs = {\"num_workers\": 1, \"pin_memory\": True, \"shuffle\": True}\n            dataset_kwargs.update(cuda_kwargs)\n        elif all([torch.backends.mps.is_available(), args.weights != \"float8\", args.activations != \"float8\"]):\n            device = torch.device(\"mps\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    # load  the processor and model\n    model_name = args.model\n    processor = ViTImageProcessor.from_pretrained(model_name)\n    model = ViTForImageClassification.from_pretrained(model_name)\n\n    def transform(data_batch):\n        # Take a list of PIL images and turn them to pixel values\n        inputs = processor(data_batch[\"image\"], return_tensors=\"pt\")\n\n        # Don't forget to include the labels!\n        inputs[\"labels\"] = data_batch[\"label\"]\n        return inputs\n\n    ds = load_dataset(\"rokmr/pets\")\n    prepared_ds = ds.with_transform(transform)\n    test_loader = torch.utils.data.DataLoader(prepared_ds[\"test\"], **dataset_kwargs)\n    print(\"Model before quantization...\")\n    test(model, device, test_loader)\n    weights = keyword_to_itype(args.weights)\n    activations = keyword_to_itype(args.activations)\n    quantize(model, weights=weights, activations=activations)\n    if activations is not None:\n        print(\"Calibrating ...\")\n        with Calibration():\n            test(model, device, test_loader)\n    print(f\"Quantized model (w: {args.weights}, a: {args.activations})\")\n    test(model, device, test_loader)\n    print(\"Quantized frozen model\")\n    freeze(model)\n    test(model, device, test_loader)\n    # Serialize model to a state_dict, save it to disk and reload it\n    with NamedTemporaryFile() as tmp_file:\n        save_file(model.state_dict(), tmp_file.name)\n        state_dict = load_file(tmp_file.name)\n    model_reloaded = ViTForImageClassification.from_pretrained(model_name)\n    # Create an empty model\n    config = ViTConfig.from_pretrained(model_name)\n    with init_empty_weights():\n        model_reloaded = ViTForImageClassification.from_pretrained(model_name, config=config)\n    # Requantize it using the serialized state_dict\n    requantize(model_reloaded, state_dict, quantization_map(model), device)\n    print(\"Serialized quantized model\")\n    test(model_reloaded, device, test_loader)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/vision/object-detection/quantize_owl_model.py",
    "content": "import argparse\nimport gc\n\nimport numpy as np\nimport requests\nimport torch\nfrom PIL import Image\nfrom transformers import AutoProcessor, Owlv2ForObjectDetection\nfrom transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD\n\nfrom optimum.quanto import freeze, qfloat8, qint4, qint8, quantize\n\n\ndef detect(model, processor, image, texts):\n    inputs = processor(text=texts, images=image, return_tensors=\"pt\").to(model.device)\n\n    # forward pass\n    with torch.no_grad():\n        outputs = model(**inputs)\n\n    # Note: boxes need to be visualized on the padded, unnormalized image\n    # hence we'll set the target image sizes (height, width) based on that\n    def get_preprocessed_image(pixel_values):\n        pixel_values = pixel_values.squeeze().cpu().numpy()\n        unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[\n            :, None, None\n        ]\n        unnormalized_image = (unnormalized_image * 255).astype(np.uint8)\n        unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)\n        unnormalized_image = Image.fromarray(unnormalized_image)\n        return unnormalized_image\n\n    unnormalized_image = get_preprocessed_image(inputs.pixel_values)\n\n    target_sizes = torch.Tensor([unnormalized_image.size[::-1]])\n    # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores\n    results = processor.post_process_object_detection(outputs=outputs, threshold=0.2, target_sizes=target_sizes)\n\n    i = 0  # Retrieve predictions for the first image for the corresponding text queries\n    text = texts[i]\n    boxes, scores, labels = results[i][\"boxes\"], results[i][\"scores\"], results[i][\"labels\"]\n\n    if len(boxes) == 0:\n        print(\"None of the specified labels were detected\")\n        return\n\n    for box, score, label in zip(boxes, scores, labels):\n        box = [round(i, 2) for i in box.tolist()]\n        print(f\"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}\")\n\n\ndef get_device_memory(device):\n    gc.collect()\n    if device.type == \"cuda\":\n        torch.cuda.empty_cache()\n        return torch.cuda.memory_allocated()\n    elif device.type == \"mps\":\n        torch.mps.empty_cache()\n        return torch.mps.current_allocated_memory()\n    elif device.type == \"xpu\":\n        torch.xpu.empty_cache()\n        return torch.xpu.memory_allocated()\n    return None\n\n\ndef keyword_to_qtype(k):\n    return {\"none\": None, \"int4\": qint4, \"int8\": qint8, \"float8\": qfloat8}[k]\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"google/owlv2-base-patch16\")\n    parser.add_argument(\"--image\", type=str, required=True)\n    parser.add_argument(\"--texts\", type=str, nargs=\"+\", required=True)\n    parser.add_argument(\"--weights\", type=str, default=\"none\", choices=[\"none\", \"int4\", \"int8\", \"float8\"])\n    parser.add_argument(\"--exclude-heads\", action=\"store_true\", help=\"Do not quantize detection heads\")\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    args = parser.parse_args()\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            # MPS backend does not support torch.float64 that is required for owl models\n            device = torch.device(\"cpu\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    processor = AutoProcessor.from_pretrained(args.model)\n    model = Owlv2ForObjectDetection.from_pretrained(args.model, low_cpu_mem_usage=True).to(device)\n\n    weights_qtype = keyword_to_qtype(args.weights)\n    if weights_qtype is not None:\n        if args.exclude_heads:\n            quantize(model.owlv2, weights=weights_qtype)\n        else:\n            quantize(model, weights=weights_qtype)\n        freeze(model)\n\n    memory = get_device_memory(device)\n    if memory is not None:\n        memory_gb = memory / 2**30\n        print(f\"{device.type} device memory: {memory_gb:.2f} GB.\")\n\n    image_path = args.image\n    if image_path.startswith(\"http\"):\n        image_path = requests.get(args.image, stream=True).raw\n    image = Image.open(image_path)\n\n    texts = [args.texts]\n    detect(model, processor, image, texts)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/vision/text-to-image/quantize_pixart_sigma.py",
    "content": "import argparse\nimport gc\n\nimport torch\nfrom diffusers import DiffusionPipeline\n\nfrom optimum.quanto import freeze, qfloat8, qint4, qint8, quantize\n\n\nNUM_INFERENCE_STEPS = 50\n\nTORCH_DTYPES = {\"fp16\": torch.float16, \"bf16\": torch.bfloat16}\nQTYPES = {\n    \"fp8\": qfloat8,\n    \"int8\": qint8,\n    \"int4\": qint4,\n    \"none\": None,\n}\n\n\ndef load_pipeline(model_id, torch_dtype, qtype=None, device=\"cpu\"):\n    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True).to(device)\n\n    if qtype:\n        quantize(pipe.transformer, weights=qtype)\n        freeze(pipe.transformer)\n        quantize(pipe.text_encoder, weights=qtype)\n        freeze(pipe.text_encoder)\n\n    pipe.set_progress_bar_config(disable=True)\n    return pipe\n\n\ndef get_device_memory(device):\n    gc.collect()\n    if device.type == \"cuda\":\n        torch.cuda.empty_cache()\n        return torch.cuda.memory_allocated()\n    elif device.type == \"mps\":\n        torch.mps.empty_cache()\n        return torch.mps.current_allocated_memory()\n    elif device.type == \"xpu\":\n        torch.xpu.empty_cache()\n        return torch.xpu.memory_allocated()\n    return None\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model_id\", type=str, default=\"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS\")\n    parser.add_argument(\"--prompt\", type=str, default=\"ghibli style, a fantasy landscape with castles\")\n    parser.add_argument(\"--torch_dtype\", type=str, default=\"fp16\", choices=list(TORCH_DTYPES.keys()))\n    parser.add_argument(\"--qtype\", type=str, default=None, choices=list(QTYPES.keys()))\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    args = parser.parse_args()\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        elif torch.xpu.is_available():\n            device = torch.device(\"xpu\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    pipeline = load_pipeline(\n        args.model_id, TORCH_DTYPES[args.torch_dtype], QTYPES[args.qtype] if args.qtype else None, device\n    )\n\n    print(f\"torch_dtype: {args.torch_dtype}, qtype: {args.qtype}.\")\n    memory = get_device_memory(device)\n    if memory is not None:\n        memory_gb = memory / 2**30\n        print(f\"{device.type} device memory: {memory_gb:.2f} GB.\")\n\n    if args.qtype == \"int4\" and device.type == \"CUDA\":\n        raise ValueError(\"This example does not work (yet) for int4 on CUDA\")\n\n    img_name = f\"pixart-sigma-dtype@{args.torch_dtype}-qtype@{args.qtype}.png\"\n    image = pipeline(\n        prompt=args.prompt,\n        num_inference_steps=NUM_INFERENCE_STEPS,\n        num_images_per_prompt=1,\n        generator=torch.manual_seed(0),\n    ).images[0]\n    image.save(img_name)\n"
  },
  {
    "path": "external/awq/conftest.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\n\ndevices = [\"cpu\"]\nif torch.cuda.is_available():\n    devices += [\"cuda\"]\nelif torch.backends.mps.is_available():\n    devices += [\"mps\"]\n\n\n@pytest.fixture(scope=\"module\", params=devices)\ndef device(request):\n    return torch.device(request.param)\n\n\ndef pytest_configure(config):\n    # register additional markers\n    config.addinivalue_line(\"markers\", \"skip_device(type): mark test to be skipped for the specified device type\")\n\n\ndef pytest_runtest_call(item):\n    fixture_name = \"device\"\n    if fixture_name in item.fixturenames:\n        # TODO: should be able to recover the fixture id instead of the actual value\n        fixture_arg = item.funcargs[fixture_name].type\n        skip_marks = {mark.args[0] for mark in item.iter_markers(name=f\"skip_{fixture_name}\")}\n        if fixture_arg in skip_marks:\n            pytest.skip(f\"Test skipped for {fixture_name} {fixture_arg}\")\n"
  },
  {
    "path": "external/awq/pack_intweight.py",
    "content": "# MIT License\n#\n# Copyright (c) 2023 MIT HAN Lab\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\nimport torch\n\n\ndef pack_intweight(unpacked_qweight, interleave, kstride):\n    # unpacked_qweight: [N, K]\n    N = unpacked_qweight.shape[0]\n    K = unpacked_qweight.shape[1]\n\n    Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)\n    # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]\n    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)\n    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)\n\n    # reorder each 8 weights for fast dequantization\n    # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]\n    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)\n    Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)\n    Packed_Kernel = Packed_Kernel.reshape(N, K)\n\n    # interleaving every four rows\n    Packed_Kernel = Packed_Kernel.reshape(\n        N // interleave, interleave, K // kstride, kstride\n    )\n    # N // 4, K // 64, 4, 64\n    Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)\n    Packed_Kernel = Packed_Kernel.reshape(\n        N // interleave, K // kstride, kstride, interleave\n    )\n    # Packing -> (N // 4, K // 64, 64)\n    Packed_Kernel = (\n        Packed_Kernel[..., 0]\n        | (Packed_Kernel[..., 1] << 4)\n        | (Packed_Kernel[..., 2] << 8)\n        | (Packed_Kernel[..., 3] << 12)\n    )\n    # reshape to (N // 4, K), FP16 format\n    Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)\n    qweight = (\n        torch.tensor(Packed_Kernel.astype(\"int16\"))\n        .to(unpacked_qweight.device)\n        .contiguous()\n    )\n    return qweight\n"
  },
  {
    "path": "external/awq/packing_utils.py",
    "content": "import torch\n\n\nAWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]\nAWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n\n\ndef pack_awq(intweight: torch.Tensor, reorder=False):\n    bits = 4\n    pack_num = 32 // bits\n    qweight = torch.zeros(intweight.shape[0], intweight.shape[1] // pack_num, dtype=torch.int32, device=intweight.device)\n    for col in range(intweight.shape[1] // pack_num):\n        if reorder:\n            order_map = [0, 2, 4, 6, 1, 3, 5, 7]\n        else:\n            order_map = [0, 1, 2, 3, 4, 5, 6, 7]\n        for i in range(pack_num):\n            qweight_col = intweight[:, col * pack_num + order_map[i]]\n            qweight[:, col] |= qweight_col << (i * bits)\n    return qweight\n\n\ndef unpack_awq(qweight: torch.Tensor, bits: int):\n    shifts = torch.arange(0, 32, bits, device=qweight.device)\n\n    # unpacking columnwise\n    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(\n        torch.int8  # smallest dtype available\n    )\n    iweights = iweights.view(iweights.shape[0], -1)\n\n    return iweights\n\n\ndef reverse_awq_order(iweights: torch.Tensor, bits: int):\n    reverse_order_tensor = torch.arange(\n        iweights.shape[-1],\n        dtype=torch.int32,\n        device=iweights.device,\n    )\n    reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)\n    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]\n    reverse_order_tensor = reverse_order_tensor.view(-1)\n\n    iweights = iweights[:, reverse_order_tensor]\n\n    return iweights\n\n\ndef pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):\n    shifts = torch.arange(0, 32, bits, device=iweights.device)\n\n    # packing rowwise\n    iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1)\n    qweight = (\n        torch.bitwise_left_shift(iweights, shifts[None, :, None])\n        .sum(dim=1)\n        .to(torch.int32)\n    )\n\n    # packing columnwise\n    izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits)\n    qzeros = (\n        torch.bitwise_left_shift(izeros, shifts[None, None, :])\n        .sum(dim=-1)\n        .to(torch.int32)\n    )\n\n    return qweight, qzeros\n\n\ndef unpack_reorder_pack(qweight, qzeros, bits):\n    # Unpack the qweight and qzeros tensors\n    iweight, izeros = unpack_awq(qweight, qzeros, bits)\n    # Reverse the order of the iweight and izeros tensors\n    iweight, izeros = reverse_awq_order(iweight, izeros, bits)\n\n    # overflow checks\n    iweight = torch.bitwise_and(iweight, (2**bits) - 1)\n    izeros = torch.bitwise_and(izeros, (2**bits) - 1)\n\n    # Subtract 1 from the izeros tensor (exllama adds 1 during inference)\n    # We can remove it if we remove the +1 in the exllama code\n    izeros = izeros - 1\n    # Pack the qweight and qzeros tensors\n    qweight, qzeros = pack_exllama(iweight, izeros, bits)\n\n    return qweight, qzeros\n\n\ndef dequantize_gemm(qweight, qzeros, scales, bits, group_size):\n    # Unpack the qweight and qzeros tensors\n    iweight, izeros = unpack_awq(qweight, qzeros, bits)\n    # Reverse the order of the iweight and izeros tensors\n    iweight, izeros = reverse_awq_order(iweight, izeros, bits)\n\n    # overflow checks\n    iweight = torch.bitwise_and(iweight, (2**bits) - 1)\n    izeros = torch.bitwise_and(izeros, (2**bits) - 1)\n\n    # fp16 weights\n    scales = scales.repeat_interleave(group_size, dim=0)\n    izeros = izeros.repeat_interleave(group_size, dim=0)\n    iweight = (iweight - izeros) * scales\n\n    return iweight\n"
  },
  {
    "path": "external/awq/test_awq_kernels.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport pytest\nimport torch\nfrom pack import pack_awq\n\nfrom optimum.quanto import AffineQuantizer, MaxOptimizer, qint4, ungroup\n\n\ndef assert_similar(a, b, atol=None, rtol=None):\n    \"\"\"Verify that the cosine similarity of the two inputs is close to 1.0 everywhere\"\"\"\n    assert a.dtype == b.dtype\n    assert a.shape == b.shape\n    if atol is None:\n        # We use torch finfo resolution\n        atol = torch.finfo(a.dtype).resolution\n    if rtol is None:\n        # Please refer to that discussion for default rtol values based on the float type:\n        # https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision\n        rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype]\n    sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0)\n    if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol):\n        max_deviation = torch.min(sim)\n        raise ValueError(f\"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}\")\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 256), (512, 256)])\n@pytest.mark.parametrize(\"kernel\", [\"gemv\", \"gemm\"])\ndef test_standalone_kernel(in_features, out_features, kernel):\n    \"\"\"This test verifies that the GEMM operation is equivalent to torch.mm.\n    \"\"\"\n    bits = 4\n    group_size = 128 # Hard-coded in kernels\n    interleave = 4 # Hard-coded in kernels\n    kstride = 64 # Hard-coded in kernels\n    device = torch.device('cuda')\n    batch_size, tokens = (4, 1) if kernel ==\"gemv\" else (10, 128)\n    input_shape = (batch_size, tokens, in_features)\n    # FIXME: does not work if inputs are negative !!??\n    inputs = torch.rand(input_shape, dtype=torch.float16, device=device)\n    qmax = 2**bits\n    other_shape = (out_features, in_features)\n    other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device)\n    #packed_other_data = pack_intweight(other_data.to(torch.int32), interleave=interleave, kstride=kstride)\n    packed_other_data = pack_awq(other_data.to(torch.int32), interleave=interleave, kstride=kstride)\n    # The GEMM kernel works on transposed scales\n    scales_shape = (in_features // group_size, out_features)\n    other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax\n    # The GEMM kernel works on transposed, negated and scaled zeropoints\n    qmin = -2**(bits -1)\n    qmax = 2**(bits -1)\n    other_zeropoints = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device)\n    # Negate and scale\n    other_scaled_zeropoints = - other_zeropoints * other_scales\n    # Evaluate mm outputs using the GEMM kernel\n    if kernel == \"gemv\":\n        awq_outputs = torch.ops.quanto.gemv(inputs,\n                                         packed_other_data,\n                                         other_scales,\n                                         other_scaled_zeropoints,\n                                         rows=inputs.numel() // inputs.shape[-1],\n                                         out_cols=out_features,\n                                         in_cols=in_features,\n                                         bits=4,\n                                         group_size=group_size)\n    else:\n        awq_outputs = torch.ops.quanto.gemm(inputs,\n                                                  packed_other_data,\n                                                  other_scales,\n                                                  other_scaled_zeropoints,\n                                                  rows=inputs.numel() // inputs.shape[-1],\n                                                  out_cols=out_features,\n                                                  in_cols=in_features,\n                                                  bits=4,\n                                                  group_size=group_size)\n    # Transpose other data and reshape it to align it with transposed scales and zeros\n    other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features)\n    # Dequantize transposed other\n    other_t = (other_data_t - other_zeropoints) * other_scales\n    # Reshape it as expected by the matmul\n    other_t = other_t.reshape(in_features, out_features)\n    # Evaluate the matrix multiplication using pytorch float16 mm\n    pt_outputs = torch.matmul(inputs, other_t)\n    # Verify the results are similar\n    assert_similar(awq_outputs, pt_outputs, rtol=5e-3)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 256), (512, 256)])\n@pytest.mark.parametrize(\"kernel\", [\"gemm\", \"gemv\"])\ndef test_integrated_kernel(in_features, out_features, kernel):\n    group_size = 128 # Hard-coded in kernels\n    interleave = 4 # Hard-coded in kernels\n    kstride = 64 # Hard-coded in kernels\n    device = torch.device('cuda')\n    batch_size, tokens = (4, 1) if kernel == \"gemv\" else (10, 128)\n    input_shape = (batch_size, tokens, in_features)\n    inputs = torch.rand(input_shape, dtype=torch.float16, device=device) * 2 - 1\n    other_shape = (out_features, in_features)\n    other = torch.rand(other_shape, dtype=torch.float16, device=device) * 2 - 1\n    # Quantize using quanto\n    scale, zeropoint = MaxOptimizer()(other, bits=4, axis=0, group_size=128)\n    quanto_base = AffineQuantizer.apply(other, qint4, 0, group_size, scale, zeropoint)\n    # Evaluate mm\n    quanto_outputs = torch.matmul(inputs, quanto_base.t())\n\n    # Extract quantized data, unpack and ungroup to recover original shape\n    quanto_data = ungroup(quanto_base._data.unpack(), axis=0, orig_shape=other_shape)\n    # Pack data for AWQ kernel\n    awq_data = pack_awq(quanto_data.to(torch.int32), interleave=interleave, kstride=kstride)\n    # Reshape and transpose scale as expected by AWQ kernel (! buffer must be contiguous)\n    awq_scale = scale.reshape(out_features, in_features // group_size).t().contiguous()\n    # Reshape and transpose zeropoint as expected by AWQ kernel (! buffer must be contiguous)\n    awq_zeropoint = zeropoint.reshape(out_features, in_features // group_size).t().contiguous()\n    # Negate and rescale\n    awq_scaled_zeropoint = - awq_zeropoint * awq_scale\n\n    # Evaluate mm outputs using the AWQ kernels\n    if kernel == \"gemv\":\n        awq_outputs = torch.ops.quanto.gemv(inputs,\n                                         awq_data,\n                                         awq_scale,\n                                         awq_scaled_zeropoint,\n                                         rows=inputs.numel() // inputs.shape[-1],\n                                         out_cols=out_features,\n                                         in_cols=in_features,\n                                         bits=4,\n                                         group_size=group_size)\n    else:\n        awq_outputs = torch.ops.quanto.gemm(inputs,\n                                                  awq_data,\n                                                  awq_scale,\n                                                  awq_scaled_zeropoint,\n                                                  rows=inputs.numel() // inputs.shape[-1],\n                                                  out_cols=out_features,\n                                                  in_cols=in_features,\n                                                  bits=4,\n                                                  group_size=group_size)\n\n    # Verify the results are similar\n    assert_similar(awq_outputs, quanto_outputs, rtol=5e-3)\n"
  },
  {
    "path": "external/awq/test_awq_packing.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport numpy as np\nimport pytest\nimport torch\nfrom pack_intweight import pack_intweight\nfrom packing_utils import pack_awq, reverse_awq_order, unpack_awq\n\nfrom optimum.quanto import AWQPackedTensor, AWQPacking\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"reorder\", [True, False])\n@pytest.mark.parametrize(\"random\", [True, False])\ndef test_awq_pack(in_features, out_features, reorder, random):\n    \"\"\"This test verifies two things:\n\n    - that we are able to replicate awq packing,\n    - that we can unpack awq packed tensors and recover the original tensor.\n    \"\"\"\n    bits = 4\n    interleave = 4\n    kstride = 64\n    qmax = 2**bits\n    shape = (out_features, in_features)\n    device = torch.device('cuda')\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    packed = pack_awq(t.to(torch.int32), reorder=reorder)\n    # Sanity check: verify we can recover the Tensor using AWQ unpacking\n    unpacked = unpack_awq(packed, bits=4)\n    if reorder:\n        unpacked = reverse_awq_order(unpacked, bits=4)\n    unpacked = torch.bitwise_and(unpacked, qmax - 1)\n    assert torch.equal(t, unpacked)\n    # Compare with quanto packing\n    repacked = AWQPackedTensor.pack(t, packing=AWQPacking.V1, reorder=reorder)\n    assert torch.equal(packed, repacked._data)\n    unpacked = repacked.unpack()\n    assert torch.equal(unpacked, t)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"random\", [True, False])\ndef test_awq_pack_v2(in_features, out_features, random):\n    \"\"\"This test verifies two things:\n\n    - that we are able to replicate awq packing,\n    - that we can unpack awq packed tensors and recover the original tensor.\n    \"\"\"\n    bits = 4\n    interleave = 4\n    kstride = 64\n    qmax = 2**bits\n    shape = (out_features, in_features)\n    device = torch.device('cuda')\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    packed = pack_intweight(t.to(torch.int32), interleave=interleave, kstride=kstride)\n    # Compare with quanto packing\n    repacked = AWQPackedTensor.pack(t, packing=AWQPacking.V2)\n    assert torch.equal(packed, repacked._data)\n    unpacked = repacked.unpack()\n    assert torch.equal(unpacked, t)\n\n"
  },
  {
    "path": "external/awq/test_awq_quantize.py",
    "content": "import pytest\nimport torch\n\nfrom optimum.quanto import AffineQuantizer, MaxOptimizer, qint4, ungroup\n\n\ndef awq_quantize(base, scales, zeros, group_size):\n    _, in_features = base.shape\n    scale_zeros = scales * zeros\n    intweight = []\n    # From https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear/gemv_fast.py#L165\n    for idx in range(in_features):\n        intweight.append(\n            torch.round(\n                (base[:, idx] + scale_zeros[:, idx // group_size])\n                        / scales[:, idx // group_size]\n                    ).to(torch.uint8)[:, None]\n                )\n    intweight = torch.cat(intweight, dim=1)\n    return intweight\n\n\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 512), (1024, 1024)])\ndef test_awq_quantize(in_features, out_features):\n    \"\"\"Verify that AWQ quantization is equivalent to quanto affine quantization\n    \"\"\"\n    shape = (out_features, in_features)\n    base = torch.rand(shape, dtype=torch.float16)\n    group_size = 128\n\n    # Quantize using quanto\n    scale, zeropoint = MaxOptimizer()(base, bits=4, axis=0, group_size=128)\n    quanto_base = AffineQuantizer.apply(base, qint4, 0, group_size, scale, zeropoint)\n    # Extract quantized data, unpack and ungroup to recover original shape\n    quanto_data = ungroup(quanto_base._data.unpack(), axis=0, orig_shape=shape)\n\n    # Reshape scale and zeropoint as expected by awq\n    awq_shape = (out_features, in_features // group_size)\n    scale = scale.reshape(awq_shape)\n    zeropoint = zeropoint.reshape(awq_shape)\n\n    # Compare with awq quantization\n    awq_data = awq_quantize(base, scale, zeropoint, group_size)\n    # FIX: AWQ does not clamp values before packing\n    qmax = 2 ** 4 - 1\n    awq_data = torch.clamp(awq_data, 0, qmax)\n\n    mismatches = quanto_data != awq_data\n    n = torch.sum(mismatches).numpy()\n    rate = n / base.numel()\n    print(f\"Mismatches: {n}/{base.numel()} ({rate:.8f} %)\")\n    # Extract mismatches\n    display = 10\n    quanto_values = torch.masked_select(quanto_data, mismatches)[:display]\n    awq_values = torch.masked_select(awq_data, mismatches)[:display]\n    print(f\"First {display} mismatches\")\n    print(list(quanto_values.numpy()))\n    print(list(awq_values.numpy()))\n    # Due to a slightly different order of operations (zero is multiplied by scale before subtracting it),\n    # there are some mismatches\n    assert rate < 5e-4\n"
  },
  {
    "path": "external/smoothquant/README.md",
    "content": "# SmoothQuant original conversion script\n\nThis converts an OPT or Bloom [🤗 transformers](https://github.com/huggingface/transformers) model to a \"smoothed\" version, as described in\n[SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/abs/2211.10438).\n\n```bash\n$ python smoothquant.py --model facebook/opt-1.3b --save-path smoothed-models/facebook/opt-1.3b\n```\n\nNote: due to hard-coded assumptions on model architecture in the script this only works for OPT models that apply the layer_norm\nbefore the attention (`do_layer_norm_before=true` in `config.json`). This means all models but `facebook/opt-350m`.\n"
  },
  {
    "path": "external/smoothquant/smoothquant.py",
    "content": "import argparse\nimport functools\nimport os\n\nimport torch\nimport torch.nn as nn\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom transformers.models.bloom.modeling_bloom import BloomBlock\nfrom transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm\nfrom transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralRMSNorm\nfrom transformers.models.opt.modeling_opt import OPTDecoderLayer\n\n\ndef get_act_scales(model, tokenizer, dataset, num_samples=512, seq_len=512):\n    model.eval()\n    device = next(model.parameters()).device\n    act_scales = {}\n\n    def stat_tensor(name, tensor):\n        hidden_dim = tensor.shape[-1]\n        tensor = tensor.view(-1, hidden_dim).abs().detach()\n        comming_max = torch.max(tensor, dim=0)[0].float().cpu()\n        if name in act_scales:\n            act_scales[name] = torch.max(act_scales[name], comming_max)\n        else:\n            act_scales[name] = comming_max\n\n    def stat_input_hook(m, x, y, name):\n        if isinstance(x, tuple):\n            x = x[0]\n        stat_tensor(name, x)\n\n    hooks = []\n    for name, m in model.named_modules():\n        if isinstance(m, nn.Linear):\n            hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name)))\n\n    for i in tqdm(range(num_samples)):\n        input_ids = tokenizer(\n            dataset[i][\"text\"], return_tensors=\"pt\", max_length=seq_len, truncation=True\n        ).input_ids.to(device)\n        model(input_ids)\n\n    for h in hooks:\n        h.remove()\n\n    return act_scales\n\n\n@torch.no_grad()\ndef smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):\n    if not isinstance(fcs, list):\n        fcs = [fcs]\n    assert isinstance(ln, (nn.LayerNorm, LlamaRMSNorm, MistralRMSNorm))\n    for fc in fcs:\n        assert isinstance(fc, nn.Linear)\n        assert ln.weight.numel() == fc.in_features == act_scales.numel()\n\n    device, dtype = fcs[0].weight.device, fcs[0].weight.dtype\n    act_scales = act_scales.to(device=device, dtype=dtype)\n    weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)\n    weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)\n\n    scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)\n\n    ln.weight.div_(scales)\n    if getattr(ln, 'bias', None) is not None:\n        ln.bias.div_(scales)\n\n    for fc in fcs:\n        fc.weight.mul_(scales.view(1, -1))\n\n\n@torch.no_grad()\ndef smooth_lm(model, scales, alpha=0.5):\n    for name, module in model.named_modules():\n        if isinstance(module, OPTDecoderLayer):\n            attn_ln = module.self_attn_layer_norm\n            qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]\n            qkv_input_scales = scales[name + \".self_attn.q_proj\"]\n            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)\n\n            ffn_ln = module.final_layer_norm\n            fc1 = module.fc1\n            fc1_input_scales = scales[name + \".fc1\"]\n            smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)\n        elif isinstance(module, BloomBlock):\n            attn_ln = module.input_layernorm\n            qkv = module.self_attention.query_key_value\n            qkv_input_scales = scales[name + \".self_attention.query_key_value\"]\n            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)\n\n            ffn_ln = module.post_attention_layernorm\n            fc1 = module.mlp.dense_h_to_4h\n            fc1_input_scales = scales[name + \".mlp.dense_h_to_4h\"]\n            smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)\n        elif isinstance(module, (LlamaDecoderLayer, MistralDecoderLayer)):\n            attn_ln = module.input_layernorm\n            qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]\n            qkv_input_scales = scales[name + \".self_attn.q_proj\"]\n            smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)\n\n            ffn_ln = module.post_attention_layernorm\n            fc = [module.mlp.gate_proj, module.mlp.up_proj]\n            fc_input_scales = scales[name + \".mlp.gate_proj\"]\n            smooth_ln_fcs(ffn_ln, fc, fc_input_scales, alpha)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"facebook/opt-125m\", help=\"model name\")\n    parser.add_argument(\"--save-path\", type=str, default=None, help=\"smoothed model model save path\")\n    parser.add_argument(\"--num-samples\", type=int, default=512)\n    parser.add_argument(\"--seq-len\", type=int, default=512)\n    parser.add_argument(\"--device\", type=str, default=None, help=\"The device to use for generation.\")\n    args = parser.parse_args()\n\n    if args.device is None:\n        if torch.cuda.is_available():\n            device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            device = torch.device(\"mps\")\n        else:\n            device = torch.device(\"cpu\")\n    else:\n        device = torch.device(args.device)\n\n    dataset = load_dataset(\"lambada\", split=f\"validation[:{args.num_samples}]\").shuffle()\n    tokenizer = AutoTokenizer.from_pretrained(args.model, model_max_length=args.seq_len)\n    model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=\"auto\").to(device)\n\n    act_scales = get_act_scales(model, tokenizer, dataset, args.num_samples, args.seq_len)\n    smooth_lm(model, act_scales, 0.5)\n    save_path = args.save_path\n    if save_path is None:\n        save_path = os.path.join(\"smoothed_models\", args.model)\n    model.save_pretrained(save_path)\n    tokenizer.save_pretrained(save_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "optimum/quanto/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n__version__ = \"0.2.7dev\"\n\nfrom .calibrate import *\nfrom .library import *\nfrom .models import *\nfrom .nn import *\nfrom .quantize import *\nfrom .tensor import *\n"
  },
  {
    "path": "optimum/quanto/calibrate.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\nfrom torch.nn.modules.module import (\n    register_module_forward_hook,\n    register_module_forward_pre_hook,\n)\nfrom torch.overrides import TorchFunctionMode\n\nfrom .nn import QModuleMixin\nfrom .tensor import ActivationQBytesTensor, QTensor, axis_to_dim, dtype_info, qint8, qtype\n\n\n__all__ = [\"Calibration\", \"absmax_scale\"]\n\n\ndef _updated_scale(scale, new_scale, momentum):\n    if torch.all(scale == 1):\n        return new_scale\n    return momentum * scale + new_scale * (1.0 - momentum)\n\n\ndef absmax_scale(base: torch.Tensor, qtype: qtype = qint8, axis: Optional[int] = None) -> torch.Tensor:\n    \"\"\"Evaluate the quantization scale using the absmax algorithm.\n\n    The Absolute Maximum quantization algorithm is a symmetrical quantization\n    algorithm where the scale corresponds to the maximum absolute value of the\n    base divided by the highest positive integer value for the target integer\n    representation.\n\n    Args:\n        base (`torch.Tensor`): the base tensor on which the scale will be applied.\n        qtype (`quanto.qtype`): the target qtype for quantization.\n        axis (`int`): the index of the axis to preserve, or -1 for the last one.\n            Defaults to None to reduce all axis.\n\n    Returns:\n        `torch.Tensor`: a scale tensor of the same dtype as the base.\n    \"\"\"\n    base = torch.abs(base)\n    if axis is None:\n        qranges = torch.max(base)\n    else:\n        dim = axis_to_dim(base, axis)\n        qranges = torch.amax(base, dim=dim, keepdim=True)\n    info = dtype_info(qtype.dtype)\n    return qranges / info.max\n\n\nclass Calibration(TorchFunctionMode):\n    \"\"\"A custom torch dispatch mode to calibrate quantized modules.\n\n    In order to improve the accuracy of the quantized activations, the input and output\n    scales of each quantized module are evaluated per-batch using the absmax algorithm and aggregated using a\n    momentum.\n\n    The dispatch mode also tracks the calls to each torch function down the model graph, and applies optional\n    optimizations:\n    - streamline: do not quantize activations that are immediately consumed by an incompatible function (like `add` or `silu`).\n\n    Args:\n        momentum (`float`): the momentum to use when updating scales.\n        streamline (`bool`): if True, avoid quantizing activations when they are consumed by an incompatible function. Defaults to True.\n        debug (`bool`): provide very verbose feedback on the console during calibration.\n    \"\"\"\n\n    def __init__(self, *args, momentum: float = 0.9, streamline=True, debug=False, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.momentum = momentum\n        self.streamline = streamline\n        if streamline:\n            self.modules_qactivations = {}\n            self.streamline_hooks = {}\n        self.debug = debug\n\n    def __torch_function__(self, func, types, args=(), kwargs=None):\n        kwargs = kwargs if kwargs is not None else {}\n        qinput = QTensor in types\n        output = func(*args, **kwargs)\n        if self.streamline and qinput:\n            for i, arg in enumerate(args):\n                module = getattr(arg, \"src_module\", None)\n                if module is not None:\n                    if isinstance(output, ActivationQBytesTensor):\n                        # Quantized activations are required for that module\n                        self.modules_qactivations[module] = True\n                    elif isinstance(output, torch.Tensor):\n                        # Quantized activations are not required for that module unless another function requires them\n                        qactivations_required = self.modules_qactivations.get(module, False)\n                        self.modules_qactivations[module] = qactivations_required\n        return output\n\n    def __enter__(self):\n        super().__enter__()\n        self.pre_handle = register_module_forward_pre_hook(self.calibrate_input)\n        self.post_handle = register_module_forward_hook(self.calibrate_output)\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        super().__exit__(exc_type, exc_val, exc_tb)\n        self.pre_handle.remove()\n        self.post_handle.remove()\n        if self.streamline:\n            for handle in self.streamline_hooks.values():\n                handle.remove()\n\n    def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9):\n        \"\"\"Calibrate a module input scale\n\n        This is registered as a global hook that is called before any module forward pre hook.\n        \"\"\"\n        if isinstance(module, QModuleMixin) and module.activation_qtype is not None:\n            input = input[0]\n            if isinstance(input, ActivationQBytesTensor):\n                # Just adopt the maximum scale of the input\n                module.input_scale = torch.max(input._scale)\n            else:\n                # Evaluate the best scale\n                input_scale = absmax_scale(input, module.activation_qtype)\n                module.input_scale = _updated_scale(module.input_scale, input_scale, momentum)\n            if self.streamline and module not in self.streamline_hooks:\n                # Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin)\n                self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs)\n            return input\n\n    def calibrate_output(\n        self,\n        module: torch.nn.Module,\n        input: torch.Tensor,\n        output: torch.Tensor,\n    ):\n        \"\"\"Calibrate a module output scale\n\n        This is registered as a global hook that is called before any module forward hook.\n\n        When the module is a QModuleMixin, its outputs are not quantized yet because they\n        are only quantized in the QModuleMixin.quantize_output forward hook.\n        \"\"\"\n        if isinstance(module, (QModuleMixin)) and module.activation_qtype is not None:\n            # Evaluate the optimal scale per-tensor and update output scale\n            output_scale = absmax_scale(output, module.activation_qtype, axis=None)\n            module.output_scale = _updated_scale(module.output_scale, output_scale, self.momentum)\n            return output\n        else:\n            if self.streamline:\n                for name, child in module.named_children():\n                    if isinstance(child, QModuleMixin) and child.activation_qtype is not None:\n                        qactivations_required = self.modules_qactivations.get(child, False)\n                        if not qactivations_required:\n                            # Disable output quantization for this child as its outputs are only consumed by incompatible functions.\n                            child.disable_output_quantization()\n            if self.debug:\n                for name, child in module.named_children():\n                    if isinstance(child, QModuleMixin):\n                        classname = child.__class__.__name__\n                        trace = f\"{name}({classname}) activations are\"\n                        if child.activation_qtype is None:\n                            trace += \" not quantized.\"\n                        else:\n                            trace += f\" quantized to {child.activation_qtype} with scale {child.output_scale}.\"\n                        print(trace)\n\n    def tag_outputs(\n        self,\n        module: torch.nn.Module,\n        input: torch.Tensor,\n        output: torch.Tensor,\n    ):\n        \"\"\"Mark outputs as generated by a module\n\n        This is called as a module forward hook that is called after the QModuleMixin.quantize_output\n        forward hook.\n\n        This is useful in streamline mode to identify the module that generated a specific QTensor.\n        \"\"\"\n        output.src_module = module\n"
  },
  {
    "path": "optimum/quanto/library/README.md",
    "content": "# Quanto operations library\n\nThis contains the `quanto::` operations, available in python under `torch.ops.quanto`.\n\nTo add a new operation:\n\n- add a definition for the operation in `library/ops.py`,\n- provide a default implementation using pytorch operators only under `library/python`,\n- provide optimized kernels for all devices under `library/ext`.\n"
  },
  {
    "path": "optimum/quanto/library/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .extensions import *\nfrom .qbytes_mm import *\nfrom .quantize import *\nfrom .unpack import *\n"
  },
  {
    "path": "optimum/quanto/library/extensions/README.md",
    "content": "# Quanto library extensions\n\nThis folder contains device-specific `quanto::` operations.\n\nImplementations can be provided as part of:\n\n- the generic C++ pytorch extension under `cpp`,\n- the CUDA extension under `cuda`,\n- the Metal Performance Shader extension under `mps`,\n- the XPU SYCL extension under `xpu`.\n\n\nTo provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax:\n\n```python\n@torch.library.impl(\"quanto::unpack\", [\"CPU\", \"CUDA\"])\ndef unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:\n    return ext.unpack(t, bits)\n```\n\nTo declare a new device-specific operation, you need to add it to the library:\n\n```python\ntorch.library.define(\n    \"quanto::gemm_f16i4\",\n    \"(Tensor input,\"\n    \" Tensor other,\"\n    \" Tensor other_scale,\"\n    \" Tensor other_shift,\"\n    \" int group_size)\"\n    \" -> Tensor\",\n)\n```\n\nThen you can provide its implementation:\n\n```python\n@torch.library.impl(\"quanto::gemm_f16i4\", [\"CUDA\"])\ndef gemm_f16i4(\n    input: torch.Tensor,\n    other: torch.Tensor,\n    scales: torch.Tensor,\n    shift: torch.Tensor,\n    group_size: int,\n) -> torch.Tensor:\n    ...\n```\n\n\nPlease refer to each extension folder for examples.\n"
  },
  {
    "path": "optimum/quanto/library/extensions/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport platform\n\nimport torch\nfrom packaging import version\n\nfrom .cpp import *\nfrom .extension import *\n\n\nif torch.cuda.is_available() and platform.system() == \"Linux\":\n    if torch.version.cuda:\n        from .cuda import *\n    elif torch.version.hip:\n        from .hip import *\n\nif torch.backends.mps.is_available():\n    from .mps import *\n\n\ndef _is_xpu_available():\n    # SYCL extension support is added in torch>=2.7 on Linux\n    if platform.system() != \"Linux\":\n        return False\n    if version.parse(torch.__version__).release < version.parse(\"2.7\").release:\n        return False\n    return torch.xpu.is_available()\n\n\nif _is_xpu_available():\n    from .xpu import *\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cpp/README.md",
    "content": "# Quanto generic C++ extension\n\nKernels in this extension must use only the C++ syntax.\n\nThey can use any pytorch operation defined under `aten::` or `c10::`.\n\nTo add a new implementation for an operation defined in `library./ops.py`:\n\n- add the corresponding `.cpp` file to the list of sources in `__init__.py`,\n- add a binding to `pybind_module.cpp`,\n- provide an implementation calling the binding in `__init__.py`.\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cpp/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nfrom ..extension import Extension, register_extension\n\n\n__all__ = []\n\n\next = Extension(\n    \"quanto_cpp\",\n    root_dir=os.path.dirname(__file__),\n    sources=[\"unpack.cpp\", \"pybind_module.cpp\"],\n    extra_cflags=[\"-O3\"],\n)\nregister_extension(ext)\n\n\n@torch.library.impl(\"quanto::unpack\", [\"CPU\"])\ndef unpack_cpp(t: torch.Tensor, bits: int):\n    return ext.lib.unpack(t, bits)\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cpp/pybind_module.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include \"unpack.h\"\n\n// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,\n// and need to be explicitly converted using dedicated helpers before calling a C++ method.\n// As a consequence, when an operation takes such an object as parameter, instead\n// of creating a binding directly to the C++ method, you must create a binding to a\n// lambda method that converts the unmapped types and calls the C++ method.\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"unpack\", &unpack, \"unpack\");\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cpp/unpack.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"unpack.h\"\n#include <torch/extension.h>\n\n\nstatic torch::Tensor unpack_4bit(torch::Tensor &t) {\n\treturn torch::cat({\n                      (t & 0x0F),\n                      (t & 0xF0).__rshift__(4)\n                    },\n                    0);\n}\n\nstatic torch::Tensor unpack_2bit(torch::Tensor &t) {\n\treturn torch::cat({\n                      (t & 0x03),\n                      (t & 0x0C).__rshift__(2),\n                      (t & 0x30).__rshift__(4),\n                      (t & 0xC0).__rshift__(6)\n                    },\n                    0);\n}\n\ntorch::Tensor unpack(torch::Tensor &t, int bits) {\n    TORCH_CHECK(t.scalar_type() == torch::kUInt8, \"Unsupported data type: \", t.scalar_type());\n    switch(bits) {\n      case 4:\n        return unpack_4bit(t);\n      case 2:\n        return unpack_2bit(t);\n      default:\n        throw std::invalid_argument(\"Can only unpack 2-bit or 4-bit tensors.\");\n    }\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cpp/unpack.h",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n\ntorch::Tensor unpack(torch::Tensor &t, int bits);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/README.md",
    "content": "# Quanto generic CUDA extension\n\nKernels in this extension can use both the C++ and CUDA syntax.\n\nThey can use any pytorch operation defined under `aten::` or `c10::`.\n\nTo add a new implementation for an operation defined in `library./ops.py`:\n\n- add the corresponding `.cpp` or `.cu` file to the list of sources in `__init__.py`,\n- add a binding to `pybind_module.cpp`,\n- provide an implementation calling the binding in `__init__.py`.\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nfrom ..extension import Extension, register_extension\n\n\n__all__ = []\n\n\ndef get_max_cuda_arch():\n    \"\"\"Select the maximum CUDA arch supported\n\n    This is a combination of the CUDA and pytorch version and all detected devices capabilities.\n    \"\"\"\n    capability_list = []\n    supported_sm = [int(arch.split(\"_\")[1]) for arch in torch.cuda.get_arch_list() if \"sm_\" in arch]\n    if supported_sm:\n        max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm)\n        for i in range(torch.cuda.device_count()):\n            capability = torch.cuda.get_device_capability(i)\n            # Capability of the device may be higher than what's supported by the user's\n            # NVCC, causing compilation error. User's NVCC is expected to match the one\n            # used to build pytorch, so we use the maximum supported capability of pytorch\n            # to clamp the capability.\n            capability = min(max_supported_sm, capability)\n            if capability not in capability_list:\n                capability_list.append(capability)\n    max_capability = max(sorted(capability_list)) if len(capability_list) > 0 else (0, 0)\n    return f\"{max_capability[0]}{max_capability[1]}0\"\n\n\nextra_cflags = [\"-g\", \"-O3\"]\nextra_cuda_cflags = [\n    \"--expt-extended-lambda\",\n    \"--use_fast_math\",\n]\n# We need to know the minimum CUDA Arch to select only the relevant kernels\n# but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code)\nquanto_cuda_arch = get_max_cuda_arch()\nextra_cuda_cflags += [f\"-DQUANTO_CUDA_ARCH={quanto_cuda_arch}\"]\nmodule_path = os.path.dirname(__file__)\nsources = [\n    \"unpack.cu\",\n    \"awq/v2/gemm_cuda.cu\",\n    \"awq/v2/gemv_cuda.cu\",\n    \"marlin/fp8_marlin.cu\",\n    \"marlin/gptq_marlin_repack.cu\",\n    \"marlin/marlin_cuda.cpp\",\n    \"marlin/marlin_cuda_kernel.cu\",\n    \"pybind_module.cpp\",\n]\next = Extension(\n    \"quanto_cuda\",\n    root_dir=os.path.dirname(__file__),\n    sources=sources,\n    extra_cflags=extra_cflags,\n    extra_cuda_cflags=extra_cuda_cflags,\n)\nregister_extension(ext)\n\n\n@torch.library.impl(\"quanto::unpack\", [\"CUDA\"])\ndef unpack_cuda(t: torch.Tensor, bits: int):\n    return ext.lib.unpack(t, bits)\n\n\ntorch.library.define(\n    \"quanto::gemm_f16i4_awq\",\n    \"(Tensor input,\"\n    \" Tensor other,\"\n    \" Tensor other_scale,\"\n    \" Tensor other_shift,\"\n    \" int rows,\"\n    \" int out_cols,\"\n    \" int in_cols,\"\n    \" int bits,\"\n    \" int group_size)\"\n    \" -> Tensor\",\n)\n\n\n@torch.library.impl(\"quanto::gemm_f16i4_awq\", [\"CUDA\"])\ndef gemm_f16i4_awq(\n    input: torch.Tensor,\n    other: torch.Tensor,\n    scales: torch.Tensor,\n    shift: torch.Tensor,\n    rows: int,\n    out_cols: int,\n    in_cols: int,\n    bits: int,\n    group_size: int,\n):\n    assert out_cols >= 128\n    assert input.dtype == torch.float16\n    assert input.numel() == rows * in_cols\n    assert other.dtype == torch.int16\n    assert scales.dtype == torch.float16\n    assert scales.shape[-1] == out_cols\n    assert shift.dtype == torch.float16\n    assert shift.shape[-1] == out_cols\n    assert bits == 4\n    assert group_size == 128\n    if rows < 8:\n        return ext.lib.awq_v2_gemv_f16i4(input, other, scales, shift, rows, out_cols, in_cols, group_size)\n    return ext.lib.awq_v2_gemm_f16i4(input, other, scales, shift)\n\n\ntorch.library.define(\n    \"quanto::gemm_f16f8_marlin\",\n    \"(Tensor a,\"\n    \"Tensor b_q_weight,\"\n    \"Tensor b_scales,\"\n    \"Tensor workspace,\"\n    \"int num_bits,\"\n    \"int size_m,\"\n    \"int size_n,\"\n    \"int size_k)\"\n    \" -> Tensor\",\n)\n\n\n@torch.library.impl(\"quanto::gemm_f16f8_marlin\", [\"CUDA\"])\ndef fp8_marlin_gemm(\n    a: torch.Tensor,\n    b_q_weight: torch.Tensor,\n    b_scales: torch.Tensor,\n    workspace: torch.Tensor,\n    num_bits: int,\n    size_m: int,\n    size_n: int,\n    size_k: int,\n) -> torch.Tensor:\n    assert b_scales.dtype == torch.float16 or b_scales.dtype == torch.bfloat16\n    assert b_q_weight.dim() == 2\n    assert b_q_weight.dtype == torch.int32\n    return ext.lib.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k)\n\n\ntorch.library.define(\n    \"quanto::pack_fp8_marlin\",\n    \"(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor\",\n)\n\n\n@torch.library.impl(\"quanto::pack_fp8_marlin\", [\"CUDA\"])\ndef gptq_marlin_repack(\n    b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int\n) -> torch.Tensor:\n    assert b_q_weight.dim() == 2\n    assert b_q_weight.dtype == torch.int32\n    return ext.lib.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)\n\n\ntorch.library.define(\n    \"quanto::gemm_f16i4_marlin\",\n    \"(Tensor input, Tensor other, Tensor other_scale, Tensor other_shift, Tensor workspace) -> Tensor\",\n)\n\n\n@torch.library.impl(\"quanto::gemm_f16i4_marlin\", [\"CUDA\"])\ndef gemm_f16i4_marlin(\n    input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, zeropoint: torch.Tensor, workspace: torch.Tensor\n) -> torch.Tensor:\n    assert input.dtype == torch.float16\n    assert other.dtype == torch.int32\n    assert scales.dtype == torch.float16\n    assert zeropoint.dtype == torch.float16\n    assert workspace.dtype == torch.int32\n    output = torch.empty(\n        input.shape[:-1] + (scales.shape[1],),\n        dtype=input.dtype,\n        device=input.device,\n    )\n    ext.lib.marlin_gemm_f16i4(\n        input.reshape((-1, input.shape[-1])),\n        other,\n        output.reshape((-1, output.shape[-1])),\n        scales,\n        zeropoint,\n        workspace,\n        -1,\n        -1,\n        -1,\n        16,\n    )\n    return output\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/dequantize.cuh",
    "content": "/*\nModified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n\n@article{lin2023awq,\n  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},\n  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},\n  journal={arXiv},\n  year={2023}\n}\n*/\n#include <cuda_fp16.h>\n#pragma once\n\n__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result)\n{\n    // uint4 result;\n\n    uint32_t *h = reinterpret_cast<uint32_t *>(result);\n    uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);\n\n    // First, we extract the i4s and construct an intermediate fp16 number.\n    static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;\n    static constexpr uint32_t BOTTOM_MASK = 0x000f000f;\n    static constexpr uint32_t TOP_MASK = 0x00f000f0;\n    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;\n\n    // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing\n    // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.\n    // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and\n    // elt_67 to fp16 without having to shift them to the bottom bits before hand.\n\n    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue\n    // immediately before required.\n    const uint32_t top_i4s = i4s >> 8;\n    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n                 : \"=r\"(h[0])\n                 : \"r\"(i4s), \"n\"(BOTTOM_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n                 : \"=r\"(h[1])\n                 : \"r\"(i4s), \"n\"(TOP_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n                 : \"=r\"(h[2])\n                 : \"r\"(top_i4s), \"n\"(BOTTOM_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400\n    asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n                 : \"=r\"(h[3])\n                 : \"r\"(top_i4s), \"n\"(TOP_MASK), \"n\"(I4s_TO_F16s_MAGIC_NUM), \"n\"(immLut));\n\n    // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the\n    // half2 ctor. In this case, I chose performance reliability over code readability.\n\n    // This is the half2 {1032, 1032} represented as an integer.\n    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;\n    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]\n    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;\n    // This is the half2 {1 / 16, 1 / 16} represented as an integer.\n    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;\n    // This is the half2 {-72, -72} represented as an integer.\n    // static constexpr uint32_t NEG_72 = 0xd480d480;\n    // Haotian: Let's use {-64, -64}.\n    static constexpr uint32_t NEG_64 = 0xd400d400;\n\n    // Finally, we construct the output numbers.\n    // Convert elt_01\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[0]) : \"r\"(h[0]), \"r\"(FP16_TOP_MAGIC_NUM));\n    // Convert elt_23\n    asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(h[1]) : \"r\"(h[1]), \"r\"(ONE_SIXTEENTH), \"r\"(NEG_64));\n    // Convert elt_45\n    asm volatile(\"sub.f16x2 %0, %1, %2;\\n\" : \"=r\"(h[2]) : \"r\"(h[2]), \"r\"(FP16_TOP_MAGIC_NUM));\n    // Convert elt_67\n    asm volatile(\"fma.rn.f16x2 %0, %1, %2, %3;\\n\" : \"=r\"(h[3]) : \"r\"(h[3]), \"r\"(ONE_SIXTEENTH), \"r\"(NEG_64));\n\n    // return result;\n}"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/v2/gemm_cuda.cu",
    "content": "#include <cuda_fp16.h>\n#include \"semaphore.h\"\n#include \"gemm_cuda.h\"\n#include \"../dequantize.cuh\"\n#include <torch/extension.h>\n#include <cuda_pipeline_primitives.h>\n\n#if defined(QUANTO_CUDA_ARCH) and QUANTO_CUDA_ARCH >= 800\n// The following GEMMs requires m16n8k16 which is only supported for CUDA arch after sm80\n\n#define kInterleave 4\n#define OP_M 16\n#define OP_N 8\n#define OP_K 16\n#define INTRIN_M 16\n#define INTRIN_N 16\n#define INTRIN_K 16\n#define WARP_SIZE 32\n#define SMEM_PAD_A 0\n#define SMEM_PAD_B 0\n#define PACK_SIZE 8\n#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)\n#define L2_CACHEHINT(size) \".L2::\" #size \"B\"\n#else\n#define L2_CACHEHINT(size)\n#endif\n\n#define KERNEL_LAUNCH_CODE                                                                                                                              \\\n  int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N;                                                       \\\n  torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int);                                                                                \\\n  auto semaphores = reinterpret_cast<int *>(_semaphores.data_ptr<int>());                                                                               \\\n  constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K);                                                                     \\\n  constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2);                              \\\n  constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(half); \\\n  if (kSmemByteSize >= 99 * 1024)                                                                                                                       \\\n  {                                                                                                                                                     \\\n    printf(\"This kernel requires %d Bytes of shared memory, which exceeds device limit.\\n\", kSmemByteSize);                                             \\\n    return _out_feats;                                                                                                                                  \\\n  }                                                                                                                                                     \\\n  int j_factors1 = num_out_channels / CTA_N / 1;                                                                                                        \\\n  dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK);                                                                           \\\n  dim3 threads_per_block(WARP_SIZE, NUM_WARPS);                                                                                                         \\\n  auto kernel_func = gemm_w4a16_T1<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>;                                                     \\\n  cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);                                                        \\\n  kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(                                                                                        \\\n      in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);\n\ntemplate <int N>\n__inline__ __host__ __device__ int get_log_tile(int n)\n{\n  if (N >= 8 && n >= 6)\n    return 3;\n  else if (N >= 4 && n >= 3)\n    return 2;\n  else if (N >= 2 && n >= 2)\n    return 1;\n  else\n    return 0;\n}\n\n__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)\n{\n  return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));\n}\n\ntemplate <int SLICES, int NUM_WARPS_MN>\n__device__ void sync_slice(int slice_id)\n{\n  if constexpr (SLICES == 1)\n  {\n    __syncthreads();\n  }\n  else\n  {\n    constexpr int SLICE_GROUP = (SLICES + 7) / 8;\n    constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;\n    const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;\n    asm volatile(\"bar.sync %0, %1;\" : : \"r\"(barrier_id), \"n\"(num_threads));\n  }\n}\n\n__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr)\n{\n  uint32_t smem_int_ptr;\n\n  asm(\"{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\\n\"\n      : \"=r\"(smem_int_ptr)\n      : \"l\"(ptr));\n\n  return smem_int_ptr;\n}\n\n__inline__ __device__ void ldmatrix_m8n8_x4_b16(half *shared_warp, int ax0_0, uint32_t addr)\n{\n  __asm__ __volatile__(\n      \"ldmatrix.sync.aligned.m8n8.x4.shared.b16\"\n      \"{%0, %1, %2, %3}, [%4];\"\n      : \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])\n      : \"r\"(addr));\n}\n\n__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0_0, uint32_t addr)\n{\n  __asm__ __volatile__(\n      \"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16\"\n      \"{%0, %1, %2, %3}, [%4];\"\n      : \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), \"=r\"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])\n      : \"r\"(addr));\n}\n\n__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)\n{\n  const int cp_size = 16;\n  asm volatile(\"{\"\n               \"  .reg .pred p;\"\n               \"  setp.ne.b32 p, %0, 0;\"\n               \"  @p cp.async.cg.shared.global\" L2_CACHEHINT(128) \" [%1], [%2], %3;\"\n                                                                  \"}\" ::\"r\"((int)mask),\n               \"r\"(smem_int_ptr),\n               \"l\"(src),\n               \"n\"(cp_size));\n}\n\n__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp)\n{\n  __asm__ __volatile__(\n      \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\"\n      \"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\"\n      : \"=f\"(((float *)C_warp)[0]), \"=f\"(((float *)C_warp)[1]), \"=f\"(((float *)C_warp)[2]), \"=f\"(((float *)C_warp)[3])\n      : \"r\"(((unsigned *)A_shared_warp)[0]), \"r\"(((unsigned *)A_shared_warp)[1]), \"r\"(((unsigned *)A_shared_warp)[2]), \"r\"(((unsigned *)A_shared_warp)[3]), \"r\"(((unsigned *)B_shared_warp)[0]), \"r\"(((unsigned *)B_shared_warp)[1]), \"f\"(((float *)C_warp)[0]), \"f\"(((float *)C_warp)[1]), \"f\"(((float *)C_warp)[2]), \"f\"(((float *)C_warp)[3]));\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>\n__device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;\n  constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;\n  constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;\n  constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;\n  constexpr int threads_per_row = CTA_K / PACK_SIZE;\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_A;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n  int ld_col = (threadIdx.x % threads_per_row);\n#pragma unroll\n  for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)\n  {\n    int global_iter = shared_iter_k * partial_global_iters + _global_iter;\n    int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);\n    int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;\n    void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);\n    uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);\n    if constexpr (STAGES > 1)\n    {\n      uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n      cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));\n    }\n    else\n    {\n      if (local_mask & (ld_row + cta_offset_m < global_nrows))\n        *(uint4 *)dst_ptr = *src_ptr;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>\n__device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;\n  constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;\n  constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;\n  constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;\n  constexpr int threads_per_row = CTA_K / PACK_SIZE;\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_B;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n#pragma unroll\n  for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)\n  {\n    int global_iter = shared_iter_k * partial_global_iters + _global_iter;\n\n    int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);\n    int ld_col = (threadIdx.x % threads_per_row);\n    int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;\n    void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));\n    uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);\n    if constexpr (STAGES > 1)\n    {\n      uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n      cp_async_cg_A(addr, src_ptr, local_mask);\n    }\n    else\n    {\n      if (local_mask)\n        *(uint4 *)dst_ptr = *src_ptr;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>\n__device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;\n  constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int threads_per_row = CTA_N / PACK_SIZE;\n  constexpr int kSmemCol = CTA_N;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n  int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;\n\n  void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  if (STAGES > 1)\n  {\n    uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n    cp_async_cg_A(addr, src_ptr, local_mask);\n    uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);\n    cp_async_cg_A(addr_z, src_ptr_z, local_mask);\n  }\n  else\n  {\n    if (local_mask)\n    {\n      *(uint4 *)dst_ptr = *src_ptr;\n      *(uint4 *)dst_ptr_z = *src_ptr_z;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>\n__device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)\n{\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_A;\n\n  for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n  {\n\n    int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);\n    int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;\n    int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;\n    void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);\n\n    uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);\n    ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>\n__device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)\n{\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_B;\n  int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);\n  int c0 = ((threadIdx.x / 8) % 2) * 8;\n  int r = r0 / 4;\n  int c = (r0 % 4) * 16 + c0;\n  int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;\n\n  if constexpr (ldmatrix)\n  {\n#pragma unroll\n    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n    {\n      void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);\n      uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);\n      ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);\n    }\n  }\n\n#pragma unroll\n  for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n  {\n    half scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];\n    half zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];\n    half2 scale2 = make_half2(scale, scale);\n    half2 zero2 = make_half2(zero, zero);\n    half2 loaded[4];\n\n    dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));\n#pragma unroll\n    for (int i = 0; i < 4; i++)\n    {\n      loaded[i] = __hfma2(loaded[i], scale2, zero2);\n    }\n    *reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G, int SPLITK>\n__global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)\n{\n  constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;\n  constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;\n  constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;\n  constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;\n  constexpr int SLICES = CTA_K / WARP_K;\n  int num_blocks_n = (N + CTA_N - 1) / CTA_N;\n  int num_blocks_m = (M + CTA_M - 1) / CTA_M;\n  int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);\n  int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);\n  const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);\n  int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);\n  int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);\n  const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);\n  blockIdx_m = block_idx_mapping.x;\n  blockIdx_n = block_idx_mapping.y;\n\n  float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];\n  constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;\n  constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;\n  constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;\n  constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;\n  constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;\n  constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;\n  constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;\n  constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;\n  constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;\n  extern __shared__ half mem_shared[];\n  half *A_shared = mem_shared;\n  half *B_shared = mem_shared + kSmemSizeA;\n  half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB;\n  half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;\n  float *C_shared = reinterpret_cast<float *>(mem_shared);\n  half A_shared_warp_[2][WARP_M * INTRIN_K /\n                         WARP_SIZE];\n  half B_shared_warp_[2][WARP_N * 32 /\n                         WARP_SIZE];\n  half B_shared_warp_tmp_[2][WARP_N * 16 /\n                             WARP_SIZE];\n  int cta_offset_m = blockIdx_m * CTA_M;\n  int cta_offset_n = blockIdx_n * CTA_N;\n  int cta_offset_k = blockIdx_z * (K / SPLITK);\n  int warp_mn = threadIdx.y % NUM_WARPS_MN;\n  int slice_id = threadIdx.y / NUM_WARPS_MN;\n  int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;\n  int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;\n  int warp_offset_k = slice_id * WARP_K;\n\n  for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)\n    C_warp[i] = 0.0;\n\n  int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;\n  int k_0_0_ld = 0;\n  int k_0_0 = 0;\n  constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;\n#pragma unroll\n  for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)\n  {\n    global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);\n    global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);\n    global_to_share_one_stage_scales<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(\n        scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,\n        zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,\n        N, cta_offset_m, cta_offset_n, cta_offset_k,\n        k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);\n    if constexpr (STAGES > 1)\n      __pipeline_commit();\n  }\n  if constexpr (STAGES > 1)\n    __pipeline_wait_prior(STAGES - 2);\n  __syncthreads();\n\n  share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);\n  share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);\n  constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;\n\n  for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)\n  {\n    int ld_stage = k_0_0_ld % STAGES;\n    int compute_stage = k_0_0 % STAGES;\n    half *A_shared_this_compute_stage;\n    half *B_shared_this_compute_stage;\n    half *scales_shared_this_compute_stage;\n    half *zeros_shared_this_compute_stage;\n\n#pragma unroll\n    for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)\n    {\n      A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;\n      B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;\n      scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;\n      zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;\n      share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);\n      if ((iter_k + 1) % kInterleave == 0)\n      {\n        if (compute_stage % 2 == 1)\n        {\n          share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);\n        }\n        else\n        {\n          share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);\n        }\n      }\n      else\n      {\n        if (compute_stage % 2 == 1)\n        {\n          share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);\n        }\n        else\n        {\n          share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);\n        }\n      }\n      half *A_shared_warp = A_shared_warp_[iter_k % 2];\n      half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];\n\n      for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)\n      {\n        for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)\n        {\n          mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);\n          mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);\n        }\n      }\n\n      if (iter_k < WARP_K / INTRIN_K - 1)\n      {\n        if constexpr (STAGES == 1)\n          __syncthreads();\n        global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);\n      }\n\n      if (iter_k == WARP_K / INTRIN_K - 2)\n      {\n        if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)\n        {\n          __syncthreads();\n        }\n        global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_scales<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(\n            scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,\n            zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,\n            N, cta_offset_m, cta_offset_n, cta_offset_k,\n            k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);\n        if constexpr (STAGES > 1)\n        {\n          __pipeline_commit();\n          __pipeline_wait_prior(STAGES - 2);\n        }\n        compute_stage = (k_0_0 + 1) % STAGES;\n        __syncthreads();\n      }\n    }\n  }\n  __pipeline_commit();\n  __pipeline_wait_prior(0);\n  __syncthreads();\n  if constexpr (SLICES > 1)\n  {\n#pragma unroll\n    for (int z = 0; z < SLICES; ++z)\n    {\n      if (slice_id == z)\n      {\n#pragma unroll\n        for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)\n        {\n#pragma unroll\n          for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)\n          {\n#pragma unroll\n            for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)\n            {\n              if (z > 0)\n              {\n                C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];\n              }\n              C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];\n            };\n          }\n        }\n      }\n      __syncthreads();\n    }\n    if (slice_id == 0)\n    {\n#pragma unroll\n      for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)\n      {\n#pragma unroll\n        for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)\n        {\n#pragma unroll\n          for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)\n          {\n            C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];\n          };\n        }\n      }\n    }\n  }\n\n  if (slice_id == 0)\n  {\n    Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);\n\n    if constexpr (SPLITK > 1)\n    {\n      semaphore.fetch();\n    }\n\n    if (blockIdx_z != 0)\n    {\n      semaphore.wait(blockIdx_z);\n      for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)\n      {\n        for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)\n        {\n          for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)\n          {\n            int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));\n\n            if (write_row < M)\n            {\n              half2 *existing_psum_ptr = reinterpret_cast<half2 *>(\n                  C + write_row * N +\n                  cta_offset_n + warp_offset_n + ax1_0_1 * 16 +\n                  (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2);\n\n              *existing_psum_ptr = __hadd2(*existing_psum_ptr,\n                                           __float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +\n                                                                                         ax1_0_1 * 8 + local_id)));\n            }\n          };\n        }\n      }\n    }\n    else\n    {\n      for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)\n      {\n        for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)\n        {\n          for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)\n          {\n            int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));\n            if (write_row < M)\n            {\n              *reinterpret_cast<half2 *>(\n                  C + write_row * N +\n                  cta_offset_n + warp_offset_n + ax1_0_1 * 16 +\n                  (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =\n                  __float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +\n                                                                ax1_0_1 * 8 + local_id));\n            }\n          };\n        }\n      }\n    }\n\n    if constexpr (SPLITK > 1)\n    {\n\n      int lock = 0;\n      if (SPLITK == blockIdx_z + 1)\n      {\n\n        lock = 0;\n      }\n      else\n      {\n        lock = blockIdx_z + 1;\n      }\n      semaphore.release(lock);\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>\n__device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;\n  constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;\n  constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;\n  constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;\n  constexpr int threads_per_row = CTA_K / PACK_SIZE;\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_A;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n  int ld_col = (threadIdx.x % threads_per_row);\n#pragma unroll\n  for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)\n  {\n    int global_iter = shared_iter_k * partial_global_iters + _global_iter;\n    int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);\n    int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;\n    void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);\n    uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);\n    if constexpr (STAGES > 1)\n    {\n      uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n      cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));\n    }\n    else\n    {\n      if (local_mask & (ld_row + cta_offset_m < global_nrows))\n        *(uint4 *)dst_ptr = *src_ptr;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>\n__device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;\n  constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;\n  constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;\n  constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;\n  constexpr int threads_per_row = CTA_K / PACK_SIZE;\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_B;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n#pragma unroll\n  for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)\n  {\n    int global_iter = shared_iter_k * partial_global_iters + _global_iter;\n\n    int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);\n    int ld_col = (threadIdx.x % threads_per_row);\n    int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;\n    void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));\n    uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);\n    if constexpr (STAGES > 1)\n    {\n      uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n      cp_async_cg_A(addr, src_ptr, local_mask);\n    }\n    else\n    {\n      if (local_mask)\n        *(uint4 *)dst_ptr = *src_ptr;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>\n__device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)\n{\n  constexpr int threads_needed = CTA_N / PACK_SIZE / 1;\n  constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;\n  constexpr int threads_per_row = CTA_N / PACK_SIZE;\n  bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);\n  int g_idx = global_iter_k * CTA_K / G;\n\n  void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);\n  if (STAGES > 1)\n  {\n    uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);\n    cp_async_cg_A(addr, src_ptr, local_mask);\n    uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);\n    cp_async_cg_A(addr_z, src_ptr_z, local_mask);\n  }\n  else\n  {\n    if (local_mask)\n    {\n      *(uint4 *)dst_ptr = *src_ptr;\n      *(uint4 *)dst_ptr_z = *src_ptr_z;\n    }\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int STAGES, int shared_iters>\n__device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int warp_offset_m, int warp_offset_n, int k_0_1)\n{\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_A;\n\n  for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n  {\n\n    int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);\n    int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;\n    int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;\n    void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);\n\n    uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);\n    ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int STAGES, bool ldmatrix, int shared_iters, int G>\n__device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1)\n{\n  constexpr int kSmemCol = CTA_K + SMEM_PAD_B;\n  int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);\n  int c0 = ((threadIdx.x / 8) % 2) * 8;\n  int r = r0 / 4;\n  int c = (r0 % 4) * 16 + c0;\n  int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;\n\n  if constexpr (ldmatrix)\n  {\n#pragma unroll\n    for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n    {\n      void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled);\n      uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);\n      ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);\n    }\n  }\n\n#pragma unroll\n  for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)\n  {\n    half scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];\n    half zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];\n    half2 scale2 = make_half2(scale, scale);\n    half2 zero2 = make_half2(zero, zero);\n    half2 loaded[4];\n    dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast<uint4 *>(loaded));\n#pragma unroll\n    for (int i = 0; i < 4; i++)\n    {\n      loaded[i] = __hfma2(loaded[i], scale2, zero2);\n    }\n    *reinterpret_cast<uint4 *>(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast<uint4 *>(loaded);\n  }\n}\n\ntemplate <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>\n__global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int M, int N, int K)\n{\n  constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;\n  constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;\n  int num_blocks_n = (N + CTA_N - 1) / CTA_N;\n  int num_blocks_m = (M + CTA_M - 1) / CTA_M;\n  int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);\n  const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);\n  int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);\n  int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);\n  const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);\n  blockIdx_m = block_idx_mapping.x;\n  blockIdx_n = block_idx_mapping.y;\n\n  float C_warp[CTA_M * CTA_N / CTA_SIZE];\n  constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;\n  constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;\n  constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;\n  constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;\n  constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;\n  constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;\n  constexpr int kSmemSizeScales = CTA_N * STAGES / 2;\n  constexpr int scales_load_interval = G / CTA_K;\n  extern __shared__ half mem_shared[];\n  half *A_shared = mem_shared;\n  half *B_shared = mem_shared + kSmemSizeA;\n  half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB;\n  half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;\n  half A_shared_warp_[2][WARP_M * INTRIN_K /\n                         WARP_SIZE];\n  half B_shared_warp_[2][WARP_N * 32 /\n                         WARP_SIZE];\n  half B_shared_warp_tmp_[2][WARP_N * 16 /\n                             WARP_SIZE];\n  int cta_offset_m = blockIdx_m * CTA_M;\n  int cta_offset_n = blockIdx_n * CTA_N;\n  int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;\n  int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;\n\n  for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)\n    C_warp[i] = 0.0;\n\n  int gemm_iters = (K + CTA_K - 1) / CTA_K;\n  int k_0_0_ld = 0;\n  int k_0_0 = 0;\n  constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;\n#pragma unroll\n  for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)\n  {\n    global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);\n    global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);\n    global_to_share_one_stage_scales_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(\n        scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,\n        zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,\n        N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);\n    if constexpr (STAGES > 1)\n      __pipeline_commit();\n  }\n  if constexpr (STAGES > 1)\n    __pipeline_wait_prior(STAGES - 2);\n  __syncthreads();\n\n  share_to_reg_one_stage_A_T2<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);\n  share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);\n  constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;\n\n  for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)\n  {\n    int ld_stage = k_0_0_ld % STAGES;\n    int compute_stage = k_0_0 % STAGES;\n    half *A_shared_this_compute_stage;\n    half *B_shared_this_compute_stage;\n    half *scales_shared_this_compute_stage;\n    half *zeros_shared_this_compute_stage;\n\n    for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)\n    {\n      A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;\n      B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;\n      scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;\n      zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;\n      share_to_reg_one_stage_A_T2<CTA_M, CTA_N, CTA_K, STAGES, WARP_M / INTRIN_M>(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);\n      if ((iter_k + 1) % kInterleave == 0)\n      {\n        if (compute_stage % 2 == 1)\n        {\n          share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);\n        }\n        else\n        {\n          share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, true, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);\n        }\n      }\n      else\n      {\n        if (compute_stage % 2 == 1)\n        {\n          share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);\n        }\n        else\n        {\n          share_to_reg_one_stage_B_T2<CTA_M, CTA_N, CTA_K, STAGES, false, WARP_N / INTRIN_N, G>(\n              B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,\n              B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],\n              warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);\n        }\n      }\n      __syncthreads();\n      half *A_shared_warp = A_shared_warp_[iter_k % 2];\n      half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];\n      for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)\n      {\n        for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)\n        {\n          mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);\n          mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);\n        }\n      }\n\n      if (iter_k < WARP_K / INTRIN_K - 1)\n      {\n        if constexpr (STAGES == 1)\n          __syncthreads();\n        global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);\n      }\n\n      if (iter_k == WARP_K / INTRIN_K - 2)\n      {\n        if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)\n        {\n          __syncthreads();\n        }\n        global_to_share_one_stage_A_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_B_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);\n        global_to_share_one_stage_scales_T2<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(\n            scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N,\n            zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N,\n            N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);\n        if constexpr (STAGES > 1)\n        {\n          __pipeline_commit();\n          __pipeline_wait_prior(STAGES - 2);\n        }\n        compute_stage = (k_0_0 + 1) % STAGES;\n        __syncthreads();\n      }\n    }\n  }\n  for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)\n  {\n    for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)\n    {\n      for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)\n      {\n        int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));\n        if (write_row < M)\n        {\n          *reinterpret_cast<half2 *>(\n              C + write_row * N +\n              cta_offset_n + warp_offset_n + ax1_0_1 * 16 +\n              (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =\n              __float22half2_rn(*reinterpret_cast<float2 *>(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +\n                                                            ax1_0_1 * 8 + local_id));\n        }\n      };\n    }\n  }\n}\n\ntorch::Tensor awq_v2_gemm_f16i4(\n    torch::Tensor _in_feats,\n    torch::Tensor _kernel,\n    torch::Tensor _scales,\n    torch::Tensor _zeros)\n{\n  std::vector<int64_t> output_shape = _in_feats.sizes().vec();\n  output_shape.back() = _kernel.size(0) * kInterleave;\n  int num_in_feats = _in_feats.numel() / _in_feats.size(-1);\n  int num_in_channels = _in_feats.size(-1);\n  auto in_feats = reinterpret_cast<half *>(_in_feats.data_ptr<at::Half>());\n  auto kernel = reinterpret_cast<half *>(_kernel.data_ptr<int16_t>());\n  auto scales = reinterpret_cast<half *>(_scales.data_ptr<at::Half>());\n  auto zeros = reinterpret_cast<half *>(_zeros.data_ptr<at::Half>());\n  auto options =\n      torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());\n  auto options_int =\n      torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device());\n  at::Tensor _out_feats = torch::empty(output_shape, options);\n  int num_out_feats = _out_feats.numel() / _out_feats.size(-1);\n  int num_out_channels = _out_feats.size(-1);\n  auto out_feats = reinterpret_cast<half *>(_out_feats.data_ptr<at::Half>());\n\n  if (num_out_feats <= 32)\n  {\n    constexpr int G = 128;\n    constexpr int CTA_M = 16;\n    constexpr int CTA_N = 128;\n    constexpr int CTA_K = 128;\n    constexpr int WARP_M = 16;\n    constexpr int WARP_N = 32;\n    constexpr int WARP_K = 64;\n    constexpr int SPLITK = 2;\n    constexpr int STAGES = 4;\n    KERNEL_LAUNCH_CODE\n  }\n  else if (num_out_feats <= 64)\n  {\n\n    constexpr int G = 128;\n    constexpr int CTA_M = 16;\n    constexpr int CTA_N = 128;\n    constexpr int CTA_K = 128;\n    constexpr int WARP_M = 16;\n    constexpr int WARP_N = 32;\n    constexpr int WARP_K = 64;\n    constexpr int SPLITK = 1;\n    constexpr int STAGES = 3;\n    KERNEL_LAUNCH_CODE\n  }\n  else if (num_out_feats <= 128)\n  {\n    constexpr int G = 128;\n    constexpr int CTA_M = 32;\n    constexpr int CTA_N = 128;\n    constexpr int CTA_K = 128;\n    constexpr int WARP_M = 32;\n    constexpr int WARP_N = 32;\n    constexpr int WARP_K = 64;\n    constexpr int SPLITK = 1;\n    constexpr int STAGES = 4;\n    KERNEL_LAUNCH_CODE\n  }\n  else if (num_out_feats <= 192)\n  {\n    constexpr int G = 128;\n    constexpr int CTA_M = 64;\n    constexpr int CTA_N = 128;\n    constexpr int CTA_K = 64;\n    constexpr int WARP_M = 64;\n    constexpr int WARP_N = 32;\n    constexpr int WARP_K = 64;\n    constexpr int SPLITK = 1;\n    constexpr int STAGES = 4;\n    KERNEL_LAUNCH_CODE\n  }\n  else\n  {\n    constexpr int G = 128;\n    constexpr int CTA_M = 64;\n    constexpr int CTA_N = 128;\n    constexpr int CTA_K = 64;\n    constexpr int WARP_M = 64;\n    constexpr int WARP_N = 32;\n    constexpr int WARP_K = 64;\n    constexpr int STAGES = 4;\n\n    constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);\n    constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(half);\n    if (kSmemByteSize >= 99 * 1024)\n    {\n      printf(\"This kernel requires %d Bytes of shared memory, which exceeds device limit.\\n\", kSmemByteSize);\n      return _out_feats;\n    }\n    int j_factors1 = num_out_channels / CTA_N / 1;\n    dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);\n    dim3 threads_per_block(WARP_SIZE, NUM_WARPS);\n    auto kernel_func = gemm_w4a16_T2<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;\n    cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);\n    kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(\n        in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);\n  }\n\n  return _out_feats;\n}\n\n#else // if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n\ntorch::Tensor awq_v2_gemm_f16i4(\n    torch::Tensor _in_feats,\n    torch::Tensor _kernel,\n    torch::Tensor _scales,\n    torch::Tensor _zeros)\n{\n  throw std::runtime_error(\"This GEMM requires a CUDA arch >= sm80.\\n\");\n}\n\n#endif // if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/v2/gemm_cuda.h",
    "content": "#include <torch/extension.h>\n\ntorch::Tensor awq_v2_gemm_f16i4(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/v2/gemv_cuda.cu",
    "content": "/*\n * Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)\n * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n@article{lin2023awq,\n  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},\n  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},\n  journal={arXiv},\n  year={2023}\n}\n*/\n\n#include <cuda_fp16.h>\n#include <stdio.h>\n#include <torch/extension.h>\n#include \"gemv_cuda.h\"\n#include \"../dequantize.cuh\"\n#define PACK_FACTOR 8\n#define WARP_SIZE 32\n#define MEM_ACCESS_SIZE 128\n\n// Reduce sum within the warp using the tree reduction algorithm.\ntemplate <int Num, int WarpSize>\n__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4])\n{\n  // kInterleave = 4\n      float fpsum[Num];\n      #pragma unroll\n      for (int i = 0; i < Num; ++i)\n      {\n          fpsum[i] = __half2float(psum[i]);\n      }\n\n      #pragma unroll\n      for (int i = 0; i < Num; ++i)\n      {\n          // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)\n          fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);\n          fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);\n          fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);\n      }\n      __syncthreads();\n      int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;\n      if (lane == 0 || lane == 2 || lane == 4 || lane == 6)\n      {\n          #pragma unroll\n          for (int i = 0; i < Num; ++i)\n          {\n              out_smem[warp][i * 4 + lane / 2] = fpsum[i];\n          }\n      }\n      __syncthreads();\n};\n\n__device__ __forceinline__ int make_divisible(int c, int divisor){\n  return (c + divisor - 1) / divisor;\n}\n\ntemplate <int NPerBlock, int Batch, int BlockSize, int GroupSize>\n__global__ void gemv_kernel(\n  const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs,\n  const int IC, const int OC)\n{\n    const int kStride = 64;\n    const int kElemsPerThread = MEM_ACCESS_SIZE / 4;\n    const int kThreadsNumPerTile = kStride / kElemsPerThread;\n    // assert(MEM_ACCESS_SIZE == 128);\n\n    static constexpr int kShuffleSize = 32;\n    static constexpr int kShuffleBasicTile = 2;\n    static constexpr int kShuffleContinous = 4;\n    static constexpr int kShuffleStrided = 4;\n\n    constexpr int Num = NPerBlock * Batch;\n    constexpr int kInterleave = 4;\n\n    half local_inputs[kElemsPerThread];\n    uint32_t local_qweights[MEM_ACCESS_SIZE / 32];\n    half half_weight_buffer[kElemsPerThread];\n    half dequantized_weight[kElemsPerThread * NPerBlock];\n    half local_scale[NPerBlock];\n    half local_scaled_zeros[NPerBlock];\n\n    half psum[Num];\n    for (int i = 0; i < Num; ++i)\n        psum[i] = __float2half(0.f);\n\n    extern __shared__ uint8_t shmem[];\n    float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);\n\n    const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;\n    const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;\n    const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride\n                               + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;\n    const int group_offset = act_k_offset / GroupSize;\n    // TODO: use make_divisible\n    const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;\n    const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;\n    const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;\n    const half* inputs_ptr = inputs + act_k_offset;\n\n    const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;\n    const int scale_forward_step = act_forward_step / GroupSize * OC;\n\n    // Main loop iteration, each block completes the outputs for several OCs\n    for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)\n    {\n        // Load qweight, scales and scaled_zeros\n        #pragma unroll\n        for (int idx = 0; idx < NPerBlock; ++idx)\n        {\n            // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)\n            *((float4*)(local_qweights)) =\n                *((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));\n            local_scale[idx] = *(scale_ptr + idx * kInterleave);\n            local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);\n\n            // Map int4 qweight to fp format\n            #pragma unroll\n            for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)\n            {\n                // Converts 32 bits (8 x int4) to 8 fp16\n                dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));\n            }\n\n            // Dequantize (apply s/z) and shuffle elements to match the weight packing format\n            #pragma unroll\n            for (int i = 0; i < kShuffleContinous; ++i)\n            {\n                #pragma unroll\n                for (int j = 0; j < kShuffleStrided; ++j)\n                {\n                    half2 w =\n                        *reinterpret_cast<half2*>(\n                          half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile\n                        );\n                    w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx]));\n                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)\n                          * NPerBlock + idx]\n                        = w.x;\n                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)\n                            * NPerBlock + idx]\n                        = w.y;\n                }\n            }\n        }\n        #pragma unroll\n        for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)\n        {\n            const half* local_inputs_ptr = inputs_ptr + batch_idx * IC;\n            #pragma unroll\n            for (int idx = 0; idx < kElemsPerThread / 8; ++idx)\n            {\n                // load activation, 8 halves (128 bits) / step.\n                *((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));\n            }\n            // Perform the MACs\n            #pragma unroll\n            for (int x = 0; x < NPerBlock / 2; ++x)\n            {\n                #pragma unroll\n                for (int y = 0; y < kElemsPerThread; ++y)\n                {\n                    *reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2)\n                        = __hfma2(*reinterpret_cast<half2*>(dequantized_weight + y * NPerBlock + x * 2),\n                            __half2half2(local_inputs[y]),\n                            *reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2));\n                }\n            }\n        }\n        inputs_ptr += act_forward_step;\n        scale_ptr += scale_forward_step;\n        zeros_ptr += scale_forward_step;\n    }\n\n    warp_reduce<Num, WARP_SIZE>(psum, out_smem);\n\n    // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num\n    for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)\n    {\n        int batch_idx = i / (NPerBlock * kInterleave);\n        int oc_idx = i % (NPerBlock * kInterleave);\n        float acc = 0.f;\n        for (int j = 0; j < BlockSize / WARP_SIZE; ++j)\n        {\n            acc += out_smem[j][i];\n        }\n        outputs[batch_idx * OC + blk_row_offset + oc_idx] = __float2half(acc);\n    }\n}\n\n/*\nComputes GEMV (PyTorch interface).\n\nArgs:\n  _in_feats: tensor of shape [B, IC];\n  _kernel: int tensor of shape [OC, IC // 8];\n  _zeros: int tensor of shape [OC, IC // G // 8];\n  _scaling_factors: tensor of shape [OC, IC // G];\n  blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;\n  blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;\n\nReturns:\n  out_feats: tensor of shape [B, OC];\n*/\ntorch::Tensor awq_v2_gemv_f16i4(\n    torch::Tensor _in_feats,\n    torch::Tensor _kernel,\n    torch::Tensor _scaling_factors,\n    torch::Tensor _zeros,\n    int m,\n    int n,\n    int k,\n    int group_size)\n{\n\n    std::vector<int64_t> output_shape = _in_feats.sizes().vec();\n    output_shape.back() = n;\n\n    auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());\n    auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());\n    auto zeros = reinterpret_cast<half*>(_zeros.data_ptr<at::Half>());\n    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());\n\n    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());\n    at::Tensor _out_feats = torch::empty(output_shape, options);\n    half * out_feats = reinterpret_cast<half *>(_out_feats.data_ptr());\n\n    static constexpr int N_PER_BLOCK = 2;\n    static constexpr int K_INTERLEAVE = 4;\n    static constexpr int BLOCK_SIZE = 256;\n\n    dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);\n    dim3 num_threads(BLOCK_SIZE);\n\n    // if (group_size == 64)\n    // {\n    //   gemv_kernel_g64<<<num_blocks, num_threads>>>(\n    //     // pointers\n    //     in_feats, kernel, zeros, scaling_factors, out_feats,\n    //     // constants\n    //     num_in_channels, num_out_channels\n    //   );\n    // }\n    if (group_size == 128)\n    {\n      switch (m)\n      {\n      case 1:\n        gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 2:\n        gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 3:\n        gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 4:\n        gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 5:\n        gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 6:\n        gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      case 7:\n        gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(\n          in_feats, kernel, scaling_factors, zeros, out_feats, k, n\n        );\n        break;\n      default:\n        throw std::runtime_error(\"Unsupported batch size for gemv kernel.\\n\");\n      }\n    }\n    else\n    {\n      throw std::runtime_error(\"Unsupported group size for gemv kernel.\\n\");\n    }\n    return _out_feats;\n}\n\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/v2/gemv_cuda.h",
    "content": "#pragma once\n#include <torch/extension.h>\n\ntorch::Tensor awq_v2_gemv_f16i4(\n    torch::Tensor _in_feats,\n    torch::Tensor _kernel,\n    torch::Tensor _scaling_factors,\n    torch::Tensor _zeros,\n    int m,\n    int n,\n    int k,\n    int group_size);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/awq/v2/semaphore.h",
    "content": "/***************************************************************************************************\n * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n * SPDX-License-Identifier: BSD-3-Clause\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n *\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * 3. Neither the name of the copyright holder nor the names of its\n * contributors may be used to endorse or promote products derived from\n * this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n **************************************************************************************************/\n/*! \\file\n    \\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.\n*/\n\n#pragma once\n\n/////////////////////////////////////////////////////////////////////////////////////////////////\n\n// namespace cutlass {\n\n/////////////////////////////////////////////////////////////////////////////////////////////////\n\n/// CTA-wide semaphore for inter-CTA synchronization.\nclass Semaphore\n{\npublic:\n  int *lock;\n  bool wait_thread;\n  int state;\n\npublic:\n  /// Implements a semaphore to wait for a flag to reach a given value\n  __host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),\n                                                             wait_thread(thread_id < 0 || thread_id == 0),\n                                                             state(-1)\n  {\n  }\n\n  /// Permit fetching the synchronization mechanism early\n  __device__ void fetch()\n  {\n    if (wait_thread)\n    {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n#else\n      asm volatile(\"ld.global.cg.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n#endif\n    }\n  }\n\n  /// Gets the internal state\n  __device__ int get_state() const\n  {\n    return state;\n  }\n\n  /// Waits until the semaphore is equal to the given value\n  __device__ void wait(int status = 0)\n  {\n    while (__syncthreads_and(state != status))\n    {\n      fetch();\n    }\n\n    __syncthreads();\n  }\n\n  /// Updates the lock with the given result\n  __device__ void release(int status = 0)\n  {\n    __syncthreads();\n\n    if (wait_thread)\n    {\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700\n      asm volatile(\"st.global.release.gpu.b32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(status));\n#else\n      asm volatile(\"st.global.cg.b32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(status));\n#endif\n    }\n  }\n};\n\n/////////////////////////////////////////////////////////////////////////////////////////////////\n\n// } // namespace cutlass\n\n/////////////////////////////////////////////////////////////////////////////////////////////////\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/COPYRIGHT",
    "content": "These kernels were vendored from VLLM. The Marlin kernels were developed\nby Elias Frantar and extended by Neural Magic.\n\n---\n\nCopyright (C) Marlin.2024 Elias Frantar\nModified by Neural Magic\nCopyright 2024 The vLLM team.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n         http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/fp8_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include \"fp8_marlin.cuh\"\n\nusing namespace gptq_marlin;\n\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)               \\\n  static_assert(std::is_same<scalar_t, half>::value ||          \\\n                    std::is_same<scalar_t, nv_bfloat16>::value, \\\n                \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T>\ninline std::string str(T x) {\n  return std::to_string(x);\n}\n\nnamespace fp8_marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {}\n\n}  // namespace fp8_marlin\n\ntorch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                              torch::Tensor& b_scales, torch::Tensor& workspace,\n                              int64_t num_bits, int64_t size_m, int64_t size_n,\n                              int64_t size_k) {\n  TORCH_CHECK_NOT_IMPLEMENTED(false,\n                              \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n                           const typename ScalarType<scalar_t>::FragB& frag_b,\n                           typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <typename scalar_t>\n__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n                             const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n               : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n               : \"r\"(smem));\n}\n\n// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16\n// bf16 Reference:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n  // Constants for FP8 (E4M3) and FP16 formats\n  constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;\n  constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;\n\n  // Calculate MASK for extracting mantissa and exponent\n  constexpr int MASK1 = 0x80000000;\n  constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);\n  constexpr int MASK3 = MASK2 & 0x7fffffff;\n  constexpr int MASK = MASK3 | (MASK3 >> 16);\n  // Final MASK value: 0x7F007F00\n\n  // Extract and shift FP8 values to FP16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET =\n      (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));\n  const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));\n\n  // Convert to half2 and apply bias\n  typename ScalarType<half>::FragB frag_b;\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);\n  frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_8bit<nv_bfloat16>(int q) {\n  // Constants for FP8 (E4M3) and BF16 formats\n  constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;\n  constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;\n\n  // Calculate MASK for extracting mantissa and exponent\n  constexpr int MASK1 = 0x80000000;\n  constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);\n  constexpr int MASK3 = MASK2 & 0x7fffffff;\n  constexpr int MASK = MASK3 | (MASK3 >> 16);\n  // Final MASK value: 0x7F007F00\n\n  // Extract and shift FP8 values to BF16 format\n  int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);\n  int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);\n\n  // Construct and apply exponent bias\n  constexpr int BIAS_OFFSET =\n      (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));\n  // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent\n  // position\n  constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;\n  const nv_bfloat162 bias_reg =\n      __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));\n\n  // Convert to bfloat162 and apply bias\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n  // Note: reverse indexing is intentional because weights are permuted\n  frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);\n  frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);\n  return frag_b;\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n                             typename ScalarType<scalar_t>::FragS& frag_s,\n                             int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s =\n      ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c,\n                                   typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                   : \"=r\"(state)\n                   : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                 :\n                 : \"l\"(lock), \"r\"(val));\n  }\n}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n\n  constexpr int pack_factor = 32 / num_bits;\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n  // better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > 16 * thread_m_blocks) {\n    parallel = prob_m / (16 * thread_m_blocks);\n    prob_m = 16 * thread_m_blocks;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;  // number of threadblock tiles in the current slice\n  int slice_count =\n      0;          // total number of active threadblocks in the current slice\n  int slice_idx;  // index of threadblock in current slice; numbered bottom to\n                  // top\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    locks += (slice_col_par / n_tiles) * n_tiles;\n    slice_col = slice_col_par % n_tiles;\n  }\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&]() {\n    slice_iters =\n        iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * prob_k / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      locks += n_tiles;\n      slice_col = 0;\n    }\n  };\n  init_slice();\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = prob_k / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd =\n      a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  int b_sh_wr = threadIdx.x * b_thread_vecs;\n  int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n  int s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // We scale a `half2` tile in row-major layout for column-wise quantization.\n  int s_sh_rd =\n      8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;\n\n  // Precompute which thread should not read memory in which iterations; this is\n  // needed if there are more threads than required for a certain tilesize or\n  // when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] =\n          transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  int4* sh_a = sh;\n  int4* sh_b = sh_a + (stages * a_sh_stage);\n  int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n  int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n  constexpr int sh_max_num_groups = 32;\n\n  auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n                                    int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups < sh_max_num_groups) {\n      sh_num_groups = sh_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                         &scales_ptr[row_offset + (i * s_gl_stride) +\n                                     slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                         threadIdx.x];\n        }\n      }\n    }\n  };\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n            &sh_a_stage[a_sh_wr_trans[i]],\n            &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n            a_sh_wr_pred[i]);\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe) {\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm4<scalar_t>(frag_a[k % 2][i],\n                      &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n  #pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    is_same_group[pipe] = false;\n    same_group_id[pipe] = 0;\n    return;\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  auto matmul = [&](int k) {\n  // We have the m dimension as the inner loop in order to encourage overlapping\n  // dequantization and matmul operations.\n  #pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n\n      int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n      int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n      int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n      frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n      frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n        mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      int red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                      (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n  #pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n  #pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n  #pragma unroll\n            for (int j = 0; j < 4 * 2; j++) {\n              int red_sh_wr =\n                  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd =\n                    reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n  #pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                      c_rd[k] + c_wr[k];\n              }\n              sh[red_sh_wr] =\n                  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n  #pragma unroll\n          for (int i = 0; i < 4 * 2; i++) {\n            float* c_rd =\n                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n  #pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                  c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      constexpr int c_sh_wr_delta = active_threads;\n      int c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n  // Interestingly, doing direct global accesses here really seems to mess up\n  // the compiler and lead to slowdowns, hence we also use async-copies even\n  // though these fetches are not actually asynchronous.\n  #pragma unroll\n        for (int i = 0; i < thread_m_blocks * 4; i++) {\n          cp_async4_pred(\n              &sh[c_sh_wr + c_sh_wr_delta * i],\n              &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                 c_gl_wr_delta_i * (i % 2)],\n              i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks * 4; i++) {\n        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n          if (!first) {\n            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<float*>(\n                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                  Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n            }\n          }\n          if (!last) {\n            int4 c;\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<scalar_t*>(&c)[j] =\n                  Dtype::float2num(reinterpret_cast<float*>(\n                      &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n            }\n            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                c;\n          }\n        }\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&]() {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta =\n        c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr =\n        (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n    c_sh_wr += 32 * (threadIdx.x / 32);\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s) {\n      scalar_t2 res =\n          Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      ((scalar_t2*)sh)[idx] = res;\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n        for (int j = 0; j < 4; j++) {\n          int wr = c_sh_wr + 8 * j;\n          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n  #pragma unroll\n    for (int i = 0;\n         i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n         i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        C[c_gl_wr] = sh[c_sh_rd];\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n  #pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      fetch_to_shared(i, i, i < slice_iters);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n    slice_k_start_shared_fetch += tb_k * (stages - 1);\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n  #pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n  #pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                          slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n          init_same_group(pipe % stages);\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    a_gl_rd += a_gl_rd_delta_o * stages;\n    slice_k_start += tb_k * stages;\n    slice_k_start_shared_fetch += tb_k * stages;\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if (s_sh_wr_pred) {\n        cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n      }\n      cp_async_fence();\n\n      thread_block_reduce();\n\n      cp_async_wait<0>();\n      __syncthreads();\n      if (threadIdx.x / 32 < thread_n_blocks / 4) {\n        reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n        reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n        for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n          for (int j = 0; j < 4; j++) {\n            scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                                  frag_s[j / 2][2 * (j % 2) + 0]);\n            scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                                  frag_s[j / 2][2 * (j % 2) + 0]);\n\n            scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                                  frag_s[j / 2][2 * (j % 2) + 1]);\n            scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                                  frag_s[j / 2][2 * (j % 2) + 1]);\n          }\n        }\n      }\n\n      if (slice_count > 1) {  // only globally reduce if there is more than one\n                              // block in a slice\n        barrier_acquire(&locks[slice_col], slice_idx);\n        global_reduce(slice_idx == 0, last);\n        barrier_release(&locks[slice_col], last);\n      }\n      if (last)  // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      init_slice();\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                  (threadIdx.x % a_gl_rd_delta_o);\n  #pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n  #pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n        }\n\n        // Update slice k/n for scales loading\n        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n\n        start_pipes();\n      }\n    }\n  }\n}\n\n  #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,                \\\n                    THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS)                \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) {     \\\n      cudaFuncSetAttribute(                                                    \\\n          Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,             \\\n                 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,                 \\\n             THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>      \\\n          <<<blocks, NUM_THREADS, max_shared_mem, stream>>>(                   \\\n              A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k,  \\\n              locks);                                                          \\\n    }\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\ntypedef struct {\n  int max_m_blocks;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n\n};\n\nint get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n                          int prob_n, int prob_k, int num_bits,\n                          int group_size) {\n  int tb_n = th_config.thread_n;\n\n  // Get max scale groups per thread-block\n  // Fixed for channelwise\n  int tb_groups = 1;\n  int tb_scales = tb_groups * tb_n * 2;\n\n  return tb_scales * pipe_stages;\n}\n\nbool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n                         int prob_m, int prob_n, int prob_k, int num_bits,\n                         int scales_cache_size, int max_shared_mem) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n\n  int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n  // Get A size\n  int m_blocks = div_ceil(prob_m, 16);\n  int tb_max_m = 16;\n\n  while (true) {\n    if (m_blocks >= max_m_blocks) {\n      tb_max_m *= max_m_blocks;\n      break;\n    }\n\n    max_m_blocks--;\n    if (max_m_blocks == 0) {\n      TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n    }\n  }\n\n  int a_size = (tb_max_m * tb_k) * 2;\n\n  float pipe_size = (a_size + b_size) * pipe_stages;\n\n  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size);  // Sanity\n\n  return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n}\n\nbool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n                     int prob_m, int prob_n, int prob_k, int num_bits,\n                     int group_size, int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n      th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  //  Determine cache for scales\n  int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n,\n                                                prob_k, num_bits, group_size);\n\n  // Check that pipeline fits into cache\n  if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                           num_bits, scales_cache_size, max_shared_mem)) {\n    return false;\n  }\n\n  return true;\n}\n\nexec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n                                      int num_bits, int group_size,\n                                      int max_shared_mem) {\n  int max_m_blocks = 4;\n  while (max_m_blocks > 0) {\n    if (prob_m <= 16) {\n      for (auto th_config : small_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    } else {\n      for (auto th_config : large_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    }\n\n    max_m_blocks--;  // Process less M blocks per invocation to reduce cache\n                     // usage\n  }\n\n  return exec_config_t{0, {-1, -1, -1}};\n}\n\n  #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)    \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)\n\ntemplate <typename scalar_t>\nvoid marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m,\n                     int prob_n, int prob_k, void* workspace, int num_bits,\n                     int num_groups, int group_size, int dev,\n                     cudaStream_t stream, int thread_k, int thread_n, int sms,\n                     int max_par) {\n  TORCH_CHECK(num_bits == 8, \"num_bits must be 8. Got = \", num_bits);\n  TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m,\n              \", \", prob_n, \", \", prob_k, \"]\");\n\n  int tot_m = prob_m;\n  int tot_m_blocks = div_ceil(tot_m, 16);\n  int pad = 16 * tot_m_blocks - tot_m;\n\n  if (sms == -1) {\n    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n  }\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  // Set thread config\n  exec_config_t exec_cfg;\n  if (thread_k != -1 && thread_n != -1) {\n    // User-defined config\n    exec_cfg =\n        exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};\n  } else {\n    // Auto config\n    exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,\n                                       group_size, max_shared_mem);\n  }\n\n  TORCH_CHECK(\n      exec_cfg.max_m_blocks > 0 &&\n          is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,\n                          prob_n, prob_k, num_bits, group_size, max_shared_mem),\n      \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n      \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n      \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n      \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\", prob_m,\n      \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n      \", group_size = \", group_size, \", max_shared_mem = \", max_shared_mem);\n\n  int num_threads = exec_cfg.tb_cfg.num_threads;\n  thread_k = exec_cfg.tb_cfg.thread_k;\n  thread_n = exec_cfg.tb_cfg.thread_n;\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n\n  int blocks = sms;\n\n  TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n              \" is not divisible by thread_n = \", thread_n);\n  TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n              \" is not divisible by thread_k = \", thread_k);\n\n  int group_blocks = -1;\n\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  const int4* s_ptr = (const int4*)s;\n\n  int* locks = (int*)workspace;\n\n  // Main loop\n  for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n    int thread_m_blocks = tot_m_blocks - i;\n    prob_m = tot_m - 16 * i;\n    int par = 1;\n    if (thread_m_blocks > exec_cfg.max_m_blocks) {\n      // Note that parallel > 1 currently only works for inputs without any\n      // padding\n      par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n      if (par > max_par) par = max_par;\n      prob_m = (16 * exec_cfg.max_m_blocks) * par;\n      i += exec_cfg.max_m_blocks * (par - 1);\n      thread_m_blocks = exec_cfg.max_m_blocks;\n    }\n\n    // Define kernel configurations\n    if (false) {\n    }\n    CALL_IF(8, 32, 2, 256)\n    CALL_IF(8, 16, 4, 256)\n    CALL_IF(8, 8, 8, 256)\n    CALL_IF(8, 8, 4, 128)\n    CALL_IF(8, 4, 8, 128)\n    else {\n      TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" +\n                             str(prob_n) + \", \" + str(prob_k) + \"]\" +\n                             \", num_groups = \" + str(num_groups) +\n                             \", group_size = \" + str(group_size) +\n                             \", thread_m_blocks = \" + str(thread_m_blocks) +\n                             \", thread_n_blocks = \" + str(thread_n_blocks) +\n                             \", thread_k_blocks = \" + str(thread_k_blocks));\n    }\n\n    A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n    C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n  }\n}\n\n}  // namespace fp8_marlin\n\ntorch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                              torch::Tensor& b_scales, torch::Tensor& workspace,\n                              int64_t num_bits, int64_t size_m, int64_t size_n,\n                              int64_t size_k) {\n  // Verify num_bits\n  TORCH_CHECK(num_bits == 8, \"num_bits must be 8. Got = \", num_bits);\n  int pack_factor = 32 / num_bits;\n\n  // Verify A\n  TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n              \", size_m = \", size_m);\n  TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n              \", size_k = \", size_k);\n\n  // Verify B\n  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  int actual_size_n =\n      (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n  TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n              \", actual_size_n = \", actual_size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n  TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n  TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n  TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n  // Alloc buffers\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n  torch::Tensor c = torch::empty({size_m, size_n}, options);\n\n  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_k = -1;\n  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_n = -1;\n  // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n  int sms = -1;\n\n  // Detect groupsize and act_order\n  int num_groups = -1;\n  int group_size = -1;\n\n  int b_rank = b_scales.sizes().size();\n  TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n  TORCH_CHECK(b_scales.size(1) == size_n, \"b_scales dim 1 = \", b_scales.size(1),\n              \" is not size_n = \", size_n);\n  // Channelwise only for FP8\n  TORCH_CHECK(b_scales.size(0) == 1)\n  num_groups = b_scales.size(0);\n\n  // Verify workspace size\n  TORCH_CHECK(\n      size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n      \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n  int min_workspace_size =\n      (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n  TORCH_CHECK(workspace.numel() >= min_workspace_size,\n              \"workspace.numel = \", workspace.numel(),\n              \" is below min_workspace_size = \", min_workspace_size);\n\n  int dev = a.get_device();\n  if (a.scalar_type() == at::ScalarType::Half) {\n    fp8_marlin::marlin_mm_f16i4<half>(\n        a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),\n        b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,\n        workspace.data_ptr(), num_bits, num_groups, group_size, dev,\n        at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n        gptq_marlin::max_par);\n  } else if (a.scalar_type() == at::ScalarType::BFloat16) {\n    fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(\n        a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n        c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,\n        size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,\n        dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n        gptq_marlin::max_par);\n  } else {\n    TORCH_CHECK(false, \"fp8_marlin_gemm only supports bfloat16 and float16\");\n  }\n\n  return c;\n}\n\n#endif"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/fp8_marlin.cuh",
    "content": "// #pragma once\n#include <torch/all.h>\n#include <stdint.h>\n\n\n// #ifndef _fp8_marlin_cuh\n// #define _fp8_marlin_cuh\n\n// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n// assert(0);\n// #else\ntorch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                              torch::Tensor& b_scales, torch::Tensor& workspace,\n                              int64_t num_bits, int64_t size_m, int64_t size_n,\n                              int64_t size_k);\n// #endif\n\n// #endif"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/gptq_marlin.cuh",
    "content": "#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4;  // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,\n                                      bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   .reg .pred p;\\n\"\n      \"   setp.ne.b32 p, %0, 0;\\n\"\n      \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n      \"}\\n\" ::\"r\"((int)pred),\n      \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n      \"}\\n\" ::\"r\"(smem),\n      \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n}  // namespace gptq_marlin"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_dtypes.cuh",
    "content": "\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t>\nclass ScalarType {};\n\ntemplate <>\nclass ScalarType<half> {\n public:\n  using scalar_t = half;\n  using scalar_t2 = half2;\n\n  // Matrix fragments for tensor core instructions; their precise layout is\n  // documented here:\n  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n  using FragA = Vec<half2, 4>;\n  using FragB = Vec<half2, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<half2, 1>;\n\n  static __device__ float inline num2float(const half x) {\n    return __half2float(x);\n  }\n\n  static __device__ half2 inline num2num2(const half x) {\n    return __half2half2(x);\n  }\n\n  static __device__ half2 inline nums2num2(const half x1, const half x2) {\n    return __halves2half2(x1, x2);\n  }\n\n  static __host__ __device__ half inline float2num(const float x) {\n    return __float2half(x);\n  }\n};\n\ntemplate <>\nclass ScalarType<nv_bfloat16> {\n public:\n  using scalar_t = nv_bfloat16;\n  using scalar_t2 = nv_bfloat162;\n\n  using FragA = Vec<nv_bfloat162, 4>;\n  using FragB = Vec<nv_bfloat162, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  static __device__ float inline num2float(const nv_bfloat16 x) {\n    return __bfloat162float(x);\n  }\n\n  static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n    return __bfloat162bfloat162(x);\n  }\n\n  static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                  const nv_bfloat16 x2) {\n    return __halves2bfloat162(x1, x2);\n  }\n\n  static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n    return __float2bfloat16(x);\n  }\n#endif\n};\n\n}  // namespace gptq_marlin\n\n#endif"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_repack.cu",
    "content": "#include \"gptq_marlin.cuh\"\n\nnamespace gptq_marlin {\n\nstatic constexpr int repack_stages = 8;\n\nstatic constexpr int repack_threads = 256;\n\nstatic constexpr int tile_k_size = tile_size;\nstatic constexpr int tile_n_size = tile_k_size * 4;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  TORCH_CHECK_NOT_IMPLEMENTED(\n      false, \"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {\n  constexpr int pack_factor = 32 / num_bits;\n\n  int k_tiles = size_k / tile_k_size;\n  int n_tiles = size_n / tile_n_size;\n  int block_k_tiles = div_ceil(k_tiles, gridDim.x);\n\n  int start_k_tile = blockIdx.x * block_k_tiles;\n  if (start_k_tile >= k_tiles) {\n    return;\n  }\n\n  int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<repack_stages - 2>();\n    __syncthreads();\n  };\n\n  extern __shared__ int4 sh[];\n\n  constexpr int perm_size = tile_k_size / 4;\n\n  int4* sh_perm_ptr = sh;\n  int4* sh_pipe_ptr = sh_perm_ptr;\n  if constexpr (has_perm) {\n    sh_pipe_ptr += perm_size;\n  }\n\n  constexpr int tile_ints = tile_k_size / pack_factor;\n\n  constexpr int stage_n_threads = tile_n_size / 4;\n  constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;\n  constexpr int stage_size = stage_k_threads * stage_n_threads;\n\n  auto load_perm_to_shared = [&](int k_tile_id) {\n    int first_k_int4 = (k_tile_id * tile_k_size) / 4;\n\n    int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);\n\n    if (threadIdx.x < perm_size) {\n      sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];\n    }\n    __syncthreads();\n  };\n\n  auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      cp_async_fence();\n      return;\n    }\n\n    int first_n = n_tile_id * tile_n_size;\n\n    int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;\n\n    if constexpr (has_perm) {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        uint32_t const* sh_perm_int_ptr =\n            reinterpret_cast<uint32_t const*>(sh_perm_ptr);\n\n        int src_k = sh_perm_int_ptr[k_id];\n        int src_k_packed = src_k / pack_factor;\n\n        cp_async4(\n            &sh_ptr[k_id * stage_n_threads + n_id],\n            reinterpret_cast<int4 const*>(&(\n                b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));\n      }\n\n    } else {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        int first_k = k_tile_id * tile_k_size;\n        int first_k_packed = first_k / pack_factor;\n\n        cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],\n                  reinterpret_cast<int4 const*>(\n                      &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +\n                                       first_n + (n_id * 4)])));\n      }\n    }\n\n    cp_async_fence();\n  };\n\n  auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      return;\n    }\n\n    int warp_id = threadIdx.x / 32;\n    int th_id = threadIdx.x % 32;\n\n    if (warp_id >= 4) {\n      return;\n    }\n\n    int tc_col = th_id / 4;\n    int tc_row = (th_id % 4) * 2;\n\n    constexpr int tc_offsets[4] = {0, 1, 8, 9};\n\n    int cur_n = warp_id * 16 + tc_col;\n\n    constexpr int sh_stride = 64;\n    constexpr uint32_t mask = (1 << num_bits) - 1;\n\n    int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;\n    uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);\n\n    uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);\n\n    uint32_t vals[8];\n\n    if constexpr (has_perm) {\n      for (int i = 0; i < 4; i++) {\n        int k_idx = tc_row + tc_offsets[i];\n\n        uint32_t src_k = sh_perm_int_ptr[k_idx];\n        uint32_t src_k_pos = src_k % pack_factor;\n\n        uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];\n        uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;\n\n        uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];\n        uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;\n\n        vals[i] = b1_cur_val;\n        vals[4 + i] = b2_cur_val;\n      }\n\n    } else {\n      uint32_t b1_vals[tile_ints];\n      uint32_t b2_vals[tile_ints];\n\n  #pragma unroll\n      for (int i = 0; i < tile_ints; i++) {\n        b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];\n        b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];\n      }\n\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        int cur_elem = tc_row + tc_offsets[i];\n        int cur_int = cur_elem / pack_factor;\n        int cur_pos = cur_elem % pack_factor;\n\n        vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n        vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n      }\n    }\n\n    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;\n    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;\n\n    // Result of:\n    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n    if constexpr (num_bits == 4) {\n      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};\n\n      uint32_t res = 0;\n  #pragma unroll\n      for (int i = 0; i < 8; i++) {\n        res |= vals[pack_idx[i]] << (i * 4);\n      }\n\n      out_ptr[out_offset + th_id * 4 + warp_id] = res;\n\n    } else {\n      constexpr int pack_idx[4] = {0, 2, 1, 3};\n\n      uint32_t res1 = 0;\n      uint32_t res2 = 0;\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        res1 |= vals[pack_idx[i]] << (i * 8);\n        res2 |= vals[4 + pack_idx[i]] << (i * 8);\n      }\n\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;\n    }\n  };\n\n  auto start_pipes = [&](int k_tile_id, int n_tile_id) {\n  #pragma unroll\n    for (int pipe = 0; pipe < repack_stages - 1; pipe++) {\n      fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);\n    }\n\n    wait_for_stage();\n  };\n  #pragma unroll\n  for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {\n    int n_tile_id = 0;\n\n    if constexpr (has_perm) {\n      load_perm_to_shared(k_tile_id);\n    }\n\n    start_pipes(k_tile_id, n_tile_id);\n\n    while (n_tile_id < n_tiles) {\n  #pragma unroll\n      for (int pipe = 0; pipe < repack_stages; pipe++) {\n        fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,\n                        n_tile_id + pipe + repack_stages - 1);\n        repack_tile(pipe, k_tile_id, n_tile_id + pipe);\n        wait_for_stage();\n      }\n      n_tile_id += repack_stages;\n    }\n  }\n}\n\n}  // namespace gptq_marlin\n\n  #define CALL_IF(NUM_BITS, HAS_PERM)                                          \\\n    else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {                   \\\n      cudaFuncSetAttribute(                                                    \\\n          gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads,       \\\n                                            NUM_BITS, HAS_PERM>,               \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \\\n                                        HAS_PERM>                              \\\n          <<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>(   \\\n              b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);              \\\n    }\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  // Verify compatibility with marlin tile of 16x64\n  TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_k_size = \", gptq_marlin::tile_k_size);\n  TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, \"size_n = \", size_n,\n              \" is not divisible by tile_n_size = \", gptq_marlin::tile_n_size);\n\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int const pack_factor = 32 / num_bits;\n\n  // Verify B\n  TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", pack_factor = \", pack_factor);\n  TORCH_CHECK(b_q_weight.size(1) == size_n,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not size_n = \", size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n  TORCH_CHECK(b_q_weight.dtype() == at::kInt, \"b_q_weight type is not kInt\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n  TORCH_CHECK(perm.dtype() == at::kInt, \"perm type is not at::kInt\");\n\n  // Alloc buffers\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));\n  auto options = torch::TensorOptions()\n                     .dtype(b_q_weight.dtype())\n                     .device(b_q_weight.device());\n  torch::Tensor out =\n      torch::empty({size_k / gptq_marlin::tile_size,\n                    size_n * gptq_marlin::tile_size / pack_factor},\n                   options);\n\n  // Detect if there is act_order\n  bool has_perm = perm.size(0) != 0;\n\n  // Get ptrs\n  uint32_t const* b_q_weight_ptr =\n      reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());\n  uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());\n  uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());\n\n  // Get dev info\n  int dev = b_q_weight.get_device();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);\n  int blocks;\n  cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  if (false) {\n  }\n  CALL_IF(4, false)\n  CALL_IF(4, true)\n  CALL_IF(8, false)\n  CALL_IF(8, true)\n  else {\n    TORCH_CHECK(false, \"Unsupported repack config: num_bits = \", num_bits,\n                \", has_perm = \", has_perm);\n  }\n\n  return out;\n}\n\n#endif"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_repack.cuh",
    "content": "#include <torch/library.h>\n#include <torch/all.h>\n#include <stdint.h>\n\n#ifndef _gptq_marlin_repack_cuh\n#define _gptq_marlin_repack_cuh\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits);\n\n#endif\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/marlin_cuda.cpp",
    "content": "/*\n * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include \"marlin_cuda.h\"\n\n#include <torch/all.h>\n#include <torch/python.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cuda_runtime.h>\n\n#include \"marlin_cuda_kernel.cuh\"\n\nconst int ERR_PROB_SHAPE = 1;\nconst int ERR_KERN_SHAPE = 2;\n\nvoid mul(\n  const torch::Tensor& A,\n  const torch::Tensor& B,\n        torch::Tensor& C,\n  const torch::Tensor& s,\n  const torch::Tensor& sz, // ADDED: add scaled zero point\n        torch::Tensor& workspace,\n  int thread_k,\n  int thread_n,\n  int sms,\n  int max_par\n) {\n  int prob_m = A.size(0);\n  int prob_n = C.size(1);\n  int prob_k = A.size(1);\n  int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0);\n  if (groupsize != -1 && groupsize * s.size(0) != prob_k)\n    AT_ERROR(\"k=\", prob_k, \" not compatible with \", s.size(0), \" groups.\");\n  if (workspace.numel() < prob_n / 128 * max_par)\n    AT_ERROR(\"workspace must be of size at least \", prob_n / 128 * max_par, \".\");\n  int dev = A.get_device();\n  int err = marlin_cuda(\n    A.data_ptr(),\n    B.data_ptr(),\n    C.data_ptr(),\n    s.data_ptr(),\n    sz.data_ptr(), // ADDED: add scaled zero point\n    prob_m, prob_n, prob_k,\n    workspace.data_ptr(),\n    groupsize,\n    dev,\n    at::cuda::getCurrentCUDAStream(dev),\n    thread_k,\n    thread_n,\n    sms,\n    max_par\n  );\n  if (err == ERR_PROB_SHAPE) {\n    AT_ERROR(\n      \"Problem (m=\", prob_m, \", n=\", prob_n, \", k=\", prob_k, \")\",\n      \" not compatible with thread_k=\", thread_k, \", thread_n=\", thread_n, \".\"\n    );\n  } else if (err == ERR_KERN_SHAPE) {\n    AT_ERROR(\n      \"No kernel implementation for thread_k=\", thread_k, \", thread_n=\", thread_n, \", groupsize=\", groupsize, \".\"\n    );\n  }\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/marlin_cuda.h",
    "content": "/*\n * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include <torch/extension.h>\n\nvoid mul(\n  const torch::Tensor& A,\n  const torch::Tensor& B,\n        torch::Tensor& C,\n  const torch::Tensor& s,\n  const torch::Tensor& sz,\n        torch::Tensor& workspace,\n  int thread_k = -1,\n  int thread_n = -1,\n  int sms = -1,\n  int max_par = 8\n);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu",
    "content": "/*\n * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n\n#ifndef MARLIN_CUDA_KERNEL_CUH\n#define MARLIN_CUDA_KERNEL_CUH\n\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\n\nconstexpr int ceildiv(int a, int b) {\n  return (a + b - 1) / b;\n}\n\n// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core\n// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we\n// extensively use `#pragma unroll` throughout the kernel code to guarantee this.\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) {\n    return elems[i];\n  }\n};\n\nusing I4 = Vec<int, 4>;\n\n// Matrix fragments for tensor core instructions; their precise layout is documented here:\n// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\nusing FragA = Vec<half2, 4>;\nusing FragB = Vec<half2, 2>;\nusing FragC = Vec<float, 4>;\nusing FragS = Vec<half2, 1>; // quantization scales\n\n// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that\n// are not multiples of 16.\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n    \"{\\n\"\n    \"   .reg .pred p;\\n\"\n    \"   setp.ne.b32 p, %0, 0;\\n\"\n    \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n    \"}\\n\" :: \"r\"((int) pred), \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES)\n  );\n}\n\n// Asynchronous global->shared copy\n__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\"{\\n\"\n               \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n               \"}\\n\" :: \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n// Async copy fence.\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\n// Wait until at most `n` async copy stages are still pending.\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" :: \"n\"(n));\n}\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation.\n__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  asm volatile(\n    \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n    \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n    : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n    :  \"r\"(a[0]),  \"r\"(a[1]),  \"r\"(a[2]),  \"r\"(a[3]),  \"r\"(b[0]),  \"r\"(b[1]),\n       \"f\"(c[0]),  \"f\"(c[1]),  \"f\"(c[2]),  \"f\"(c[3])\n  );\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout.\n__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n    \"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n    : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3]) : \"r\"(smem)\n  );\n}\n\n// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to\n// automatically recognize it in all cases.\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\n    \"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n    : \"=r\"(res) : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut)\n  );\n  return res;\n}\n\n// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values.\n// We mostly follow the strategy in the link below, with some small changes:\n// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n__device__ inline FragB dequant(int q) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`.\n  // const int SUB = 0x64086408;\n  // const int MUL = 0x2c002c00;\n  // const int ADD = 0xd480d480;\n  // MODIFIED: use scaled zero point so do not need to map to [-8, 7]\n  const int SUB = 0x64006400;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd400d400;\n  FragB frag_b;\n  frag_b[0] = __hsub2(\n    *reinterpret_cast<half2*>(&lo),\n    *reinterpret_cast<const half2*>(&SUB)\n  );\n  frag_b[1] = __hfma2(\n    *reinterpret_cast<half2*>(&hi),\n    *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD)\n  );\n  return frag_b;\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization.\n// MODIFIED: add scaled zero point\n__device__ inline void scale(FragB& frag_b, FragS& frag_s, FragS& frag_sz, int i) {\n  half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);\n  half2 sz = __half2half2(reinterpret_cast<__half*>(&frag_sz)[i]);\n  // frag_b[0] = __hmul2(frag_b[0], s);\n  // frag_b[1] = __hmul2(frag_b[1], s);\n  frag_b[0] = __hfma2(frag_b[0], s, sz);\n  frag_b[1] = __hfma2(frag_b[1], s, sz);\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible globally.\n      asm volatile (\"ld.global.acquire.gpu.b32 %0, [%1];\\n\" : \"=r\"(state) : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier.\n    asm volatile (\"fence.acq_rel.gpu;\\n\");\n    asm volatile (\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\" : : \"l\"(lock), \"r\"(val));\n  }\n}\n\n\ntemplate <\n  const int threads, // number of threads in a threadblock\n  const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock\n  const int thread_n_blocks, // same for n dimension (output)\n  const int thread_k_blocks, // same for k dimension (reduction)\n  const int stages, // number of stages for the async global->shared fetch pipeline\n  const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale\n>\n__global__ void Marlin(\n  const int4* __restrict__ A, // fp16 input matrix of shape mxk\n  const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n        int4* __restrict__ C, // fp16 output buffer of shape mxn\n  const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn\n  // ADDED: add scaled zero point\n  const int4* __restrict__ sz, // fp16 quantization scaled zero points of shape (k/groupsize)xn\n  int  prob_m, // batch dimension m\n  int  prob_n, // output dimension n\n  int  prob_k, // reduction dimension k\n  int* locks // extra global storage for barrier synchronization\n) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the same size, which might involve multiple\n  // column \"slices\" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs\n  // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as\n  // possible.\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > 16 * thread_m_blocks) {\n    parallel = prob_m / (16 * thread_m_blocks);\n    prob_m = 16 * thread_m_blocks;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x);\n  // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case\n  // where a stripe starts in the middle of group.\n  if (group_blocks != -1)\n    iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks));\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters; // number of threadblock tiles in the current slice\n  int slice_count = 0; // total number of active threadblocks in the current slice\n  int slice_idx; // index of threadblock in current slice; numbered bottom to top\n\n  // We can easily implement parallel problem execution by just remapping indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    locks += (slice_col_par / n_tiles) * n_tiles;\n    slice_col = slice_col_par % n_tiles;\n  }\n\n  // Compute all information about the current slice which is required for synchronization.\n  auto init_slice = [&] () {\n    slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)\n      slice_iters = 0;\n    if (slice_iters == 0)\n      return;\n    if (slice_row + slice_iters > k_tiles)\n      slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * ceildiv(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = ceildiv(k_tiles - col_off, iters);\n      if (col_off > 0)\n        slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0)\n          slice_idx--;\n      }\n    }\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * prob_k / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      locks += n_tiles;\n      slice_col = 0;\n    }\n  };\n  init_slice();\n\n  int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory\n  // We typically use `constexpr` to indicate that this value is a compile-time constant\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile\n  constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile\n\n  int b_gl_stride = 16 * prob_n / 32;\n  constexpr int b_sh_stride = 32 * thread_n_blocks / 4;\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);\n  constexpr int b_sh_wr_delta = threads;\n  constexpr int b_sh_rd_delta = threads;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_sh_stage = s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  int b_sh_wr = threadIdx.x;\n  int b_sh_rd = threadIdx.x;\n\n  int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;\n  int s_sh_wr = threadIdx.x;\n  int s_sh_rd;\n  // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major\n  // layout in the former and in row-major in the latter case.\n  if (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;\n\n  // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than\n  // required for a certain tilesize or when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank\n  // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of\n  // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based\n  // on NSight-Compute) that each warp must also write a consecutive memory segment?\n  auto transform_a = [&] (int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory\n  // accesses are static, we simply precompute both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n    #pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between\n  // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  int4* sh_a = sh;\n  int4* sh_b = sh_a + (stages * a_sh_stage);\n  int4* sh_s = sh_b + (stages * b_sh_stage);\n  // ADDED: shared memory storage for scaled zero points\n  int4* sh_sz = sh_s + (stages * s_sh_stage);\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];\n  // ADDED: register storage for scaled zero points\n  FragS frag_sz[2][4];\n\n  // Zero accumulators.\n  auto zero_accums = [&] () {\n    #pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location.\n  auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n      #pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n          &sh_a_stage[a_sh_wr_trans[i]],\n          &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n          a_sh_wr_pred[i]\n        );\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n      #pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n        cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n      // Only fetch scales if this tile starts a new group\n      if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {\n        // ADDED: fetch scaled zero pointers too\n        int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n        int4* sh_sz_stage = sh_sz + s_sh_stage * pipe;\n        if (s_sh_wr_pred) {\n          cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);\n          cp_async4(&sh_sz_stage[s_sh_wr], &sz[s_gl_rd]);\n        }\n        s_gl_rd += s_gl_rd_delta;\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&] () {\n    // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when\n    // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer.\n  auto fetch_to_registers = [&] (int k, int pipe) {\n    // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a\n    // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the\n    // compiler and correspondingly a noticable drop in performance.\n    if (group_blocks != -1) {\n      // ADDED: load scaled zero pointers too\n      int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n      int4* sh_sz_stage = sh_sz + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));\n      reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n      reinterpret_cast<int4*>(&frag_sz[k % 2])[0] = sh_sz_stage[s_sh_rd];\n    }\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n    #pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n    frag_b_quant[k % 2] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  auto matmul = [&] (int k) {\n    // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations.\n    #pragma unroll\n    for (int j = 0; j < 4; j++) {\n      int b_quant = frag_b_quant[k % 2][j];\n      int b_quant_shift = b_quant >> 8;\n      FragB frag_b0 = dequant(b_quant);\n      // If there are no groups, we can just scale the final output once and can avoid doing so for each weight.\n      // MODIFIED: add scaled zero point\n      if (group_blocks != -1)\n        scale(frag_b0, frag_s[k % 2][j], frag_sz[k % 2][j], 0);\n      FragB frag_b1 = dequant(b_quant_shift);\n      if (group_blocks != -1)\n        scale(frag_b1, frag_s[k % 2][j], frag_sz[k % 2][j], 1);\n      #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n        mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n\n  // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&] () {\n    constexpr int red_off = threads / b_sh_stride / 2;\n    if (red_off >= 1) {\n      int red_idx = threadIdx.x / b_sh_stride;\n      constexpr int red_sh_stride = b_sh_stride * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations,\n      // e.g., for two warps we write only once by warp 1 and read only once by warp 0.\n\n      #pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n        #pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n            #pragma unroll\n            for (int j = 0; j < 4 * 2; j++) {\n              int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n                #pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];\n              }\n              sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n          #pragma unroll\n          for (int i = 0; i < 4 * 2; i++) {\n            float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n            #pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over\n  // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather\n  // small, we perform this reduction serially in L2 cache.\n  auto global_reduce = [&] (bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step.\n    // To do this, we write out results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      constexpr int c_sh_wr_delta = active_threads;\n      int c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n        // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns,\n        // hence we also use async-copies even though these fetches are not actually asynchronous.\n        #pragma unroll\n        for (int i = 0; i < thread_m_blocks * 4; i++) {\n          cp_async4_pred(\n            &sh[c_sh_wr + c_sh_wr_delta * i],\n            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],\n            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m\n          );\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n      #pragma unroll\n      for (int i = 0; i < thread_m_blocks * 4; i++) {\n        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n          if (!first) {\n            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n            #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(\n                reinterpret_cast<__half*>(&c_red)[j]\n              );\n            }\n          }\n          if (!last) {\n            int4 c;\n            #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<__half*>(&c)[j] = __float2half(\n                reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]\n              );\n            }\n            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c;\n          }\n        }\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step,\n  // the reduction above is performed in fragment layout.\n  auto write_result = [&] () {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n    c_sh_wr += 32 * (threadIdx.x / 32);\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n\n    // We first reorder in shared memory to guarantee the most efficient final global write patterns\n    auto write = [&] (int idx, float c0, float c1, FragS& s) {\n      half2 res = __halves2half2(__float2half(c0), __float2half(c1));\n      if (group_blocks == -1) // for per-column quantization we finally apply the scale here\n        res = __hmul2(res, s[0]);\n      ((half2*) sh)[idx] = res;\n    };\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n      #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        #pragma unroll\n        for (int j = 0; j < 4; j++) {\n          int wr = c_sh_wr + 8 * j;\n          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n    #pragma unroll\n    for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        C[c_gl_wr] = sh[c_sh_rd];\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&] () {\n    #pragma unroll\n    for (int i = 0; i < stages - 1; i++)\n      fetch_to_shared(i, i, i < slice_iters);\n    zero_accums();\n    wait_for_stage();\n    fetch_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n  };\n  start_pipes();\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are\n    // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0.\n    #pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n      #pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0)\n        break;\n    }\n    a_gl_rd += a_gl_rd_delta_o * stages;\n\n    // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most\n    // readable, other ways of writing the loop seemed to noticeably worse performance after compliation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before write-out\n      if (group_blocks == -1 && last) {\n        if (s_sh_wr_pred) {\n          cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);\n          // ADDED: fetch scaled zero pointers too\n          cp_async4(&sh_sz[s_sh_wr], &sz[s_gl_rd]);\n        }\n        cp_async_fence();\n      }\n      thread_block_reduce();\n      if (group_blocks == -1 && last) {\n        cp_async_wait<0>();\n        __syncthreads();\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n          reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n          reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n          // ADDED: load scaled zero pointers too\n          reinterpret_cast<int4*>(&frag_sz)[0] = sh_sz[s_sh_rd + 0];\n          reinterpret_cast<int4*>(&frag_sz)[1] = sh_sz[s_sh_rd + 4];\n        }\n      }\n      if (slice_count > 1) { // only globally reduce if there is more than one block in a slice\n        barrier_acquire(&locks[slice_col], slice_idx);\n        global_reduce(slice_idx == 0, last);\n        barrier_release(&locks[slice_col], last);\n      }\n      if (last) // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      init_slice();\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);\n        #pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n          #pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++)\n            B_ptr[i] -= b_gl_stride;\n        }\n        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n        start_pipes();\n      }\n    }\n  }\n}\n\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more\n// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles.\nconst int THREADS = 256;\nconst int STAGES = 4; // 4 pipeline stages fit into shared memory\nconst int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)\n\n// ADDED: add scaled zero pointer\n#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \\\n  else if ( \\\n    thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \\\n    group_blocks == GROUP_BLOCKS \\\n  ) { \\\n    cudaFuncSetAttribute( \\\n      Marlin<THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \\\n      cudaFuncAttributeMaxDynamicSharedMemorySize, \\\n      SHARED_MEM \\\n    ); \\\n    Marlin< \\\n      THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \\\n    ><<<blocks, THREADS, SHARED_MEM, stream>>>( \\\n      A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\\\n      prob_m, prob_n, prob_k, \\\n      locks \\\n    ); \\\n  }\n\nconst int ERR_PROB_SHAPE = 1;\nconst int ERR_KERN_SHAPE = 2;\n\n// ADDED: add scaled zero pointer\nint marlin_cuda(\n  const void* A,\n  const void* B,\n        void* C,\n        void* s,\n        void* sz,\n  int prob_m,\n  int prob_n,\n  int prob_k,\n  void* workspace,\n  int groupsize = -1,\n  int dev = 0,\n  cudaStream_t stream = 0,\n  int thread_k = -1,\n  int thread_n = -1,\n  int sms = -1,\n  int max_par = 16\n) {\n  int tot_m = prob_m;\n  int tot_m_blocks = ceildiv(tot_m, 16);\n  int pad = 16 * tot_m_blocks - tot_m;\n\n  if (sms == -1)\n    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n  if (thread_k == -1 || thread_n == -1) {\n    if (prob_m <= 16) {\n      // For small batchizes, better partioning is slightly more important than better compute utilization\n      thread_k = 128;\n      thread_n = 128;\n    } else {\n      thread_k = 64;\n      thread_n = 256;\n    }\n  }\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n  int group_blocks = (groupsize == -1) ? -1 : groupsize / 16;\n  int blocks = sms;\n\n  if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0))\n    return ERR_PROB_SHAPE;\n  if (prob_m == 0 || prob_n == 0 || prob_k == 0)\n    return 0;\n\n  const int4* A_ptr = (const int4*) A;\n  const int4* B_ptr = (const int4*) B;\n  int4* C_ptr = (int4*) C;\n  const int4* s_ptr = (const int4*) s;\n  // ADDED: add scaled zero pointer\n  const int4* sz_ptr = (const int4*) sz;\n\n  int cols = prob_n / thread_n;\n  int* locks = (int*) workspace;\n\n  int ret = 0;\n  for (int i = 0; i < tot_m_blocks; i += 4) {\n    int thread_m_blocks = tot_m_blocks - i;\n    prob_m = tot_m - 16 * i;\n    int par = 1;\n    if (thread_m_blocks > 4) {\n      // Note that parallel > 1 currently only works for inputs without any padding\n      par = (16 * thread_m_blocks - pad) / 64;\n      if (par > max_par)\n        par = max_par;\n      prob_m = 64 * par;\n      i += 4 * (par - 1);\n      thread_m_blocks = 4;\n    }\n\n    // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance)\n    // in our testing, however many more are, in principle, possible.\n    if (false) {}\n    CALL_IF(1,  8,  8, -1)\n    CALL_IF(1,  8,  8,  8)\n    CALL_IF(1, 16,  4, -1)\n    CALL_IF(1, 16,  4,  8)\n    CALL_IF(2, 16,  4, -1)\n    CALL_IF(2, 16,  4,  8)\n    CALL_IF(3, 16,  4, -1)\n    CALL_IF(3, 16,  4,  8)\n    CALL_IF(4, 16,  4, -1)\n    CALL_IF(4, 16,  4,  8)\n    else\n      ret = ERR_KERN_SHAPE;\n\n    A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n    C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n  }\n\n  return ret;\n}\n\n\n#endif\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cuh",
    "content": "/*\n * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n#include <cuda_runtime.h>\n\nint marlin_cuda(\n  const void* A,\n  const void* B,\n        void* C,\n        void* s,\n        void* sz, // ADDED: add scaled zero point\n  int prob_m,\n  int prob_n,\n  int prob_k,\n  void* workspace,\n  int groupsize = -1,\n  int dev = 0,\n  cudaStream_t stream = 0,\n  int thread_k = -1,\n  int thread_n = -1,\n  int sms = -1,\n  int max_par = 16\n);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/pybind_module.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include \"awq/v2/gemm_cuda.h\"\n#include \"awq/v2/gemv_cuda.h\"\n#include \"unpack.h\"\n#include \"marlin/fp8_marlin.cuh\"\n#include \"marlin/gptq_marlin_repack.cuh\"\n#include \"marlin/marlin_cuda.h\"\n\n// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,\n// and need to be explicitly converted using dedicated helpers before calling a C++ method.\n// As a consequence, when an operation takes such an object as parameter, instead\n// of creating a binding directly to the C++ method, you must create a binding to a\n// lambda method that converts the unmapped types and calls the C++ method.\n// See the binding of quantize_symmetric for instance.\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"awq_v2_gemm_f16i4\", &awq_v2_gemm_f16i4, \"awq_v2_gemm_f16i4\");\n  m.def(\"awq_v2_gemv_f16i4\", &awq_v2_gemv_f16i4, \"awq_v2_gemv_f16i4\");\n  m.def(\"gptq_marlin_repack\", &gptq_marlin_repack, \"gptq_marlin_repack\");\n  m.def(\"fp8_marlin_gemm\", &fp8_marlin_gemm, \"fp8_marlin_gemm\");\n  m.def(\"marlin_gemm_f16i4\", &mul, \"marlin_gemm_f16i4\");\n  m.def(\"unpack\", &unpack, \"unpack\");\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/unpack.cu",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <c10/cuda/CUDAException.h>\n\ninline  unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}\n#define BLOCK_SIZE 256\n\nusing namespace at;\n\n\nstatic torch::Tensor allocate_output(const torch::Tensor& input, int bits) {\n    int n_packed = 8 / bits;\n    auto output_shape = input.sizes().vec();\n    output_shape[0] = output_shape[0] * n_packed;\n    return torch::empty(output_shape, input.options());\n}\n\n__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {\n\tint i = blockIdx.x*blockDim.x + threadIdx.x;\n\tif(i>=n) return;\n\n\toutput[i]     = (input[i] & 0x0F);\n\toutput[i + n] = (input[i] & 0xF0) >> 4;\n}\n\nstatic torch::Tensor unpack_4bit(const torch::Tensor& input){\n\n\tauto output = allocate_output(input, 4);\n\n    const auto numel = input.numel();\n\tint blocks = cdiv(numel, BLOCK_SIZE);\n\tunpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(\n        input.data_ptr<unsigned char>(),\n        output.data_ptr<unsigned char>(),\n        numel\n    );\n\n\tC10_CUDA_KERNEL_LAUNCH_CHECK();\n\n\treturn output;\n}\n\n__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {\n\tint i = blockIdx.x*blockDim.x + threadIdx.x;\n\tif(i>=n) return;\n\n\toutput[i]       = (input[i] & 0x03);\n\toutput[i + n]   = (input[i] & 0x0C) >> 2;\n\toutput[i + n*2] = (input[i] & 0x30) >> 4;\n\toutput[i + n*3] = (input[i] & 0xC0) >> 6;\n}\n\nstatic torch::Tensor unpack_2bit(const torch::Tensor& input){\n\n\tauto output = allocate_output(input, 2);\n\n    const auto numel = input.numel();\n\tint blocks = cdiv(numel, BLOCK_SIZE);\n\tunpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(\n        input.data_ptr<unsigned char>(),\n        output.data_ptr<unsigned char>(),\n        numel\n    );\n\n\tC10_CUDA_KERNEL_LAUNCH_CHECK();\n\n\treturn output;\n}\n\ntorch::Tensor unpack(torch::Tensor &t, int bits) {\n    TORCH_CHECK(t.scalar_type() == torch::kUInt8, \"Unsupported data type: \", t.scalar_type());\n    TORCH_CHECK(t.device().is_cuda(), \"t must be a CUDA tensor.\");\n    TORCH_CHECK(t.is_contiguous(), \"t must be contiguous.\");\n    switch(bits) {\n      case 4:\n        return unpack_4bit(t);\n      case 2:\n        return unpack_2bit(t);\n      default:\n        throw std::invalid_argument(\"Can only unpack 2-bit or 4-bit tensors.\");\n    }\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/cuda/unpack.h",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n\ntorch::Tensor unpack(torch::Tensor &t, int bits);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/extension.py",
    "content": "import os\nimport shutil\nimport warnings\nfrom typing import List\n\nimport torch\nfrom torch.utils.cpp_extension import load\n\n\n__all__ = [\"is_extension_available\", \"get_extension\"]\n\n\nclass Extension(object):\n    def __init__(\n        self,\n        name: str,\n        root_dir: str,\n        sources: List[str],\n        extra_cflags: List[str] = None,\n        extra_cuda_cflags: List[str] = None,\n    ):\n        self.name = name\n        self.sources = [f\"{root_dir}/{source}\" for source in sources]\n        self.extra_cflags = extra_cflags\n        self.extra_cuda_cflags = extra_cuda_cflags\n        self.build_directory = os.path.join(root_dir, \"build\")\n        self._lib = None\n\n    @property\n    def lib(self):\n        if self._lib is None:\n            # We only load the extension when the lib is required\n            version_file = os.path.join(self.build_directory, \"pytorch_version.txt\")\n            if os.path.exists(version_file):\n                # The extension has already been built: check the torch version for which it was built\n                with open(version_file, \"r\") as f:\n                    pytorch_build_version = f.read().rstrip()\n                    if pytorch_build_version != torch.__version__:\n                        shutil.rmtree(self.build_directory)\n                        warnings.warn(\n                            f\"{self.name} was compiled with pytorch {pytorch_build_version}, but {torch.__version__} is installed: it will be recompiled.\"\n                        )\n            os.makedirs(self.build_directory, exist_ok=True)\n            self._lib = load(\n                name=self.name,\n                sources=self.sources,\n                extra_cflags=self.extra_cflags,\n                extra_cuda_cflags=self.extra_cuda_cflags,\n                build_directory=self.build_directory,\n            )\n            if not os.path.exists(version_file):\n                with open(version_file, \"w\") as f:\n                    f.write(torch.__version__)\n        return self._lib\n\n\n_extensions = {}\n\n\ndef register_extension(extension: Extension):\n    assert extension.name not in _extensions\n    _extensions[extension.name] = extension\n\n\ndef get_extension(extension_type: str):\n    \"\"\"Get an extension\n\n    Args:\n        extension_type (`str`):\n            The extension type.\n    Returns:\n        The corresponding extension.\n    \"\"\"\n    return _extensions[extension_type]\n\n\ndef is_extension_available(extension_type: str):\n    \"\"\"Check is an extension is available\n\n    Args:\n        extension_type (`str`):\n            The extension type.\n    Returns:\n        True if the extension is available.\n    \"\"\"\n    return extension_type in _extensions\n"
  },
  {
    "path": "optimum/quanto/library/extensions/hip/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nfrom ..extension import Extension, register_extension\n\n\n__all__ = []\n\n\next = Extension(\n    \"quanto_hip\",\n    root_dir=os.path.dirname(__file__),\n    sources=[\"unpack.cu\", \"pybind_module.cpp\"],\n    extra_cflags=[\"-std=c++17\"],\n)\nregister_extension(ext)\n\n\n@torch.library.impl(\"quanto::unpack\", [\"CUDA\"])\ndef unpack_hip(t: torch.Tensor, bits: int):\n    return ext.lib.unpack(t, bits)\n"
  },
  {
    "path": "optimum/quanto/library/extensions/hip/pybind_module.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include \"unpack.h\"\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"unpack\", &unpack, \"unpack\");\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/hip/unpack.cu",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <c10/cuda/CUDAException.h>\n\ninline  unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}\n#define BLOCK_SIZE 256\n\nusing namespace at;\n\n\nstatic torch::Tensor allocate_output(const torch::Tensor& input, int bits) {\n    int n_packed = 8 / bits;\n    auto output_shape = input.sizes().vec();\n    output_shape[0] = output_shape[0] * n_packed;\n    return torch::empty(output_shape, input.options());\n}\n\n__global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) {\n\tint i = blockIdx.x*blockDim.x + threadIdx.x;\n\tif(i>=n) return;\n\n\toutput[i]     = (input[i] & 0x0F);\n\toutput[i + n] = (input[i] & 0xF0) >> 4;\n}\n\nstatic torch::Tensor unpack_4bit(const torch::Tensor& input){\n\n\tauto output = allocate_output(input, 4);\n\n    const auto numel = input.numel();\n\tint blocks = cdiv(numel, BLOCK_SIZE);\n\tunpack_4bit_kernel<<<blocks, BLOCK_SIZE>>>(\n        input.data_ptr<unsigned char>(),\n        output.data_ptr<unsigned char>(),\n        numel\n    );\n\n\tC10_CUDA_KERNEL_LAUNCH_CHECK();\n\n\treturn output;\n}\n\n__global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) {\n\tint i = blockIdx.x*blockDim.x + threadIdx.x;\n\tif(i>=n) return;\n\n\toutput[i]       = (input[i] & 0x03);\n\toutput[i + n]   = (input[i] & 0x0C) >> 2;\n\toutput[i + n*2] = (input[i] & 0x30) >> 4;\n\toutput[i + n*3] = (input[i] & 0xC0) >> 6;\n}\n\nstatic torch::Tensor unpack_2bit(const torch::Tensor& input){\n\n\tauto output = allocate_output(input, 2);\n\n    const auto numel = input.numel();\n\tint blocks = cdiv(numel, BLOCK_SIZE);\n\tunpack_2bit_kernel<<<blocks, BLOCK_SIZE>>>(\n        input.data_ptr<unsigned char>(),\n        output.data_ptr<unsigned char>(),\n        numel\n    );\n\n\tC10_CUDA_KERNEL_LAUNCH_CHECK();\n\n\treturn output;\n}\n\ntorch::Tensor unpack(torch::Tensor &t, int bits) {\n    TORCH_CHECK(t.scalar_type() == torch::kUInt8, \"Unsupported data type: \", t.scalar_type());\n    TORCH_CHECK(t.device().is_cuda(), \"t must be a CUDA tensor.\");\n    TORCH_CHECK(t.is_contiguous(), \"t must be contiguous.\");\n    switch(bits) {\n      case 4:\n        return unpack_4bit(t);\n      case 2:\n        return unpack_2bit(t);\n      default:\n        throw std::invalid_argument(\"Can only unpack 2-bit or 4-bit tensors.\");\n    }\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/hip/unpack.h",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n\ntorch::Tensor unpack(torch::Tensor &t, int bits);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/mps/README.md",
    "content": "# Quanto Metal Performance Shaders extension\n\nTo add a new implementation for an operation defined in `library./ops.py`:\n\n- add the corresponding `.mm` file to the list of sources in `__init__.py`,\n- add a binding to `pybind_module.cpp`,\n- provide an implementation calling the binding in `__init__.py`.\n\nNote: torch JIT extensions for MPS requires the xcode command-line tools.\n"
  },
  {
    "path": "optimum/quanto/library/extensions/mps/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\n\nfrom ..extension import Extension, register_extension\n\n\n__all__ = []\n\n\next = Extension(\n    \"quanto_mps\",\n    root_dir=os.path.dirname(__file__),\n    sources=[\"unpack.mm\", \"pybind_module.cpp\"],\n    extra_cflags=[\"-std=c++17\"],\n)\nregister_extension(ext)\n\n\n@torch.library.impl(\"quanto::unpack\", \"MPS\")\ndef unpack_mps(t: torch.Tensor, bits: int):\n    return ext.lib.unpack(t, bits)\n"
  },
  {
    "path": "optimum/quanto/library/extensions/mps/pybind_module.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include \"unpack.h\"\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"unpack\", &unpack, \"unpack\");\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/mps/unpack.h",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n\ntorch::Tensor unpack(const torch::Tensor &input, int bits);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/mps/unpack.mm",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"unpack.h\"\n#include <torch/extension.h>\n\n#import <Foundation/Foundation.h>\n#import <Metal/Metal.h>\n\n// Defines a Metal custom kernel to mask and shift a buffer element-wise.\nstatic char *MASK_AND_SHIFT = R\"MPS_MASK&SHIFT(\n#include <metal_stdlib>\nusing namespace metal;\n\n[[host_name(\"mask_and_rshift\")]]\nkernel void mask_and_rshift(constant uint8_t*     input  [[buffer(0)]],\n                            device   uint8_t*     output [[buffer(1)]],\n                            constant uint8_t&     mask   [[buffer(2)]],\n                            constant int&       shift  [[buffer(3)]],\n                            uint index [[thread_position_in_grid]]) {\n    output[index] = (input[index] & mask) >> shift;\n}\n\n)MPS_MASK&SHIFT\";\n\n// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.\nstatic inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {\n  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());\n}\n\ntorch::Tensor& mask_and_shift(const torch::Tensor& input, torch::Tensor& output, uint8_t mask, int shift) {\n    @autoreleasepool {\n        id<MTLDevice> device = MTLCreateSystemDefaultDevice();\n        NSError *error = nil;\n\n        // Set the number of threads equal to the number of elements within the input tensor.\n        int num_threads = input.numel();\n\n        // Load the custom mask and shift shader.\n        id<MTLLibrary> library = [device newLibraryWithSource:[NSString stringWithUTF8String:MASK_AND_SHIFT]\n                                  options:nil\n                                  error:&error];\n        TORCH_CHECK(library, \"Failed to to create custom kernel library, error: \", error.localizedDescription.UTF8String);\n\n        id<MTLFunction> kernel = [library newFunctionWithName:[NSString stringWithUTF8String:\"mask_and_rshift\"]];\n        TORCH_CHECK(kernel, \"Failed to create function state object for mask_and_rshift\");\n\n        // Create a compute pipeline state object for the soft shrink kernel.\n        id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:kernel error:&error];\n        TORCH_CHECK(pso, error.localizedDescription.UTF8String);\n\n        // This is required if torch already encoded something in the command buffer\n        torch::mps::synchronize();\n\n        // Get a reference to the command buffer for the MPS stream.\n        id<MTLCommandBuffer> command_buffer = torch::mps::get_command_buffer();\n        TORCH_CHECK(command_buffer, \"Failed to retrieve command buffer reference\");\n\n        // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU.\n        dispatch_queue_t serial_queue = torch::mps::get_dispatch_queue();\n\n        dispatch_sync(serial_queue, ^(){\n            // Start a compute pass.\n            id<MTLComputeCommandEncoder> compute_encoder = [command_buffer computeCommandEncoder];\n            TORCH_CHECK(compute_encoder, \"Failed to create compute command encoder\");\n\n            // Encode the pipeline state object and its parameters.\n            [compute_encoder setComputePipelineState:pso];\n            [compute_encoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0];\n            [compute_encoder setBuffer:getMTLBufferStorage(output) offset:output.storage_offset() * output.element_size() atIndex:1];\n            [compute_encoder setBytes:&mask length:sizeof(uint8_t) atIndex:2];\n            [compute_encoder setBytes:&shift length:sizeof(int) atIndex:3];\n\n            MTLSize grid_size = MTLSizeMake(num_threads, 1, 1);\n\n            // Calculate a thread group size.\n            NSUInteger thread_group_size = pso.maxTotalThreadsPerThreadgroup;\n            if (thread_group_size > num_threads) {\n                thread_group_size = num_threads;\n            }\n            MTLSize mtl_size = MTLSizeMake(thread_group_size, 1, 1);\n\n            // Encode the compute command.\n            [compute_encoder dispatchThreads:grid_size\n                      threadsPerThreadgroup:mtl_size];\n\n            [compute_encoder endEncoding];\n\n            // Commit the work.\n            torch::mps::commit();\n        });\n\n        torch::mps::synchronize();\n    }\n\n    return output;\n}\n\ntorch::Tensor unpack_4bit(const torch::Tensor &input) {\n\n    torch::Tensor output = torch::empty_like(input);\n    mask_and_shift(input, output, 0x0F, 0);\n    torch::Tensor output1 = torch::empty_like(input);\n    mask_and_shift(input, output1, 0xF0, 4);\n    return torch::cat({output, output1}, 0);\n}\n\ntorch::Tensor unpack_2bit(const torch::Tensor &input) {\n\n    torch::Tensor output = torch::empty_like(input);\n    mask_and_shift(input, output, 0x03, 0);\n    torch::Tensor output1 = torch::empty_like(input);\n    mask_and_shift(input, output1, 0x0C, 2);\n    torch::Tensor output2 = torch::empty_like(input);\n    mask_and_shift(input, output2, 0x30, 4);\n    torch::Tensor output3 = torch::empty_like(input);\n    mask_and_shift(input, output3, 0xC0, 6);\n    return torch::cat({output, output1, output2, output3}, 0);\n}\n\n// C++ op dispatching the Metal unpack operation.\ntorch::Tensor unpack(const torch::Tensor &input, int bits) {\n    // Check whether the input tensor resides on the MPS device and whether it's contiguous.\n    TORCH_CHECK(input.device().is_mps(), \"input must be a MPS tensor\");\n    TORCH_CHECK(input.is_contiguous(), \"input must be contiguous\");\n\n    // Check the supported data types for soft shrink.\n    TORCH_CHECK(input.scalar_type() == torch::kUInt8, \"Unsupported data type: \", input.scalar_type());\n\n    switch(bits) {\n      case 4:\n        return unpack_4bit(input);\n      case 2:\n        return unpack_2bit(input);\n      default:\n        throw std::invalid_argument(\"Can only unpack 2-bit or 4-bit tensors.\");\n    }\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/xpu/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n# Copyright 2024 Intel Corporation. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\n\nimport torch\nfrom packaging import version\n\nfrom ..extension import Extension, register_extension\n\n\n__all__ = []\n\n\nmodule_path = os.path.dirname(__file__)\nsources = [\n    \"unpack.sycl\",\n    \"pybind_module.cpp\",\n]\next = Extension(\n    \"quanto_xpu\",\n    root_dir=os.path.dirname(__file__),\n    sources=sources,\n)\nregister_extension(ext)\n\n\n@torch.library.impl(\"quanto::unpack\", \"XPU\")\ndef unpack_xpu(t: torch.Tensor, bits: int):\n    return ext.lib.unpack(t, bits)\n\n\nif version.parse(torch.__version__).release >= version.parse(\"2.8.0\").release:\n    torch.library.define(\n        \"quanto::gemm_f16i4_awq\",\n        \"(Tensor input,\"\n        \" Tensor other,\"\n        \" Tensor other_scale,\"\n        \" Tensor other_shift,\"\n        \" int rows,\"\n        \" int out_cols,\"\n        \" int in_cols,\"\n        \" int bits,\"\n        \" int group_size)\"\n        \" -> Tensor\",\n    )\n\n    @torch.library.impl(\"quanto::gemm_f16i4_awq\", \"XPU\")\n    def gemm_f16i4_awq(\n        input: torch.Tensor,\n        other: torch.Tensor,\n        scales: torch.Tensor,\n        shift: torch.Tensor,\n        rows: int,\n        out_cols: int,\n        in_cols: int,\n        bits: int,\n        group_size: int,\n    ):\n        orig_act_size = input.size()\n        orig_dtype = input.dtype\n\n        input = input.reshape(-1, input.shape[-1])\n\n        # XPU does not support float32 for now.\n        if input.dtype == torch.float32:\n            input = input.to(torch.bfloat16)\n        if scales.dtype != input.dtype:\n            scales = scales.to(input.dtype)\n\n        y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(input, other, group_size, scales, shift)\n        # remove out_feature padding\n        y = y[:, :out_cols]\n        y = y.reshape(*orig_act_size[:-1], out_cols)\n\n        return y.to(orig_dtype)\n"
  },
  {
    "path": "optimum/quanto/library/extensions/xpu/pybind_module.cpp",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include \"unpack.h\"\n\n// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types,\n// and need to be explicitly converted using dedicated helpers before calling a C++ method.\n// As a consequence, when an operation takes such an object as parameter, instead\n// of creating a binding directly to the C++ method, you must create a binding to a\n// lambda method that converts the unmapped types and calls the C++ method.\n// See the binding of quantize_symmetric for instance.\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"unpack\", &unpack, \"unpack\");\n}\n"
  },
  {
    "path": "optimum/quanto/library/extensions/xpu/unpack.h",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n\ntorch::Tensor unpack(torch::Tensor &t, int bits);\n"
  },
  {
    "path": "optimum/quanto/library/extensions/xpu/unpack.sycl",
    "content": "// Copyright 2024 The HuggingFace Team. All rights reserved.\n// Copyright 2024 Intel Corporation. All rights reserved.\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <torch/extension.h>\n#include <sycl/sycl.hpp>\n#include <c10/xpu/XPUStream.h>\n\n\ninline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}\n#define BLOCK_SIZE 256\n\nusing namespace at;\n\n\nstatic torch::Tensor allocate_output(const torch::Tensor& input, int bits) {\n    int n_packed = 8 / bits;\n    auto output_shape = input.sizes().vec();\n    output_shape[0] = output_shape[0] * n_packed;\n    return torch::empty(output_shape, input.options());\n}\n\nvoid unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n,\n                        const sycl::nd_item<3> &item_ct1) {\n    int i = item_ct1.get_group(2) * item_ct1.get_local_range(2) +\n            item_ct1.get_local_id(2);\n    if (i>=n) return;\n\n    output[i]     = (input[i] & 0x0F);\n    output[i + n] = (input[i] & 0xF0) >> 4;\n}\n\nclass Unpack4BitKrn {\npublic:\n    void operator()(sycl::nd_item<3> item_ct1) const {\n        unpack_4bit_kernel(ct0, ct1, numel, item_ct1);\n    }\n    Unpack4BitKrn(unsigned char* _ct0, unsigned char* _ct1, int64_t _numel):\n        ct0(_ct0),\n        ct1(_ct1),\n        numel(_numel)\n    {}\nprivate:\n    unsigned char* ct0;\n    unsigned char* ct1;\n    int64_t numel;\n};\n\nstatic torch::Tensor unpack_4bit(const torch::Tensor& input){\n    auto output = allocate_output(input, 4);\n\n    const auto numel = input.numel();\n    int blocks = cdiv(numel, BLOCK_SIZE);\n\n    sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();\n\n    auto krn = [&](sycl::handler &cgh) {\n        auto input_data_ptr_unsigned_char_ct0 =\n            input.data_ptr<unsigned char>();\n        auto output_data_ptr_unsigned_char_ct1 =\n            output.data_ptr<unsigned char>();\n\n\tUnpack4BitKrn krn2(input_data_ptr_unsigned_char_ct0, output_data_ptr_unsigned_char_ct1, numel);\n\n        cgh.parallel_for<Unpack4BitKrn>(\n            sycl::nd_range<3>(\n\t\tsycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, BLOCK_SIZE),\n                sycl::range<3>(1, 1, BLOCK_SIZE)),\n\t    krn2);\n    };\n    queue.submit(krn);\n    return output;\n}\n\nvoid unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n,\n                        const sycl::nd_item<3> &item_ct1) {\n    int i = item_ct1.get_group(2) * item_ct1.get_local_range(2) +\n            item_ct1.get_local_id(2);\n    if (i>=n) return;\n\n    output[i]       = (input[i] & 0x03);\n    output[i + n]   = (input[i] & 0x0C) >> 2;\n    output[i + n*2] = (input[i] & 0x30) >> 4;\n    output[i + n*3] = (input[i] & 0xC0) >> 6;\n}\n\nclass Unpack2BitKrn {\npublic:\n    void operator()(sycl::nd_item<3> item_ct1) const {\n        unpack_2bit_kernel(ct0, ct1, numel, item_ct1);\n    }\n    Unpack2BitKrn(unsigned char* _ct0, unsigned char* _ct1, int64_t _numel):\n        ct0(_ct0),\n        ct1(_ct1),\n        numel(_numel)\n    {}\nprivate:\n    unsigned char* ct0;\n    unsigned char* ct1;\n    int64_t numel;\n};\n\nstatic torch::Tensor unpack_2bit(const torch::Tensor& input){\n    auto output = allocate_output(input, 2);\n\n    const auto numel = input.numel();\n    int blocks = cdiv(numel, BLOCK_SIZE);\n    sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue();\n    auto krn = [&](sycl::handler &cgh) {\n        auto input_data_ptr_unsigned_char_ct0 =\n            input.data_ptr<unsigned char>();\n        auto output_data_ptr_unsigned_char_ct1 =\n            output.data_ptr<unsigned char>();\n\n\tUnpack2BitKrn krn2(input_data_ptr_unsigned_char_ct0, output_data_ptr_unsigned_char_ct1, numel);\n\n        cgh.parallel_for<Unpack2BitKrn>(\n            sycl::nd_range<3>(\n\t\tsycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, BLOCK_SIZE),\n                sycl::range<3>(1, 1, BLOCK_SIZE)),\n\t    krn2);\n    };\n    queue.submit(krn);\n    return output;\n}\n\ntorch::Tensor unpack(torch::Tensor &t, int bits) {\n    TORCH_CHECK(t.scalar_type() == torch::kUInt8, \"Unsupported data type: \", t.scalar_type());\n    TORCH_CHECK(t.device().is_xpu(), \"t must be a XPU  tensor.\");\n    TORCH_CHECK(t.is_contiguous(), \"t must be contiguous.\");\n    switch(bits) {\n      case 4:\n        return unpack_4bit(t);\n      case 2:\n        return unpack_2bit(t);\n      default:\n        throw std::invalid_argument(\"Can only unpack 2-bit or 4-bit tensors.\");\n    }\n}\n"
  },
  {
    "path": "optimum/quanto/library/qbytes_mm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom packaging import version\n\n\n__all__ = []\n\n\ntorch.library.define(\"quanto::qbytes_mm\", \"(Tensor A, Tensor B, Tensor scales) -> Tensor\")\n\n\ndef qbytes_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    activations = activations.to(output_scales.dtype)\n    if weights.dtype.is_floating_point:\n        # Float8 requires an explicit promotion\n        weights = weights.to(output_scales.dtype)\n    # Apply the scale to the weights before the matrix multiplication to put them back\n    # into their initial numerical range and avoid overflows\n    weights = output_scales * weights\n    return torch.matmul(activations, weights.t())\n\n\ndef qbytes_int_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    in_features = activations.shape[-1]\n    out_features = weights.shape[0]\n    # torch._int_mm works on transposed weights, i.e (in_features, out_features)\n    weights = weights.t()\n    if activations.ndim == 2:\n        out_data = torch._int_mm(activations, weights)\n    else:\n        output_shape = activations.shape[:-1] + (out_features,)\n        out_data = torch._int_mm(activations.reshape(-1, in_features), weights)\n        out_data = out_data.reshape(output_shape)\n    # We must evaluate the output as float32 because the multiplication\n    # of the int32 data by the scales might overflow\n    fp32_output = out_data.to(torch.float32) * output_scales.t()\n    return fp32_output.to(output_scales.dtype)\n\n\ndef qbytes_int8pack_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    # torch._weight_int8pack_mm expects a vector of scales\n    output_scales = output_scales.flatten()\n    if activations.ndim == 2:\n        return torch._weight_int8pack_mm(activations, weights, output_scales)\n    else:\n        in_features = activations.shape[-1]\n        out_features = weights.shape[0]\n        output_shape = activations.shape[:-1] + (out_features,)\n        out_data = torch._weight_int8pack_mm(activations.reshape(-1, in_features), weights, output_scales)\n        return out_data.reshape(output_shape)\n\n\n@torch.library.impl(\"quanto::qbytes_mm\", \"default\")\ndef qbytes_mm_impl_default(\n    activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor\n) -> torch.Tensor:\n    return qbytes_mm(activations, weights, output_scales)\n\n\n@torch.library.impl(\"quanto::qbytes_mm\", \"CUDA\")\ndef qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    assert activations.ndim in (2, 3)\n    in_features = activations.shape[-1]\n    tokens = activations.shape[0] if activations.ndim == 2 else activations.shape[0] * activations.shape[1]\n    out_features = weights.shape[0]\n    if (\n        activations.dtype == torch.int8\n        and weights.dtype == torch.int8\n        and tokens > 16\n        and tokens % 8 == 0\n        and in_features % 8 == 0\n        and out_features % 8 == 0\n    ):\n        return qbytes_int_mm(activations, weights, output_scales)\n    return qbytes_mm(activations, weights, output_scales)\n\n\n@torch.library.impl(\"quanto::qbytes_mm\", \"CPU\")\ndef qbytes_mm_impl_cpu(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    if (\n        # FIXME: accuracy issues with 2.4.x\n        version.parse(torch.__version__).release >= version.parse(\"2.6.0\").release\n        and activations.dtype == torch.int8\n        and weights.dtype == torch.int8\n    ):\n        return qbytes_int_mm(activations, weights, output_scales)\n    in_features = activations.shape[-1]\n    if activations.dtype == torch.bfloat16 and weights.dtype == torch.int8 and in_features % 4 == 0:\n        if type(activations) is not torch.Tensor:\n            activations = activations.dequantize()\n        return qbytes_int8pack_mm(activations, weights, output_scales)\n    return qbytes_mm(activations, weights, output_scales)\n\n\n@torch.library.impl(\"quanto_py::qbytes_mm\", \"MPS\")\ndef qbytes_mm_impl_mps(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:\n    in_features = activations.shape[-1]\n    out_features = weights.shape[0]\n    if (\n        version.parse(torch.__version__).release >= version.parse(\"2.4.0\").release\n        and activations.dtype == torch.bfloat16\n        and weights.dtype == torch.int8\n        and in_features % 32 == 0\n        and out_features % 32 == 0\n    ):\n        if type(activations) is not torch.Tensor:\n            activations = activations.dequantize()\n        return qbytes_int8pack_mm(activations, weights, output_scales)\n    return qbytes_mm(activations, weights, output_scales)\n"
  },
  {
    "path": "optimum/quanto/library/quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Union\n\nimport torch\n\nfrom ..tensor import dtype_info, group\n\n\ntorch.library.define(\n    \"quanto::quantize_symmetric\", \"(Tensor base, ScalarType dtype, int? axis, Tensor scale) -> Tensor\"\n)\n\n\n@torch.library.impl(\"quanto::quantize_symmetric\", \"default\")\ndef quantize_symmetric(\n    base: torch.Tensor, dtype: torch.dtype, axis: Union[int, None], scale: torch.Tensor\n) -> torch.Tensor:\n    # Sanity checks\n    if axis is None:\n        if scale.ndim > 0:\n            raise ValueError(\"Scale must be a scalar when quantizing per-tensor\")\n    else:\n        if base.ndim == 1:\n            raise ValueError(\"1D Tensors cannot be quantized per-axis\")\n        if axis == base.ndim - 1:\n            # Align on the general convention to index the last dimension\n            axis = -1\n        if axis not in (0, -1):\n            raise ValueError(\"Quantization is only supported along the first or last axis.\")\n        if base.shape[axis] == 1:\n            raise ValueError(f\"Cannot quantize Tensor of shape {base.shape} along axis {axis} of size 1\")\n        if torch.squeeze(scale).ndim > 1:\n            raise ValueError(\"Quantizing along multiple axis is not supported\")\n        if scale.ndim != base.ndim:\n            raise ValueError(\n                \"When quantizing per-axis, the scale must be broadcastable to the base (Tip: try to add missing dims of length zero).\"\n            )\n    data = base / scale\n    if not dtype.is_floating_point:\n        data = torch.round(data)\n    info = dtype_info(dtype)\n    return torch.clamp(data, min=info.min, max=info.max).to(dtype)\n\n\ntorch.library.define(\n    \"quanto::quantize_affine\",\n    \"(Tensor base, int bits, int axis, int? group_size, Tensor scale, Tensor shift) -> Tensor\",\n)\n\n\n@torch.library.impl(\"quanto::quantize_affine\", \"default\")\ndef quantize_affine(\n    base: torch.Tensor, bits: int, axis: int, group_size: Union[int, None], scale: torch.Tensor, shift: torch.Tensor\n) -> torch.Tensor:\n    if axis not in (0, -1):\n        raise ValueError(\"axis parameter must be 0 (first axis) or -1 (last axis)\")\n    if group_size is not None:\n        base = group(base, axis=axis, group_size=group_size)\n    if shift.dtype.is_floating_point:\n        data = torch.round((base + shift) / scale)\n    else:\n        # Shift is an integer representing zero (i.e. zero-point)\n        data = torch.round(base / scale) + shift\n\n    return torch.clamp(data, min=0, max=2**bits - 1).to(torch.uint8)\n"
  },
  {
    "path": "optimum/quanto/library/unpack.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\n\ntorch.library.define(\"quanto::unpack\", \"(Tensor self, int bits) -> Tensor\")\n\n\n@torch.library.impl(\"quanto::unpack\", \"default\")\ndef unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:\n    \"\"\"\n    Un-Pack int4 / int2 weights (packed in a uint8) into a torch.uint8 tensor\n    What un-packing means? Assume we have packed 4 2-bit values in 8-bit\n    (because torch does not have native support for 2-bit datatypes)\n\n    > 1110 0100\n\n    Unpacking them means retrieving the original 4 2-bit values:\n\n    > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000\n\n    Args:\n        packed (`torch.Tensor`):\n            The packed tensor in `torch.uint8` precision\n        bits (`int`):\n            The number of bits per encoded value. Can be 2 or 4.\n    \"\"\"\n    unpacked = []\n    values_per_item = 8 // bits\n\n    def rshift(t: torch.Tensor, bits: int):\n        if t.device.type == \"mps\":\n            # rshift is not supported on MPS device\n            return t // (2**bits)\n        return t >> bits\n\n    # Unpack each set of values independently\n    for i in range(values_per_item):\n        mask = 2 ** (bits * (i + 1)) - 1\n        unpacked.append(rshift(packed & mask, bits * i))\n    # Return the concatenated unpacked tensors\n    return torch.cat(unpacked).to(torch.uint8)\n"
  },
  {
    "path": "optimum/quanto/models/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\nimport os\nfrom collections.abc import Mapping\nfrom typing import Any, Dict, List, Optional, Union\n\n\ndef is_transformers_available() -> bool:\n    return importlib.util.find_spec(\"transformers\") is not None\n\n\ndef is_diffusers_available() -> bool:\n    return importlib.util.find_spec(\"diffusers\") is not None\n\n\nif is_transformers_available():\n    from .transformers_models import *\n\n\nif is_diffusers_available():\n    from .diffusers_models import *\n"
  },
  {
    "path": "optimum/quanto/models/diffusers_models.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Any, List, Optional, Union\n\nfrom huggingface_hub import ModelHubMixin, snapshot_download\n\nfrom ..quantize import Optimizer, freeze, qtype, quantization_map, quantize, requantize\nfrom . import is_diffusers_available\n\n\n__all__ = [\"QuantizedDiffusersModel\", \"QuantizedPixArtTransformer2DModel\"]\n\nif not is_diffusers_available():\n    raise ImportError(f\"{__all__} require the diffusers library\")\n\nfrom diffusers import PixArtTransformer2DModel\nfrom diffusers.models.model_loading_utils import load_state_dict\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.utils import (\n    CONFIG_NAME,\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFETENSORS_WEIGHTS_NAME,\n    _get_checkpoint_shard_files,\n    is_accelerate_available,\n)\n\nfrom .shared_dict import ShardedStateDict\n\n\nclass QuantizedDiffusersModel(ModelHubMixin):\n    BASE_NAME = \"quanto\"\n    base_class = None\n\n    def __init__(self, model: ModelMixin):\n        if not isinstance(model, ModelMixin) or len(quantization_map(model)) == 0:\n            raise ValueError(\"The source model must be a quantized diffusers model.\")\n        self._wrapped = model\n\n    def __getattr__(self, name: str) -> Any:\n        \"\"\"If an attribute is not found in this class, look in the wrapped module.\"\"\"\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            wrapped = self.__dict__[\"_wrapped\"]\n            return getattr(wrapped, name)\n\n    def forward(self, *args, **kwargs):\n        return self._wrapped.forward(*args, **kwargs)\n\n    def __call__(self, *args, **kwargs):\n        return self._wrapped.forward(*args, **kwargs)\n\n    @staticmethod\n    def _qmap_name():\n        return f\"{QuantizedDiffusersModel.BASE_NAME}_qmap.json\"\n\n    @classmethod\n    def quantize(\n        cls,\n        model: ModelMixin,\n        weights: Optional[Union[str, qtype]] = None,\n        activations: Optional[Union[str, qtype]] = None,\n        optimizer: Optional[Optimizer] = None,\n        include: Optional[Union[str, List[str]]] = None,\n        exclude: Optional[Union[str, List[str]]] = None,\n    ):\n        \"\"\"Quantize the specified model\n\n        By default, each layer of the model will be quantized if is quantizable.\n\n        If include patterns are specified, the layer name must match one of them.\n\n        If exclude patterns are specified, the layer must not match one of them.\n\n        Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See\n        https://docs.python.org/3/library/fnmatch.html for more details.\n\n        Note: quantization happens in-place and modifies the original model.\n\n        Note that the resulting quantized model will be frozen: if you wish to do\n        quantization-aware training then you should use `optimum.quanto.quantize` instead,\n        and call `optimum.quanto.freeze` only after the training.\n\n        Args:\n            model (`PreTrainedModel`): the model to quantize.\n            weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.\n            activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.\n            include (`Optional[Union[str, List[str]]]`):\n                Patterns constituting the allowlist. If provided, layer names must match at\n                least one pattern from the allowlist.\n            exclude (`Optional[Union[str, List[str]]]`):\n                Patterns constituting the denylist. If provided, layer names must not match\n                any patterns from the denylist.\n        \"\"\"\n        if not isinstance(model, ModelMixin):\n            raise ValueError(\"The source model must be a diffusers model.\")\n\n        quantize(\n            model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude\n        )\n        freeze(model)\n        return cls(model)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):\n        if cls.base_class is None:\n            raise ValueError(\"The `base_class` attribute needs to be configured.\")\n\n        if not is_accelerate_available():\n            raise ValueError(\"Reloading a quantized diffusers model requires the accelerate library.\")\n        from accelerate import init_empty_weights\n\n        if os.path.isdir(pretrained_model_name_or_path):\n            working_dir = pretrained_model_name_or_path\n        else:\n            working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs)\n\n        # Look for a quantization map\n        qmap_path = os.path.join(working_dir, cls._qmap_name())\n        if not os.path.exists(qmap_path):\n            raise ValueError(\n                f\"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?\"\n            )\n\n        # Look for original model config file.\n        model_config_path = os.path.join(working_dir, CONFIG_NAME)\n        if not os.path.exists(model_config_path):\n            raise ValueError(f\"{CONFIG_NAME} not found in {pretrained_model_name_or_path}.\")\n\n        with open(qmap_path, \"r\", encoding=\"utf-8\") as f:\n            qmap = json.load(f)\n\n        with open(model_config_path, \"r\", encoding=\"utf-8\") as f:\n            original_model_cls_name = json.load(f)[\"_class_name\"]\n        configured_cls_name = cls.base_class.__name__\n        if configured_cls_name != original_model_cls_name:\n            raise ValueError(\n                f\"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name}).\"\n            )\n\n        # Create an empty model\n        config = cls.base_class.load_config(pretrained_model_name_or_path, **kwargs)\n        with init_empty_weights():\n            model = cls.base_class.from_config(config)\n\n        # Look for the index of a sharded checkpoint\n        checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME)\n        if os.path.exists(checkpoint_file):\n            # Convert the checkpoint path to a list of shards\n            _, sharded_metadata = _get_checkpoint_shard_files(working_dir, checkpoint_file)\n            # Create a mapping for the sharded safetensor files\n            state_dict = ShardedStateDict(working_dir, sharded_metadata[\"weight_map\"])\n        else:\n            # Look for a single checkpoint file\n            checkpoint_file = os.path.join(working_dir, SAFETENSORS_WEIGHTS_NAME)\n            if not os.path.exists(checkpoint_file):\n                raise ValueError(f\"No safetensor weights found in {pretrained_model_name_or_path}.\")\n            # Get state_dict from model checkpoint\n            state_dict = load_state_dict(checkpoint_file)\n\n        # Requantize and load quantized weights from state_dict\n        requantize(model, state_dict=state_dict, quantization_map=qmap)\n        model.eval()\n        return cls(model)\n\n    def _save_pretrained(self, save_directory: Path) -> None:\n        self._wrapped.save_pretrained(save_directory)\n        # Save quantization map to be able to reload the model\n        qmap_name = os.path.join(save_directory, self._qmap_name())\n        qmap = quantization_map(self._wrapped)\n        with open(qmap_name, \"w\", encoding=\"utf8\") as f:\n            json.dump(qmap, f, indent=4)\n\n\nclass QuantizedPixArtTransformer2DModel(QuantizedDiffusersModel):\n    base_class = PixArtTransformer2DModel\n"
  },
  {
    "path": "optimum/quanto/models/shared_dict.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom collections.abc import Mapping\nfrom typing import Any, Dict\n\nfrom safetensors import safe_open\n\n\nclass ShardedStateDict(Mapping):\n    \"\"\"A pytorch state_dict stored in multiple safetensors files\n\n    This class implements the `collections.abc.Mapping` interface.\n    It can be passed to `torch.nn.Module.load_state_dict()` to recursively\n    load the module tensors.\n    \"\"\"\n\n    def __init__(self, base_dir: str, tensor_index: Dict[str, str]):\n        self._base_dir = base_dir\n        self._index = tensor_index\n        self._handles = {}\n\n    def __iter__(self):\n        yield from self._index\n\n    def __len__(self):\n        return self._index.__len__()\n\n    def __getitem__(self, key: Any) -> Any:\n        filename = self._index.__getitem__(key)\n        if filename not in self._handles:\n            f = safe_open(os.path.join(self._base_dir, filename), framework=\"pytorch\")\n            self._handles[filename] = f\n        f = self._handles[filename]\n        return f.get_tensor(key)\n\n    def __contains__(self, key: object) -> bool:\n        return self._index.__contains__(key)\n\n    def keys(self):\n        return self._index.keys()\n"
  },
  {
    "path": "optimum/quanto/models/transformers_models.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport os\nfrom pathlib import Path\nfrom typing import Any, List, Optional, Union\n\nfrom huggingface_hub import ModelHubMixin, snapshot_download\n\nfrom ..nn import QModuleMixin\nfrom ..quantize import Optimizer, freeze, qtype, quantization_map, quantize, requantize\nfrom . import is_transformers_available\nfrom .shared_dict import ShardedStateDict\n\n\n__all__ = [\"QuantizedTransformersModel\", \"QuantizedModelForCausalLM\"]\n\nif not is_transformers_available():\n    raise ImportError(f\"{__all__} require the transformers library\")\n\nfrom transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel\nfrom transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict\nfrom transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available\n\n\nclass QuantizedTransformersModel(ModelHubMixin):\n    BASE_NAME = \"quanto\"\n    auto_class = None\n\n    def __init__(self, model: PreTrainedModel):\n        if not isinstance(model, PreTrainedModel) or len(quantization_map(model)) == 0:\n            raise ValueError(\"The source model must be a quantized transformers model.\")\n        self._wrapped = model\n\n    def __getattr__(self, name: str) -> Any:\n        \"\"\"If an attribute is not found in this class, look in the wrapped module.\"\"\"\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            wrapped = self.__dict__[\"_wrapped\"]\n            return getattr(wrapped, name)\n\n    def forward(self, *args, **kwargs):\n        return self._wrapped.forward(*args, **kwargs)\n\n    def __call__(self, *args, **kwargs):\n        return self._wrapped.forward(*args, **kwargs)\n\n    def __repr__(self):\n        return self._wrapped.__repr__()\n\n    @staticmethod\n    def _qmap_name():\n        return f\"{QuantizedTransformersModel.BASE_NAME}_qmap.json\"\n\n    @classmethod\n    def quantize(\n        cls,\n        model: PreTrainedModel,\n        weights: Optional[Union[str, qtype]] = None,\n        activations: Optional[Union[str, qtype]] = None,\n        optimizer: Optional[Optimizer] = None,\n        include: Optional[Union[str, List[str]]] = None,\n        exclude: Optional[Union[str, List[str]]] = None,\n    ):\n        \"\"\"Quantize the specified model\n\n        By default, each layer of the model will be quantized if is quantizable.\n\n        If include patterns are specified, the layer name must match one of them.\n\n        If exclude patterns are specified, the layer must not match one of them.\n\n        Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See\n        https://docs.python.org/3/library/fnmatch.html for more details.\n\n        Note: quantization happens in-place and modifies the original model.\n\n        Note that the resulting quantized model will be frozen: if you wish to do\n        quantization-aware training then you should use `optimum.quanto.quantize` instead,\n        and call `optimum.quanto.freeze` only after the training.\n\n        Args:\n            model (`PreTrainedModel`): the model to quantize.\n            weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.\n            activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.\n            include (`Optional[Union[str, List[str]]]`):\n                Patterns constituting the allowlist. If provided, layer names must match at\n                least one pattern from the allowlist.\n            exclude (`Optional[Union[str, List[str]]]`):\n                Patterns constituting the denylist. If provided, layer names must not match\n                any patterns from the denylist.\n        \"\"\"\n        if not isinstance(model, PreTrainedModel):\n            raise ValueError(\"The source model must be a transformers model.\")\n        quantize(\n            model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude\n        )\n        freeze(model)\n        return cls(model)\n\n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):\n        if cls.auto_class is None:\n            raise ValueError(\n                \"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead.\"\n            )\n        if not is_accelerate_available():\n            raise ValueError(\"Reloading a quantized transformers model requires the accelerate library.\")\n        from accelerate import init_empty_weights\n\n        if os.path.isdir(pretrained_model_name_or_path):\n            working_dir = pretrained_model_name_or_path\n        else:\n            working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs)\n\n        # Look for a quantization map\n        qmap_path = os.path.join(working_dir, cls._qmap_name())\n        if not os.path.exists(qmap_path):\n            raise ValueError(\n                f\"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?\"\n            )\n        with open(qmap_path, \"r\", encoding=\"utf-8\") as f:\n            qmap = json.load(f)\n        # Create an empty model\n        config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)\n        with init_empty_weights():\n            model = cls.auto_class.from_config(config)\n        # Look for the index of a sharded checkpoint\n        checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME)\n        if os.path.exists(checkpoint_file):\n            # Convert the checkpoint path to a list of shards\n            checkpoint_file, sharded_metadata = get_checkpoint_shard_files(working_dir, checkpoint_file)\n            # Create a mapping for the sharded safetensor files\n            state_dict = ShardedStateDict(working_dir, sharded_metadata[\"weight_map\"])\n        else:\n            # Look for a single checkpoint file\n            checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_NAME)\n            if not os.path.exists(checkpoint_file):\n                raise ValueError(f\"No safetensor weights found in {pretrained_model_name_or_path}.\")\n            # Get state_dict from model checkpoint\n            state_dict = load_state_dict(checkpoint_file)\n        # Requantize and load quantized weights from state_dict\n        requantize(model, state_dict=state_dict, quantization_map=qmap)\n        if getattr(model.config, \"tie_word_embeddings\", True):\n            # Tie output weight embeddings to input weight embeddings\n            # Note that if they were quantized they would NOT be tied\n            model.tie_weights()\n        # Set model in evaluation mode as it is done in transformers\n        model.eval()\n        return cls(model)\n\n    def _save_pretrained(self, save_directory: Path) -> None:\n        model = self._wrapped\n        if getattr(model.config, \"tie_word_embeddings\", True):\n            # The original model had tied embedding inputs and outputs\n            if isinstance(model.get_input_embeddings(), QModuleMixin) or isinstance(\n                model.get_output_embeddings(), QModuleMixin\n            ):\n                # At least one of the two is quantized, so they are not tied anymore\n                model.config.tie_word_embeddings = False\n        self._wrapped.save_pretrained(save_directory, safe_serialization=True)\n        # Save quantization map to be able to reload the model\n        qmap_name = os.path.join(save_directory, self._qmap_name())\n        qmap = quantization_map(self._wrapped)\n        with open(qmap_name, \"w\", encoding=\"utf8\") as f:\n            json.dump(qmap, f, indent=4)\n\n\nclass QuantizedModelForCausalLM(QuantizedTransformersModel):\n    auto_class = AutoModelForCausalLM\n"
  },
  {
    "path": "optimum/quanto/nn/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .qconv2d import *\nfrom .qlayernorm import *\nfrom .qlinear import *\nfrom .qmodule import *\n"
  },
  {
    "path": "optimum/quanto/nn/qconv2d.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\nfrom ..tensor import Optimizer, qtype\nfrom .qmodule import QModuleMixin, register_qmodule\n\n\n__all__ = [\"QConv2d\"]\n\n\n@register_qmodule(torch.nn.Conv2d)\nclass QConv2d(QModuleMixin, torch.nn.Conv2d):\n    @classmethod\n    def qcreate(\n        cls,\n        module,\n        weights: qtype,\n        activations: Optional[qtype] = None,\n        optimizer: Optional[Optimizer] = None,\n        device: Optional[torch.device] = None,\n    ):\n        return cls(\n            in_channels=module.in_channels,\n            out_channels=module.out_channels,\n            kernel_size=module.kernel_size,\n            stride=module.stride,\n            padding=module.padding,\n            dilation=module.dilation,\n            groups=module.groups,\n            bias=module.bias is not None,\n            padding_mode=module.padding_mode,\n            dtype=module.weight.dtype,\n            device=device,\n            weights=weights,\n            activations=activations,\n            optimizer=optimizer,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return self._conv_forward(input, self.qweight, self.bias)\n"
  },
  {
    "path": "optimum/quanto/nn/qlayernorm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\nfrom ..tensor import Optimizer, qtype\nfrom .qmodule import QModuleMixin, register_qmodule\n\n\n__all__ = [\"QLayerNorm\"]\n\n\n@register_qmodule(torch.nn.LayerNorm)\nclass QLayerNorm(QModuleMixin, torch.nn.LayerNorm):\n    @classmethod\n    def qcreate(\n        cls,\n        module,\n        weights: Optional[qtype] = None,\n        activations: Optional[qtype] = None,\n        optimizer: Optional[Optimizer] = None,\n        device: Optional[torch.device] = None,\n    ):\n        if activations is None:\n            return None\n        dtype = None if module.weight is None else module.weight.dtype\n        return cls(\n            module.normalized_shape,\n            module.eps,\n            module.elementwise_affine,\n            module.bias is not None,\n            dtype=dtype,\n            device=device,\n            weights=None,  # We never quantize QLayerNorm weights\n            activations=activations,\n            optimizer=None,  # We never quantize QLayerNorm weights\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)\n"
  },
  {
    "path": "optimum/quanto/nn/qlinear.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\nfrom ..tensor import Optimizer, qtype\nfrom .qmodule import QModuleMixin, register_qmodule\n\n\n__all__ = [\"QLinear\"]\n\n\n@register_qmodule(torch.nn.Linear)\nclass QLinear(QModuleMixin, torch.nn.Linear):\n    @classmethod\n    def qcreate(\n        cls,\n        module,\n        weights: qtype,\n        activations: Optional[qtype] = None,\n        optimizer: Optional[Optimizer] = None,\n        device: Optional[torch.device] = None,\n    ):\n        return cls(\n            module.in_features,\n            module.out_features,\n            module.bias is not None,\n            dtype=module.weight.dtype,\n            device=device,\n            weights=weights,\n            activations=activations,\n            optimizer=optimizer,\n            quantize_input=True,\n        )\n\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        return torch.nn.functional.linear(input, self.qweight, bias=self.bias)\n"
  },
  {
    "path": "optimum/quanto/nn/qmodule.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom abc import ABC\nfrom typing import Optional, Union\n\nimport torch\n\nfrom ..tensor import (\n    AbsmaxOptimizer,\n    ActivationQBytesTensor,\n    MaxOptimizer,\n    Optimizer,\n    QTensor,\n    SymmetricOptimizer,\n    WeightQBitsTensor,\n    WeightQBytesTensor,\n    qint2,\n    qint4,\n    qtype,\n    qtypes,\n    quantize_activation,\n    quantize_weight,\n)\n\n\n__all__ = [\"QModuleMixin\", \"register_qmodule\", \"quantize_module\"]\n\n\n_QMODULE_TABLE = {}\n\n\ndef register_qmodule(module_cls):\n    \"\"\"\n    Used for registering a new quantized module.\n\n    The QModule must implement two abstract methods:\n\n    - qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights,\n    - forward: instance method for quantized inference.\n\n    The code to register a new module looks like:\n\n    ```\n    @register_qmodule(<base torch.nn.Module>)\n    class MyQModule(QModuleMixin, <base torch.nn.Module>):\n        <implementation>\n\n        @classmethod\n        def qcreate(cls,\n                    module: torch.nn.Module,\n                    weights: Optional[qtype],\n                    activations: Optional[qtype] = None,\n                    optimizer: Optional[Optimizer] = None):\n            ...\n\n        def forward(self, input: torch.Tensor) -> torch.Tensor:\n            ...\n    ```\n\n    \"\"\"\n\n    def wrapper(cls):\n        _QMODULE_TABLE[module_cls] = cls\n        return cls\n\n    return wrapper\n\n\ndef quantize_module(\n    module,\n    weights: Optional[Union[qtype, str]] = None,\n    activations: Optional[Union[qtype, str]] = None,\n    optimizer: Optional[Optimizer] = None,\n):\n    for cls in _QMODULE_TABLE:\n        if isinstance(module, cls):\n            qcls = _QMODULE_TABLE[cls]\n            return qcls.from_module(module, weights=weights, activations=activations, optimizer=optimizer)\n    return None\n\n\nclass QModuleMixin(ABC):\n    def __init__(\n        self,\n        *args,\n        weights: Optional[Union[qtype, str]] = None,\n        activations: Optional[Union[qtype, str]] = None,\n        optimizer: Optional[Optimizer] = None,\n        quantize_input: Optional[bool] = False,\n        device: Optional[torch.device] = None,\n        **kwargs,\n    ):\n        # The tests below are meant to help people writing their own quantized Module class\n        mro = self.__class__.__mro__\n        if torch.nn.Module not in mro:\n            raise TypeError(\"Quantized modules must inherit from a torch.nn.Module class\")\n        if mro.index(__class__) > mro.index(torch.nn.Module):\n            raise TypeError(\n                \"QModuleMixin must be placed before any torch.nn.Module class in quantized module inheritance.\"\n            )\n        # This will setup the torch.nn.Module\n        super().__init__(*args, device=device, **kwargs)\n        if weights is not None and not isinstance(weights, qtype):\n            weights = qtypes[weights]\n        if activations is not None and not isinstance(activations, qtype):\n            activations = qtypes[activations]\n        self.weight_qtype = weights\n        self.weight_group_size = None\n        if self.weight_qtype in (qint2, qint4):\n            out_features = self.weight.shape[0]\n            in_features = self.weight.numel() // out_features\n            group_size = 128\n            if in_features > group_size:\n                while in_features % group_size != 0 and group_size > 32:\n                    group_size -= 32\n                if in_features % group_size == 0:\n                    self.weight_group_size = group_size\n        self.activation_qtype = activations\n        self._quantize_hooks = {}\n        if activations is not None:\n            if quantize_input:\n                self._quantize_hooks[\"input\"] = self.register_forward_pre_hook(self.quantize_input)\n            self._quantize_hooks[\"output\"] = self.register_forward_hook(self.quantize_output)\n        if optimizer is None and self.weight_qtype is not None:\n            optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer()\n        self.optimizer = optimizer\n        scale_dtype = torch.float32 if self.weight is None else self.weight.dtype\n        self.register_buffer(\"input_scale\", torch.ones((), dtype=scale_dtype, device=device))\n        self.register_buffer(\"output_scale\", torch.ones((), dtype=scale_dtype, device=device))\n\n    def disable_output_quantization(self):\n        if \"output\" in self._quantize_hooks:\n            self._quantize_hooks[\"output\"].remove()\n\n    def _save_to_state_dict(self, destination, prefix, keep_vars):\n        if self.weight_qtype is None or not self.frozen:\n            # Save standard weight Tensor\n            destination[prefix + \"weight\"] = (\n                self.weight if (self.weight is None or keep_vars) else self.weight.detach()\n            )\n        else:\n            # Save QTensor using dedicated method\n            self.weight.save_to_state_dict(destination, prefix + \"weight.\", keep_vars)\n        if self.bias is not None:\n            destination[prefix + \"bias\"] = self.bias if keep_vars else self.bias.detach()\n        destination[prefix + \"input_scale\"] = self.input_scale if keep_vars else self.input_scale.detach()\n        destination[prefix + \"output_scale\"] = self.output_scale if keep_vars else self.output_scale.detach()\n\n    def _load_from_state_dict(\n        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs\n    ):\n        weight_name = prefix + \"weight\"\n        if self.weight_qtype is not None and weight_name not in state_dict:\n            # The weight Tensor is not present because it is a flattened QTensor\n            weight_prefix = weight_name + \".\"\n            # note: deserialized_weight can be None if a key is missing in the state_dict\n            if self.weight_qtype.bits == 8:\n                deserialized_weight = WeightQBytesTensor.load_from_state_dict(\n                    state_dict,\n                    weight_prefix,\n                    qtype=self.weight_qtype,\n                    axis=0,\n                    size=self.weight.size(),\n                    stride=self.weight.stride(),\n                    activation_qtype=self.activation_qtype,\n                    missing_keys=missing_keys,\n                )\n            else:\n                deserialized_weight = WeightQBitsTensor.load_from_state_dict(\n                    state_dict,\n                    weight_prefix,\n                    qtype=self.weight_qtype,\n                    axis=0,\n                    group_size=self.weight_group_size,\n                    size=self.weight.size(),\n                    stride=self.weight.stride(),\n                    missing_keys=missing_keys,\n                )\n            if deserialized_weight is not None:\n                deserialized_weight = deserialized_weight.optimize()\n\n            assign_to_params_buffers = local_metadata.get(\"assign_to_params_buffers\", False)\n            if assign_to_params_buffers and (deserialized_weight is not None):\n                self.weight = torch.nn.Parameter(deserialized_weight)\n            elif deserialized_weight is not None:\n                if type(self.weight.data) is not type(deserialized_weight):\n                    # Reloading frozen weights into unfrozen module: move to the correct device and force assignment\n                    self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))\n                else:\n                    # FIXME: here we should copy frozen weights into frozen module, but this leads to grad error\n                    self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device))\n\n        super()._load_from_state_dict(\n            state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs\n        )\n\n    @classmethod\n    def from_module(\n        cls,\n        module: torch.nn.Module,\n        weights: Optional[qtype] = None,\n        activations: Optional[qtype] = None,\n        optimizer: Optional[Optimizer] = None,\n    ):\n        # Create the quantized module on the meta device to prevent weights intialization\n        qmodule = cls.qcreate(module, weights, activations, optimizer, device=\"meta\")\n        if qmodule is None:\n            return None\n        # Move the quantized module to the target device, but with empty weights\n        device = torch.device(\"cpu\") if module.weight is None else module.weight.device\n        qmodule = qmodule.to_empty(device=device)\n        # Set scales that were initialized to empty values\n        qmodule.input_scale = torch.ones_like(qmodule.input_scale)\n        qmodule.output_scale = torch.ones_like(qmodule.output_scale)\n        with torch.no_grad():\n            qmodule.weight = module.weight\n            if module.bias is not None:\n                qmodule.bias = module.bias\n\n        return qmodule.to(device)\n\n    @classmethod\n    def qcreate(\n        cls,\n        module: torch.nn.Module,\n        weights: Optional[qtype],\n        activations: Optional[qtype] = None,\n        optimizer: Optional[Optimizer] = None,\n        device: Optional[torch.device] = None,\n    ):\n        raise NotImplementedError\n\n    @property\n    def qweight(self):\n        \"\"\"Return the module quantized weight\n\n        When the module is frozen or does not quantize its weight parameter, it simply\n        returns the weight.\n        When the module is not frozen, this property is required to add the dynamic quantization\n        of the weight parameter to the graph and allow gradients to be propagated to the\n        underlying weight float values.\n        \"\"\"\n        if self.weight_qtype is None:\n            # QModule that does not quantize its weights\n            return None\n        if isinstance(self.weight, QTensor):\n            # Frozen QModule\n            return self.weight\n        # Quantize dynamically the weights per-axis\n        if isinstance(self.optimizer, SymmetricOptimizer):\n            scale = self.optimizer(self.weight, qtype=self.weight_qtype, axis=0)\n            shift = None\n        else:\n            optimizer_kwargs = {\"qtype\": self.weight_qtype, \"axis\": 0, \"group_size\": self.weight_group_size}\n            if self.weight.device.type == \"xpu\":\n                optimizer_kwargs.update({\"zeropoint\": True})\n            scale, shift = self.optimizer(self.weight, **optimizer_kwargs)\n\n        return quantize_weight(\n            self.weight,\n            qtype=self.weight_qtype,\n            axis=0,\n            scale=scale,\n            shift=shift,\n            group_size=self.weight_group_size,\n            activation_qtype=self.activation_qtype,\n        )\n\n    def qforward(self, input: torch.Tensor) -> torch.Tensor:\n        raise NotImplementedError\n\n    def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor:\n        input = input[0]\n        if isinstance(input, ActivationQBytesTensor):\n            if input.qtype != self.activation_qtype:\n                raise ValueError(\n                    \"Models with heterogeneous quantized activations are not supported:\"\n                    f\" expected {self.activation_qtype.name} input but got {input.qtype.name} instead.\"\n                )\n        else:\n            input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale)\n        return input\n\n    def quantize_output(\n        self,\n        module: torch.nn.Module,\n        input: torch.Tensor,\n        output: torch.Tensor,\n    ) -> torch.Tensor:\n        return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale)\n\n    def freeze(self):\n        qweight = self.qweight\n        if qweight is not None:\n            # Replace float weights by quantized weights\n            self.weight = torch.nn.Parameter(qweight)\n\n    @property\n    def frozen(self):\n        return isinstance(self.weight, QTensor)\n"
  },
  {
    "path": "optimum/quanto/quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom fnmatch import fnmatch\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\n\nfrom .nn import QModuleMixin, quantize_module\nfrom .tensor import Optimizer, qtype\n\n\n__all__ = [\"quantize\", \"freeze\", \"requantize\", \"quantization_map\"]\n\n\ndef set_module_by_name(parent_module, name, child_module):\n    module_names = name.split(\".\")\n    if len(module_names) == 1:\n        setattr(parent_module, name, child_module)\n    else:\n        parent_module_name = name[: name.rindex(\".\")]\n        parent_module = parent_module.get_submodule(parent_module_name)\n        setattr(parent_module, module_names[-1], child_module)\n\n\ndef _quantize_submodule(\n    model: torch.nn.Module,\n    name: str,\n    module: torch.nn.Module,\n    weights: Optional[Union[str, qtype]] = None,\n    activations: Optional[Union[str, qtype]] = None,\n    optimizer: Optional[Optimizer] = None,\n):\n    qmodule = quantize_module(module, weights=weights, activations=activations, optimizer=optimizer)\n    if qmodule is not None:\n        set_module_by_name(model, name, qmodule)\n        qmodule.name = name\n        for name, param in module.named_parameters():\n            # Save device memory by clearing parameters\n            setattr(module, name, None)\n            del param\n\n\ndef quantize(\n    model: torch.nn.Module,\n    weights: Optional[Union[str, qtype]] = None,\n    activations: Optional[Union[str, qtype]] = None,\n    optimizer: Optional[Optimizer] = None,\n    include: Optional[Union[str, List[str]]] = None,\n    exclude: Optional[Union[str, List[str]]] = None,\n):\n    \"\"\"Quantize the specified model submodules\n\n    Recursively quantize the submodules of the specified parent model.\n\n    Only modules that have quantized counterparts will be quantized.\n\n    If include patterns are specified, the submodule name must match one of them.\n\n    If exclude patterns are specified, the submodule must not match one of them.\n\n    Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See\n    https://docs.python.org/3/library/fnmatch.html for more details.\n\n    Note: quantization happens in-place and modifies the original model and its descendants.\n\n    Args:\n        model (`torch.nn.Module`): the model whose submodules will be quantized.\n        weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization.\n        activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization.\n        include (`Optional[Union[str, List[str]]]`):\n            Patterns constituting the allowlist. If provided, module names must match at\n            least one pattern from the allowlist.\n        exclude (`Optional[Union[str, List[str]]]`):\n            Patterns constituting the denylist. If provided, module names must not match\n            any patterns from the denylist.\n    \"\"\"\n    if include is not None:\n        include = [include] if isinstance(include, str) else include\n    if exclude is not None:\n        exclude = [exclude] if isinstance(exclude, str) else exclude\n    for name, m in model.named_modules():\n        if include is not None and not any(fnmatch(name, pattern) for pattern in include):\n            continue\n        if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):\n            continue\n        _quantize_submodule(model, name, m, weights=weights, activations=activations, optimizer=optimizer)\n\n\ndef requantize(\n    model: torch.nn.Module,\n    state_dict: Dict[str, Any],\n    quantization_map: Dict[str, Dict[str, str]],\n    device: torch.device = None,\n):\n    if device is None:\n        device = next(model.parameters()).device\n        if device.type == \"meta\":\n            device = torch.device(\"cpu\")\n\n    # Quantize the model with parameters from the quantization map\n    for name, m in model.named_modules():\n        qconfig = quantization_map.get(name, None)\n        if qconfig is not None:\n            weights = qconfig[\"weights\"]\n            if weights == \"none\":\n                weights = None\n            activations = qconfig[\"activations\"]\n            if activations == \"none\":\n                activations = None\n            _quantize_submodule(model, name, m, weights=weights, activations=activations)\n\n    # Move model parameters and buffers to CPU before materializing quantized weights\n    for name, m in model.named_modules():\n\n        def move_tensor(t, device):\n            if t.device.type == \"meta\":\n                return torch.empty_like(t, device=device)\n            return t.to(device)\n\n        for name, param in m.named_parameters(recurse=False):\n            setattr(m, name, torch.nn.Parameter(move_tensor(param, \"cpu\")))\n        for name, param in m.named_buffers(recurse=False):\n            setattr(m, name, move_tensor(param, \"cpu\"))\n\n    # Move to target device\n    model.to(device)\n    # Load the quantized model weights\n    model.load_state_dict(state_dict, strict=False)\n\n\ndef freeze(model):\n    for name, m in model.named_modules():\n        if isinstance(m, QModuleMixin):\n            m.freeze()\n\n\ndef quantization_map(model: torch.nn.Module) -> Dict[str, Dict[str, str]]:\n    \"\"\"Returns the quantization map of a module\n\n    The quantization map is a dictionary of quantization parameters indexed\n    by the module submodule names (including prefix).\n\n    This is mainly used for serialization.\n\n    Args:\n        model (`torch.nn.Module`): the root module to map.\n\n    Returns:\n        a dictionary of quantization parameters indexed by layer names.\n    \"\"\"\n    config = {}\n    for name, m in model.named_modules():\n        if isinstance(m, QModuleMixin):\n            config[name] = {\n                \"weights\": \"none\" if m.weight_qtype is None else m.weight_qtype.name,\n                \"activations\": \"none\" if m.activation_qtype is None else m.activation_qtype.name,\n            }\n    return config\n"
  },
  {
    "path": "optimum/quanto/subpackage/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .commands import *\n"
  },
  {
    "path": "optimum/quanto/subpackage/commands/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .base import *\n"
  },
  {
    "path": "optimum/quanto/subpackage/commands/base.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom optimum.commands.base import BaseOptimumCLICommand, CommandInfo\nfrom optimum.commands.optimum_cli import optimum_cli_subcommand\n\nfrom .quantize import QuantizeCommand\n\n\n__all__ = [\"QuantoCommand\"]\n\n\n@optimum_cli_subcommand()\nclass QuantoCommand(BaseOptimumCLICommand):\n    COMMAND = CommandInfo(name=\"quanto\", help=\"Hugging Face models quantization tools\")\n    SUBCOMMANDS = (\n        CommandInfo(\n            name=\"quantize\",\n            help=\"Quantize Hugging Face models.\",\n            subcommand_class=QuantizeCommand,\n        ),\n    )\n"
  },
  {
    "path": "optimum/quanto/subpackage/commands/quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Hugging Face models quantization command-line interface class.\"\"\"\n\nfrom typing import TYPE_CHECKING\n\nimport torch\nfrom optimum.commands.base import BaseOptimumCLICommand\nfrom optimum.exporters.tasks import TasksManager\n\nfrom ...models import QuantizedTransformersModel\n\n\nif TYPE_CHECKING:\n    from argparse import ArgumentParser\n\n\nSUPPORTED_LIBRARIES = [\"transformers\"]\n\n\ndef parse_quantize_args(parser: \"ArgumentParser\"):\n    required_group = parser.add_argument_group(\"Required arguments\")\n    required_group.add_argument(\n        \"output\",\n        type=str,\n        help=\"The path to save the quantized model.\",\n    )\n    required_group.add_argument(\n        \"-m\",\n        \"--model\",\n        type=str,\n        required=True,\n        help=\"Hugging Face Hub model id or path to a local model.\",\n    )\n    required_group.add_argument(\n        \"--weights\",\n        type=str,\n        default=\"int8\",\n        choices=[\"int2\", \"int4\", \"int8\", \"float8\"],\n        help=\"The Hugging Face library to use to load the model.\",\n    )\n\n    optional_group = parser.add_argument_group(\"Optional arguments\")\n    optional_group.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        help=\"The Hugging Face model revision.\",\n    )\n    optional_group.add_argument(\n        \"--trust_remote_code\",\n        action=\"store_true\",\n        default=False,\n        help=\"Trust remote code when loading the model.\",\n    )\n    optional_group.add_argument(\n        \"--library\",\n        type=str,\n        default=None,\n        choices=SUPPORTED_LIBRARIES,\n        help=\"The Hugging Face library to use to load the model.\",\n    )\n    optional_group.add_argument(\n        \"--task\",\n        type=str,\n        default=None,\n        help=\"The model task (useful for models supporting multiple tasks).\",\n    )\n    optional_group.add_argument(\n        \"--torch_dtype\",\n        type=str,\n        default=\"auto\",\n        choices=[\"auto\", \"fp16\", \"bf16\"],\n        help=\"The torch dtype to use when loading the model weights.\",\n    )\n    optional_group.add_argument(\n        \"--device\",\n        type=str,\n        default=\"cpu\",\n        help=\"The device to use when loading the model.\",\n    )\n\n\nclass QuantizeCommand(BaseOptimumCLICommand):\n    @staticmethod\n    def parse_args(parser: \"ArgumentParser\"):\n        return parse_quantize_args(parser)\n\n    def run(self):\n        model_name_or_path = self.args.model\n        library_name = self.args.library\n        if library_name is None:\n            library_name = TasksManager.infer_library_from_model(model_name_or_path)\n        if library_name not in SUPPORTED_LIBRARIES:\n            raise ValueError(\n                f\"{library_name} models are not supported by this CLI, but can be quantized using the python API directly.\"\n            )\n        task = self.args.task\n        if task is None:\n            task = TasksManager.infer_task_from_model(model_name_or_path)\n        torch_dtype = self.args.torch_dtype\n        if torch_dtype != \"auto\":\n            torch_dtype = torch.float16 if self.args.torch_dtype == \"fp16\" else torch.bfloat16\n        model = TasksManager.get_model_from_task(\n            task,\n            model_name_or_path,\n            revision=self.args.revision,\n            trust_remote_code=self.args.trust_remote_code,\n            framework=\"pt\",\n            torch_dtype=torch_dtype,\n            device=torch.device(self.args.device),\n            library_name=library_name,\n            low_cpu_mem_usage=True,\n        )\n        weights = f\"q{self.args.weights}\"\n        qmodel = QuantizedTransformersModel.quantize(model, weights=weights)\n        qmodel.save_pretrained(self.args.output)\n"
  },
  {
    "path": "optimum/quanto/tensor/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .activations import *\nfrom .core import *\nfrom .grouped import *\nfrom .optimizers import *\nfrom .qbits import *\nfrom .qbytes import *\nfrom .qtensor import *\nfrom .qtype import *\nfrom .weights import *\n"
  },
  {
    "path": "optimum/quanto/tensor/activations/__init__.py",
    "content": "from .qbytes import *\nfrom .quantization import *\n"
  },
  {
    "path": "optimum/quanto/tensor/activations/qbytes.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\nfrom torch.autograd import Function\n\nfrom ..qbytes import QBytesTensor\nfrom ..qtensor import qfallback\nfrom ..qtype import qtype, qtypes\n\n\n__all__ = [\"ActivationQBytesTensor\"]\n\n\nclass ActivationQBytesQuantizer(Function):\n    @staticmethod\n    def forward(ctx, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor:\n        if qtype.bits != 8:\n            raise ValueError(\"QBytesTensor can only be of 8-bit qtype\")\n        size = base.size()\n        stride = base.stride()\n        data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=None, scale=scale)\n        # The instantiation of the quantized tensor must happen within the context of the Function\n        # for the autograd magic to work.\n        return ActivationQBytesTensor(qtype, size, stride, data, scale)\n\n    @staticmethod\n    def backward(ctx, gO):\n        # For autograd, quantization is a no-op\n        return gO, None, None, None, None, None\n\n\nclass ActivationQBytesTensor(QBytesTensor):\n    @staticmethod\n    def __new__(cls, qtype, size, stride, data, scale, requires_grad=False):\n        assert data.device == scale.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, size, stride, data, scale, requires_grad=False):\n        super().__init__(qtype, None, size, stride, data, scale, requires_grad)\n\n    @classmethod\n    def quantize(cls, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor:\n        return ActivationQBytesQuantizer.apply(base, qtype, scale)\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\"]\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 2\n        assert len(meta) == 3\n        data, scale = inner_tensors[\"_data\"], inner_tensors[\"_scale\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return ActivationQBytesTensor(qtype, size, stride, data, scale)\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        from .qbytes_ops import get_qbytestensor_op_dispatch\n\n        kwargs = kwargs or {}\n        # Do not use directly op, but rather its overload\n        op = op.overloadpacket\n        qdispatch = get_qbytestensor_op_dispatch(op)\n        if qdispatch is not None:\n            return qdispatch(*args, **kwargs)\n        # No dispatch available: qfallback\n        return qfallback(op, *args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/activations/qbytes_ops.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numbers\nfrom functools import partial\nfrom typing import Callable, List\n\nimport torch\n\nfrom ..core import dtype_info\nfrom ..qtensor import QTensor, qfallback\nfrom ..qtype import qint8\nfrom .qbytes import ActivationQBytesTensor\nfrom .quantization import quantize_activation\n\n\n__all__ = [\"get_qbytestensor_op_dispatch\", \"register_qbytestensor_op\"]\n\n\n_QBYTESTENSOR_OP_TABLE = {}\n\n\ndef register_qbytestensor_op(aten_ops: List[Callable]):\n    \"\"\"\n    Used for registering a new __torch_dispatch__ aten operation to QBytesTensor.\n\n    The code to register a new operation looks like:\n\n    @register_qbytestensor_op(list_of_ops)\n    def foo(op, *args, **kwargs):\n        <implementation>\n    \"\"\"\n\n    def wrapper(op):\n        for aten_op in aten_ops:\n            _QBYTESTENSOR_OP_TABLE[aten_op] = partial(op, aten_op)\n\n    return wrapper\n\n\ndef get_qbytestensor_op_dispatch(aten_op):\n    return _QBYTESTENSOR_OP_TABLE.get(aten_op, None)\n\n\ndef is_scalar(t):\n    return isinstance(t, numbers.Number) or type(t) is torch.Tensor and len(t.shape) == 0\n\n\n@register_qbytestensor_op([torch.ops.aten._to_copy, torch.ops.aten.to])\ndef _to_copy(op, t, dtype=None, **kwargs):\n    # For data, ignore dtype and use the inner type instead\n    out_data = op(t._data, dtype=t._data.dtype, **kwargs)\n    # Apply the new dtype on the scale only\n    out_scale = op(t._scale, dtype=dtype, **kwargs)\n    return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.detach])\ndef detach(op, t):\n    # Detach both data and scale\n    out_data = op(t._data)\n    out_scale = op(t._scale)\n    return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.cat])\ndef cat(op, inputs, dim=0):\n    if len(inputs) == 2:\n        t1, t2 = inputs\n        # Only quantized tensors with identical scalar scales can be concatenated\n        if (\n            isinstance(t1, ActivationQBytesTensor)\n            and isinstance(t2, ActivationQBytesTensor)\n            and torch.equal(t1._scale, t2._scale)\n            and t1.qtype == t2.qtype\n        ):\n            if t1.qtype.is_floating_point or t2.qtype.is_floating_point:\n                # Cat is not supported for float8\n                return qfallback(op, inputs, dim)\n            out_data = op([t1._data, t2._data], dim)\n            return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)\n    return qfallback(op, inputs, dim)\n\n\n@register_qbytestensor_op([torch.ops.aten.lt])\ndef lt(op, input, other):\n    # Only quantized tensors with identical scales can be compared\n    if (\n        isinstance(input, ActivationQBytesTensor)\n        and isinstance(other, ActivationQBytesTensor)\n        and torch.equal(input._scale, other._scale)\n    ):\n        return op(input._data, other._data)\n    return qfallback(op, input, other)\n\n\n@register_qbytestensor_op([torch.ops.aten.clone])\ndef clone(op, t, memory_format=torch.preserve_format):\n    # We need to restore the data original shape before cloning to get the correct strides\n    data_shape = t._data.shape\n    out_data = t._data.reshape(t.shape)\n    out_data = op(t._data, memory_format=memory_format)\n    out_stride = out_data.stride()\n    out_data = out_data.reshape(data_shape)\n    out_scale = op(t._scale, memory_format=memory_format)\n    return ActivationQBytesTensor(t.qtype, t.size(), out_stride, out_data, out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.copy_])\ndef copy_(op, dest, src):\n    assert dest.qtype == src.qtype\n    dest._data = op(dest._data, src._data)\n    dest._scale = op(dest._scale, src._scale)\n    return dest\n\n\n@register_qbytestensor_op([torch.ops.aten.div])\ndef div(op, input, other):\n    if not is_scalar(other):\n        return op(input.dequantize(), other)\n    # We just divide the scale\n    return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, op(input._scale, other))\n\n\n@register_qbytestensor_op([torch.ops.aten.neg])\ndef neg(op, input, *args, **kwargs):\n    if input.qtype.is_floating_point:\n        # Neg is not supported for float8\n        return op(input.dequantize(), *args, **kwargs)\n    out_data = op(input._data, *args, **kwargs)\n    return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)\n\n\n@register_qbytestensor_op(\n    [\n        torch.ops.aten.expand,\n        torch.ops.aten.permute,\n        torch.ops.aten.select,\n        torch.ops.aten.slice,\n        torch.ops.aten.unsqueeze,\n    ]\n)\ndef unary_type_agnostic_op(op, input, *args, **kwargs):\n    if input.axis is not None:\n        return op(input.dequantize(), *args, **kwargs)\n    # When quantization is per-tensor, these operations can be transparently applied\n    # without modifying the scale.\n    out_data = op(input._data, *args, **kwargs)\n    return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.is_same_size])\ndef is_same_size(op, input, other):\n    a = input._data if isinstance(input, ActivationQBytesTensor) else input\n    b = other._data if isinstance(other, ActivationQBytesTensor) else other\n    return op(a, b)\n\n\ndef cannot_mm(t: QTensor):\n    \"\"\"True if the QTensor data cannot be passed to an mm op\"\"\"\n    return t.axis is not None and t.size() != t._data.size()\n\n\n@register_qbytestensor_op([torch.ops.aten.bmm])\ndef bmm(op, input, other):\n    if not isinstance(input, ActivationQBytesTensor):\n        return op(input, other.dequantize())\n    if not isinstance(other, QTensor) or input.axis is not None:\n        return op(input.dequantize(), other)\n    if input.qtype != qint8 or other.qtype != qint8 or cannot_mm(other):\n        return qfallback(op, input, other)\n    # Cast data to float32 and do the operation\n    out_data = op(input._data.to(torch.float32), other._data.to(torch.float32))\n    out_scale = (input._scale * other._scale).to(torch.float32)\n    return (out_data * out_scale).to(input._scale.dtype)\n\n\n@register_qbytestensor_op([torch.ops.aten.mul])\ndef mul(op, input, other):\n    # If one of the multiplicands is a scalar, just multiply the scale\n    if is_scalar(input):\n        return ActivationQBytesTensor(other.qtype, other.size(), other.stride(), other._data, input * other._scale)\n    if is_scalar(other):\n        return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, other * input._scale)\n    return qfallback(op, input, other)\n\n\n@register_qbytestensor_op([torch.ops.aten.relu])\ndef relu(op, input):\n    if input.qtype.is_floating_point:\n        # Relu is not supported for float8 types\n        return qfallback(op, input)\n    out_data = op(input._data)\n    return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)\n\n\n@register_qbytestensor_op([torch.ops.aten._softmax])\ndef _softmax(op, input, dim, half_to_float):\n    # Softmax must be performed in float\n    float_data = op(input.dequantize(), dim, half_to_float)\n    # Since softmax is normalized, we know the optimal scale\n\n    out_scale = torch.tensor(1 / dtype_info(input.qtype.dtype).max, dtype=input._scale.dtype).to(input.device)\n    return quantize_activation(float_data, qtype=input.qtype, scale=out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.stack])\ndef stack(op, inputs, dim=0):\n    if len(inputs) == 2:\n        t1, t2 = inputs\n        # Only quantized tensors with identical scales can be stacked\n        if (\n            isinstance(t1, ActivationQBytesTensor)\n            and isinstance(t2, ActivationQBytesTensor)\n            and t1.axis is None\n            and t2.axis is None\n            and torch.equal(t1._scale, t2._scale)\n            and t1.qtype == t2.qtype\n        ):\n            out_data = op([t1._data, t2._data], dim)\n            return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale)\n    return qfallback(inputs, dim)\n\n\n@register_qbytestensor_op([torch.ops.aten.split])\ndef split(op, input, *args, **kwargs):\n    if input.axis is not None:\n        return qfallback(op, input, *args, **kwargs)\n    out_datas = op(input._data, *args, **kwargs)\n    return [\n        ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale)\n        for out_data in out_datas\n    ]\n\n\n@register_qbytestensor_op([torch.ops.aten.transpose])\ndef transpose(op, input, *args):\n    out_data = op(input._data, *args)\n    out_size = out_data.size()\n    out_stride = out_data.stride()\n    out_scale = input._scale\n    return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.t])\ndef transpose2d(op, input):\n    out_data = op(input._data)\n    out_scale = input._scale\n    # Manually reverse size and stride because we cannot trust the out_data shape\n    dim0, dim1 = input.size()\n    out_size = torch.Size([dim1, dim0])\n    out_stride = input.stride()[::-1]\n    return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale)\n\n\n@register_qbytestensor_op([torch.ops.aten.view, torch.ops.aten._unsafe_view])\ndef view(op, input, *shape):\n    if input.axis is None:\n        # The view is transparent for QTensor with scalar scales\n        out_data = op(input._data, *shape)\n        return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale)\n    return qfallback(op, input, *shape)\n\n\n@register_qbytestensor_op([torch.ops.aten.where])\ndef where(op, condition, input, other):\n    if isinstance(condition, QTensor) or isinstance(other, QTensor):\n        raise NotImplementedError\n    float_data = op(condition, input.dequantize(), other)\n    if input.axis is None:\n        # We requantize with the input scale\n        return quantize_activation(float_data, qtype=input.qtype, scale=input._scale)\n    return float_data\n"
  },
  {
    "path": "optimum/quanto/tensor/activations/quantization.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\nfrom ..qtype import qtype\nfrom .qbytes import ActivationQBytesTensor\n\n\n__all__ = [\"quantize_activation\"]\n\n\ndef quantize_activation(t: torch.Tensor, qtype: qtype, scale: torch.Tensor):\n    \"\"\"Quantize an activation Tensor.\n\n    Activations are always quantized per-tensor with a scalar scale.\n\n    Args:\n        base (`torch.Tensor`): the Tensor to quantize\n        qtype (`quanto.qtype`): The target quantization type\n        scale (`torch.Tensor`): The scalar quantization scale\n\n    Returns:\n        A quantized Tensor.\n    \"\"\"\n    if scale.numel() != 1:\n        raise ValueError(\"Parameter scale must be a scalar because activations can only be quantized per-tensor\")\n    return ActivationQBytesTensor.quantize(t, qtype, scale)\n"
  },
  {
    "path": "optimum/quanto/tensor/core.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\n\n\n__all__ = [\"axis_to_dim\", \"dtype_info\"]\n\n\ndef dtype_info(dtype):\n    info = torch.finfo if dtype.is_floating_point else torch.iinfo\n    return info(dtype)\n\n\ndef axis_to_dim(t, axis):\n    dim = list(range(t.ndim))\n    if axis == -1:\n        dim = dim[:-1]\n    else:\n        dim.remove(axis)\n    return dim\n"
  },
  {
    "path": "optimum/quanto/tensor/function.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\n\n__all__ = [\"QuantizedLinearFunction\"]\n\n\nclass QuantizedLinearFunction(torch.autograd.Function):\n    \"\"\"Quantized linear function.\n\n    This is a quantized implementation of torch.nn.functional.linear.\n\n    It defines explicitly the backward pass instead of letting pytorch\n    build it by combining the gradients of the underlying quantized operations.\n\n    This has two main benefits:\n\n    - this saves computations,\n    - this allows to use operations that do not have a registered backward method,\n    such as quanto custom operations.\n\n    The drawback is that the extra tensors involved in the quantization graph, such as\n    the scales and shift, cannot be trained.\n    This is however consistent with the quanto quantizers backward pass, that returns\n    a zero gradient for these tensors.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, input, other, bias=None):\n        ctx.save_for_backward(input, other)\n        output = torch.matmul(input, other.t())\n        if bias is not None:\n            output = output + bias\n        return output\n\n    def backward(ctx, gO):\n        input_gO = other_gO = bias_gO = None\n        input, other = ctx.saved_tensors\n        out_features, in_features = other.shape\n        if ctx.needs_input_grad[0]:\n            # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B\n            input_gO = torch.matmul(gO, other)\n        if ctx.needs_input_grad[1]:\n            # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A\n            other_gO = torch.matmul(gO.view(-1, out_features).t(), input.view(-1, in_features))\n        if ctx.needs_input_grad[2]:\n            # Bias gradient is the sum on all dimensions but the last one\n            dim = tuple(range(gO.ndim - 1))\n            bias_gO = gO.sum(dim)\n        return input_gO, other_gO, bias_gO\n"
  },
  {
    "path": "optimum/quanto/tensor/grouped.py",
    "content": "import math\nfrom typing import List\n\nimport torch\n\n\n__all__ = [\"group\", \"ungroup\", \"grouped_shape\"]\n\n\ndef grouped_shape(shape: List, axis: int, group_size: int) -> List:\n    if axis not in (0, -1):\n        raise ValueError(\"Axis must be 0 or -1 for group-wise quantization\")\n    n_groups = math.prod(shape) // group_size\n    return (n_groups, group_size) if axis == 0 else (group_size, n_groups)\n\n\ndef group(base: torch.Tensor, axis: int, group_size: int):\n    if axis not in (0, -1):\n        raise ValueError(\"Axis must be 0 or -1 for group-wise quantization\")\n    # In standard per-axis quantization, we have one scale per axis dim\n    axis_dim = base.shape[axis]\n    # This scale is evaluated over axis_numel items for each feature along axis\n    axis_numel = base.numel() // axis_dim\n    if group_size > axis_numel or axis_numel % group_size != 0:\n        raise ValueError(f\"Group size ({group_size}) must be a divisor of ({axis_numel})\")\n    # Group-wise quantization further splits axis_numel into multiple groups per axis\n    axis_groups = axis_numel // group_size\n    if axis == 0:\n        # Easy-peasy: we simply need to reshape to (axis_dim * axis_groups, group_size)\n        return base.reshape([-1, group_size])\n    # More difficult: reshape to (group_size, axis_dim * axis_groups)\n    # First, split by groups, preserving the axis dimension\n    grouped = base.reshape((axis_groups, group_size, axis_dim))\n    # Permute to (group_size, axis_dim, axis_groups)\n    grouped = grouped.permute(1, 2, 0)\n    return grouped.reshape(group_size, axis_dim * axis_groups)\n\n\ndef ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size):\n    if grouped.shape == orig_shape:\n        return grouped\n    if axis == 0:\n        # No transposition required, just reshape\n        return grouped.reshape(orig_shape)\n    group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1]\n    axis_dim = orig_shape[axis]\n    axis_groups = grouped.numel() // axis_dim // group_size\n    ungrouped = grouped.reshape(group_size, axis_dim, axis_groups)\n    # Permute to (axis_groups, group_size, axis_dim)\n    ungrouped = ungrouped.permute(2, 0, 1)\n    return ungrouped.reshape(orig_shape)\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/__init__.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom .absmax_optimizer import *\nfrom .affine_optimizer import *\nfrom .hqq_optimizer import *\nfrom .max_optimizer import *\nfrom .optimizer import *\nfrom .symmetric_optimizer import *\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/absmax_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\nfrom ..qtype import qtype\nfrom .symmetric_optimizer import SymmetricOptimizer\n\n\n__all__ = [\"AbsmaxOptimizer\"]\n\n\nclass AbsmaxOptimizer(SymmetricOptimizer):\n    def optimize(\n        self, base: torch.Tensor, qtype: qtype, axis: Optional[int] = None\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        base = torch.abs(base)\n        if axis is None:\n            rmax = torch.max(base)\n        else:\n            dim = list(range(1, base.ndim)) if (axis == 0) else list(range(0, base.ndim - 1))\n            rmax = torch.amax(torch.abs(base), dim=dim, keepdim=True)\n        return rmax / qtype.qmax\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/affine_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple\n\nimport torch\n\nfrom ..grouped import group\nfrom ..qtype import qtype\nfrom .optimizer import Optimizer\n\n\n__all__ = [\"AffineOptimizer\"]\n\n\nclass AffineOptimizer(Optimizer):\n    def __call__(\n        self,\n        base: torch.Tensor,\n        qtype: qtype,\n        axis: int,\n        group_size: Optional[int] = None,\n        zeropoint: bool = False,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Args:\n            base (`torch.Tensor`): the weight Tensor to quantize\n            qtype (`quanto.qtype`): The target quantization type\n            axis ('int`): The quantization axis (0 or -1)\n            group_size (`Optional[int]`): The quantization group size\n            zeropoint (`bool`): Allow an exact representation of zero. If True, the shifts are stored as\n                integer instead of float, which results in a slightly smaller model, but might also reduce\n                the model performance. Defaults to False.\n        Returns:\n            A tuple of scale, shift Tensor.\n        \"\"\"\n        if axis not in [0, -1]:\n            raise ValueError(\"axis parameter must be 0 (first axis) or -1 (last axis)\")\n        if group_size is not None:\n            base = group(base, axis, group_size)\n        if axis is not None and base.shape[axis] == 1:\n            axis = None\n        scale, shift = self.optimize(base, qtype, axis)\n        assert scale.dtype == base.dtype\n        assert shift.dtype == base.dtype\n        if zeropoint:\n            # Round shift to make sure zero can be represented exactly using 'shift' as quantized value\n            shift = torch.clamp(torch.round(shift / scale), 0, 2**qtype.bits - 1)\n            shift = shift.to(torch.int8) if base.device.type == \"xpu\" else shift.to(torch.uint8)\n        return scale, shift\n\n    def optimize(self, base: torch.Tensor, qtype: qtype, axis: int) -> Tuple[torch.Tensor, torch.Tensor]:\n        raise NotImplementedError\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/hqq_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\nfrom ..qtype import qtype\nfrom ..weights import quantize_weight\nfrom .max_optimizer import MaxOptimizer\n\n\n__all__ = [\"HqqOptimizer\"]\n\n\n# Shrinking operator\ndef shrink_lp_op(x: torch.Tensor, beta: float, lp_norm: float) -> torch.Tensor:\n    if lp_norm == 1:\n        return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)\n    else:\n        return torch.sign(x) * torch.nn.functional.relu(\n            torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)\n        )\n\n\nclass HqqOptimizer(MaxOptimizer):\n    \"\"\"Implementation of the HQQ algorithm\n\n    This is an implementation of the algorithm described in \"Half-Quadratic Quantization of Large Machine Learning Models\",\n    by Hicham Badri and Appu Shaji (https://mobiusml.github.io/hqq_blog/).\n    This is an adaption of the original implementation at https://github.com/mobiusml/hqq.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        lp_norm: Optional[float] = 0.7,\n        beta: Optional[int] = 1e1,\n        kappa: Optional[float] = 1.01,\n        iters: Optional[int] = 20,\n        verbose: Optional[bool] = False,\n    ) -> None:\n        self.lp_norm = lp_norm\n        self.beta = beta\n        self.kappa = kappa\n        self.iters = iters\n        self.verbose = verbose\n\n    def optimize(\n        self, base: torch.Tensor, qtype: qtype, axis: int\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        scale, shift = super().optimize(base, qtype, axis)\n        best_error = None\n        beta = self.beta\n        base_q = quantize_weight(base, qtype=qtype, axis=axis, scale=scale, shift=shift)\n        for i in range(self.iters):\n            error = base - base_q\n            if best_error is None:\n                best_error = float(torch.abs(base - base_q).mean())\n                if self.verbose:\n                    print(f\"Start error: {best_error:.6f}\")\n            e = shrink_lp_op(error, beta, self.lp_norm)\n            mean_axis = 0 if axis == -1 else -1\n            hqq_shift = torch.mean(base_q._data * scale - (base - e), axis=mean_axis, keepdim=True)\n            base_q = quantize_weight(base, qtype=qtype, axis=axis, scale=scale, shift=hqq_shift)\n            mean_error = float(torch.abs(base - base_q).mean())\n            if self.verbose:\n                print(f\"HQQ error at it #{i}: {mean_error:.6f}\")\n            if mean_error < best_error:\n                best_error = mean_error\n                shift = hqq_shift\n                beta *= self.kappa\n            else:\n                break\n\n        return scale, shift\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/max_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Tuple, Union\n\nimport torch\n\nfrom ..qtype import qtype\nfrom .affine_optimizer import AffineOptimizer\n\n\n__all__ = [\"MaxOptimizer\"]\n\n\nclass MaxOptimizer(AffineOptimizer):\n    def optimize(\n        self, base: torch.Tensor, qtype: qtype, axis: int\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        dim = list(range(1, base.ndim)) if (axis == 0) else list(range(0, base.ndim - 1))\n        rmin = torch.amin(base, dim=dim, keepdim=True)\n        rmax = torch.amax(base, dim=dim, keepdim=True)\n        qmin = -(2 ** (qtype.bits - 1))\n        qmax = 2 ** (qtype.bits - 1) - 1\n        scale = (rmax - rmin) / (qmax - qmin)\n        shift = -rmin\n        return scale, shift\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom abc import ABC\nfrom typing import Optional, Tuple, Union\n\nimport torch\n\n\n__all__ = [\"Optimizer\"]\n\n\nclass Optimizer(ABC):\n    def __call__(\n        self, base: torch.Tensor, bits: int, axis: int, group_size: Optional[int] = None\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        raise NotImplementedError\n"
  },
  {
    "path": "optimum/quanto/tensor/optimizers/symmetric_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\nfrom ..qtype import qtype\nfrom .optimizer import Optimizer\n\n\n__all__ = [\"SymmetricOptimizer\"]\n\n\nclass SymmetricOptimizer(Optimizer):\n    def __call__(self, base: torch.Tensor, qtype: qtype, axis: Optional[int] = None) -> torch.Tensor:\n        if axis not in [None, 0, -1]:\n            raise ValueError(\"axis parameter must be None, 0 (first axis) or -1 (last axis)\")\n        if axis is not None and base.shape[axis] == 1:\n            axis = None\n        scale = self.optimize(base, qtype, axis)\n        assert scale.dtype == base.dtype\n\n        return scale\n\n    def optimize(self, base: torch.Tensor, qmax: float, axis: Optional[int] = None) -> torch.Tensor:\n        raise NotImplementedError\n"
  },
  {
    "path": "optimum/quanto/tensor/packed.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\nfrom torch.utils import _pytree as pytree\n\n\n__all__ = [\"PackedTensor\"]\n\n\ndef pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor:\n    \"\"\"\n    Pack int4 / int2 weights in a uint8 tensor\n\n    What packing means? Assume we have 4 values that are in 2bit but encoded in 8bit\n    (because torch does not have native support for 2-bit datatypes)\n\n    > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000\n\n    We can pack them in a single 8-bit uint value\n\n    > 1110 0100\n\n    Therefore instead of saving 4 values in 8-bit precision we save a single value of 8-bit precision saving 24 bits in total.\n\n    Args:\n        intweights (`torch.Tensor`):\n            The un-packed `torch.uint8` tensor\n        bits (`int`):\n            The actual `bits` - can be 2, 4\n    \"\"\"\n    original_shape = intweights.shape\n    values_per_item = 8 // bits\n    row_dim = (original_shape[0] + values_per_item - 1) // values_per_item\n\n    if len(original_shape) == 1:\n        packed_tensor_shape = (row_dim,)\n    else:\n        packed_tensor_shape = (row_dim, *original_shape[1:])\n\n    packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)\n    unpacked = intweights.to(torch.uint8)\n\n    def lshift(t: torch.Tensor, bits: int):\n        if t.device.type == \"mps\":\n            # lshift is not supported on MPS device\n            return t * (2**bits)\n        return t << bits\n\n    it = min(values_per_item, (original_shape[0] // row_dim) + 1)\n    for i in range(it):\n        start = i * row_dim\n        end = min(start + row_dim, original_shape[0])\n        packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)\n\n    return packed\n\n\nclass PackedTensor(torch.Tensor):\n    @staticmethod\n    def __new__(cls, data, bits, size, stride, requires_grad=False):\n        # PackedTensor represents uint8 data and can therefore NEVER require gradient\n        assert data.dtype == torch.uint8\n        assert requires_grad is False\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, data, bits, size, stride, requires_grad=False):\n        self._bits = bits\n        self._data = data\n\n    def __repr__(self):\n        autograd_info = (\n            f\", grad_fn={self.grad_fn}\" if self.grad_fn else \", requires_grad=True\" if self.requires_grad else \"\"\n        )\n        return f\"PackedTensor({self._data}, bits={self._bits}, public_dtype={self.dtype}{autograd_info})\"\n\n    @classmethod\n    def pack(cls, t, bits=4):\n        assert bits in (2, 4)\n        # XPU use int8 dtype\n        assert t.dtype in (torch.uint8, torch.int8)\n        data = pack_weights(t, bits)\n        # We need to store size and stride to make sure the unpacked data has the correct shape\n        return PackedTensor(data, bits, t.size(), t.stride())\n\n    def unpack(self):\n        unpacked_data = torch.ops.quanto.unpack(self._data, self._bits)\n        # Adjust the first dimension, as unpacked data may have extra rows if the original shape is not a multiple of 8 // bits\n        return unpacked_data[: self.shape[0]]\n\n    @property\n    def bits(self):\n        return self._bits\n\n    @property\n    def dtype(self):\n        return torch.uint8\n\n    @staticmethod\n    def load_from_state_dict(state_dict, prefix, bits, size, stride, missing_keys):\n        if prefix + \"_data\" not in state_dict:\n            missing_keys.append(prefix + \"_data\")\n            return\n\n        inner_tensors_dict = {\"_data\": state_dict.pop(prefix + \"_data\")}\n        meta = [name.replace(prefix, \"\") for name in state_dict.keys() if name.startswith(prefix)]\n        meta = {\"bits\": str(bits), \"size\": str(list(size)), \"stride\": str(stride)}\n        return PackedTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\"]\n        # Since meta can be used for serialization, use only AST compatible strings\n        meta = {\"bits\": str(self._bits), \"size\": str(list(self.size())), \"stride\": str(self.stride())}\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 1\n        assert len(meta) == 3\n        data = inner_tensors[\"_data\"]\n        # Meta should contain only AST compatible strings\n        bits = ast.literal_eval(meta[\"bits\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return PackedTensor(data, bits, size, stride)\n\n    __torch_function__ = torch._C._disabled_torch_function_impl\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Convert back to tensor before calling any operation except detach\n        if op.overloadpacket is torch.ops.aten.detach:\n            t = args[0]\n            data = op(t._data)\n            return PackedTensor(data, t._bits, t.size(), t.stride())\n        elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):\n            t = args[0]\n            dtype = kwargs.get(\"dtype\", torch.uint8)\n            if dtype != torch.uint8:\n                raise ValueError(f\"PackedTensor are torch.uint8 only and cannot be moved to {dtype}.\")\n            # Move data\n            data = op(t._data, **kwargs)\n            return PackedTensor(data, t._bits, t.size(), t.stride())\n        args, kwargs = pytree.tree_map_only(PackedTensor, lambda x: x.unpack(), (args, kwargs or {}))\n        return op(*args, **kwargs)\n\n    def numpy(self):\n        return self.unpack().cpu().numpy()\n"
  },
  {
    "path": "optimum/quanto/tensor/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nfrom torch.autograd import Function\n\nfrom .grouped import ungroup\nfrom .packed import PackedTensor\nfrom .qtensor import QTensor\n\n\n__all__ = [\"QBitsTensor\"]\n\n\nclass QBitsDequantizer(Function):\n    @staticmethod\n    def forward(ctx, t):\n        if isinstance(t._data, PackedTensor):\n            data = t._data.unpack()\n        else:\n            data = t._data\n        shift = t._shift\n        if not shift.dtype.is_floating_point:\n            # Remove shift before multiplying by the scale\n            data = data.to(torch.int8) - shift.to(torch.int8)\n        if t.qtype.is_floating_point:\n            # Upcast explicitly to the scale dtype\n            dqt = t._scale * data.to(t._scale.dtype)\n        else:\n            dqt = t._scale * data\n        if shift.dtype.is_floating_point:\n            # Remove scaled shift\n            dqt -= shift\n        if t.axis is None:\n            return dqt\n        # Restore the original shape (if needed)\n        return ungroup(dqt, axis=t.axis, orig_shape=t.shape)\n\n    @staticmethod\n    def backward(ctx, gO):\n        return gO\n\n\nclass QBitsTensor(QTensor):\n    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        super().__init__(qtype, axis)\n        self._data = data\n        self._scale = scale\n        self._shift = shift\n        self._group_size = group_size\n\n    def __repr__(self):\n        return f\"{type(self).__name__}({self._data}, scale={self._scale}, shift={self._shift}, dtype={self.dtype})\"\n\n    def dequantize(self):\n        return QBitsDequantizer.apply(self)\n"
  },
  {
    "path": "optimum/quanto/tensor/qbytes.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom torch.autograd import Function\n\nfrom .qtensor import QTensor\n\n\n__all__ = [\"QBytesTensor\"]\n\n\nclass QBytesDequantizer(Function):\n    @staticmethod\n    def forward(ctx, t):\n        if t.qtype.is_floating_point:\n            # Upcast explicitly to the scale dtype\n            dqt = t._scale * t._data.to(t._scale.dtype)\n        else:\n            dqt = t._scale * t._data\n        return dqt\n\n    @staticmethod\n    def backward(ctx, gO):\n        # For autograd, dequantization is a no-op\n        return gO\n\n\nclass QBytesTensor(QTensor):\n    def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False):\n        super().__init__(qtype, axis)\n        self._data = data\n        self._scale = scale\n\n    def __repr__(self):\n        return f\"{self.__class__}({self._data}, scale={self._scale}, dtype={self.dtype})\"\n\n    def dequantize(self):\n        \"\"\"Differentiable dequantization function\"\"\"\n        return QBytesDequantizer.apply(self)\n"
  },
  {
    "path": "optimum/quanto/tensor/qtensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport torch\nfrom torch.utils import _pytree as pytree\n\n\n__all__ = [\"QTensor\", \"qfallback\"]\n\n\ndef qfallback(callable, *args, **kwargs):\n    \"\"\"Fallback method for QTensor inputs.\n\n    When a torch function or an aten operation is not supported for the specified\n    QTensor arguments, each QTensor arg or kwarg is dequantized to a torch.Tensor\n    before calling the target function or op.\n    \"\"\"\n    args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {}))\n    return callable(*args, **kwargs)\n\n\nclass QTensor(torch.Tensor):\n    def __init__(self, qtype, axis):\n        self._qtype = qtype\n        self._axis = axis\n\n    def dequantize(self):\n        raise NotImplementedError\n\n    def save_to_state_dict(self, destination, prefix, keep_vars):\n        def serialize_tensor_subclass(t, destination, prefix, keep_vars):\n            inner_tensors, meta = t.__tensor_flatten__()\n            for name in inner_tensors:\n                inner_tensor = getattr(t, name)\n                if type(inner_tensor) is torch.Tensor:\n                    # Leaf Tensor, we can serialize it\n                    destination[prefix + name] = inner_tensor if keep_vars else inner_tensor.detach()\n                else:\n                    # Flatten also this inner Tensor\n                    serialize_tensor_subclass(inner_tensor, destination, prefix + name + \".\", keep_vars)\n\n        # Recursively flatten QTensor into individual tensors\n        serialize_tensor_subclass(self, destination, prefix, keep_vars)\n\n    @property\n    def axis(self):\n        return self._axis\n\n    @property\n    def qtype(self):\n        return self._qtype\n\n    def numpy(self):\n        return self.dequantize().cpu().numpy()\n\n    def equal(self, other):\n        if type(self) is not type(other):\n            return False\n        self_tensors, self_meta = self.__tensor_flatten__()\n        _, other_meta = other.__tensor_flatten__()\n        for name, value in self_meta.items():\n            if other_meta[name] != value:\n                return False\n        for name in self_tensors:\n            self_t = getattr(self, name)\n            other_t = getattr(other, name)\n            if self_t.device.type == \"cpu\" and self_t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):\n                # torch.equal is not implemented on CPU for float8 types\n                if self_t.dtype != other_t.dtype:\n                    return False\n                if not torch.equal(self_t.to(torch.float32), other_t.to(torch.float32)):\n                    return False\n            elif not torch.equal(self_t, other_t):\n                return False\n        return True\n"
  },
  {
    "path": "optimum/quanto/tensor/qtype.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\n\nimport torch\n\n\n@dataclass\nclass qtype:\n    \"\"\"A quantized type class mimicking torch dtype\"\"\"\n\n    name: str\n    is_floating_point: bool\n    bits: int\n    # This defines the storage dtype\n    dtype: torch.dtype\n    qmin: float\n    qmax: float\n\n    def __str__(self):\n        return f\"quanto.{self.name}\"\n\n    def __hash__(self):\n        return hash(str(self))\n\n\n# Integer qtypes\n\n\ndef qint(bits):\n    qmin = -(2 ** (bits - 1))\n    qmax = 2 ** (bits - 1) - 1\n    return qtype(f\"qint{bits}\", is_floating_point=False, bits=bits, dtype=torch.int8, qmin=qmin, qmax=qmax)\n\n\nqint2 = qint(2)\nqint4 = qint(4)\nqint8 = qint(8)\n\n# Float qtypes\n\n\ndef qfloat(dtype: torch.dtype):\n    finfo = torch.finfo(dtype)\n    qmin = finfo.min\n    qmax = finfo.max\n    return qtype(f\"q{finfo.dtype}\", is_floating_point=True, bits=8, dtype=dtype, qmin=qmin, qmax=qmax)\n\n\nqfloat8_e4m3fn = qfloat(torch.float8_e4m3fn)\nqfloat8_e4m3fnuz = qfloat(torch.float8_e4m3fnuz)\nqfloat8_e5m2 = qfloat(torch.float8_e5m2)\n\n# Alias the float8 representation that has the better support and inference efficiency\nqfloat8 = qfloat8_e4m3fn\n\n# Convenience dict to get a dtype from its name\nqtypes = {name: q for (name, q) in locals().items() if isinstance(q, qtype)}\n\n__all__ = [\"qtype\", \"qtypes\"] + [str(name) for name in qtypes.keys()]\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/__init__.py",
    "content": "from .qbits import *\nfrom .qbytes import *\nfrom .quantization import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/awq/__init__.py",
    "content": "from .packed import *\nfrom .qbits import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/awq/packed.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nfrom copy import copy\nfrom enum import Enum\n\nimport numpy as np\nimport torch\nfrom torch.utils import _pytree as pytree\n\nfrom ..packing import unpack_int32_to_uint8\n\n\n__all__ = [\"AWQPackedTensor\", \"AWQPacking\"]\n\n\nAWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]\nAWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\n\n\ndef pack(unpacked: torch.Tensor, reorder=False):\n    \"\"\"\n    Pack uint4 weights in an int32 tensor as expected by AWQ mixed mm kernel\n\n    As compared to the standard packing, this adds an optional permutation of the columns\n    for faster dequantization, as explained in \"Who Says Elephants Can’t Run: Bringing Large\n    Scale MoE Models into Cloud Scale Production\", https://arxiv.org/pdf/2211.10017.\n\n    Args:\n        unpacked (`torch.Tensor`):\n            The un-packed `torch.uint8` tensor\n        reorder (`bool`):\n            Whether columns should be reordered or not before packing.\n\n    Returns:\n        A int32 `torch.Tensor`.\n    \"\"\"\n    bits = 4\n    pack_num = 32 // bits\n    packed = torch.zeros(unpacked.shape[0], unpacked.shape[1] // pack_num, dtype=torch.int32, device=unpacked.device)\n    for col in range(unpacked.shape[1] // pack_num):\n        if reorder:\n            order_map = AWQ_ORDER\n        else:\n            order_map = [0, 1, 2, 3, 4, 5, 6, 7]\n        for i in range(pack_num):\n            packed_col = unpacked[:, col * pack_num + order_map[i]].to(torch.int32)\n            packed[:, col] |= packed_col << (i * bits)\n    return packed\n\n\ndef reverse_awq_order(t: torch.Tensor):\n    bits = 4\n    reverse_order_tensor = torch.arange(\n        t.shape[-1],\n        dtype=torch.int32,\n        device=t.device,\n    )\n    reverse_order_tensor = reverse_order_tensor.reshape(-1, 32 // bits)\n    reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]\n    reverse_order_tensor = reverse_order_tensor.reshape(-1)\n\n    t = t[:, reverse_order_tensor]\n\n    return t\n\n\ndef unpack(packed: torch.Tensor, reorder=False):\n    \"\"\"Unpack a packed int32 tensor to a larger uint8 tensor\n\n    Applies pack operations in reverse order (see pack method for details).\n\n    Args:\n        packed (`torch.Tensor`):\n            The packed `torch.int32` tensor\n        reorder (`bool`):\n            Whether columns should be reordered or not.\n\n    Returns:\n        An unpacked uint8 `torch.Tensor` expanded along the second dimension.\n    \"\"\"\n    unpacked = unpack_int32_to_uint8(packed, bits=4)\n    if reorder:\n        unpacked = reverse_awq_order(unpacked)\n    return unpacked\n\n\ndef pack_v2(unpacked: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Pack uint4 weights in an int16 tensor as expected by AWQ second generation mixed mm kernel\n\n    As compared to the standard packing, this adds three specific formatting:\n\n    - permute rows to counter implicit permutation on Turing and Ampere architecture,\n    - permute rows for faster dequantization,\n    - interleave groups of 'interleave' rows for efficient parallel processing.\n\n    Note that this formatting expects a group size of 128.\n\n    Args:\n        unpacked (`torch.Tensor`):\n            The un-packed `torch.uint8` tensor\n\n    Returns:\n        A int16 `torch.Tensor`.\n    \"\"\"\n    assert unpacked.device.type in [\"cuda\", \"xpu\"]\n    assert unpacked.ndim == 2\n    N, K = unpacked.shape\n    # These two values are hard-coded in the optimized kernels:\n    # - I represents the 'interleave', i.e. the number of values packed at a single coordinate (16 bits / 4 bits),\n    # - S represents the 'kernel stride', and is related to the group size (TBC).\n    I = 4\n    S = 64\n\n    # 1. For faster dequantization, the tensor rows must be permuted as explained in:\n    # https://github.com/NVIDIA/TensorRT-LLM/blob/035b99e0d09d4f2dfdb949810cf7245112aa4165/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp#L161\n    # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] => [0, 1, 8, 9, 16, 17, 24, 25, ...]\n    packed = unpacked.reshape(N, K // 32, 4, 4, 2).permute(0, 1, 3, 2, 4)\n\n    # Reorder each 8 weights for fast dequantization\n    # From: \"Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production\"\n    # https://arxiv.org/pdf/2211.10017\n    # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]\n    packed = packed.permute(0, 1, 2, 4, 3)\n    packed = packed.reshape(N, K)\n\n    # 2. For efficient parallelization, the rows are grouped and interleaved by blocks of kstride into a single row, as explained in:\n    # https://github.com/NVIDIA/TensorRT-LLM/blob/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h#L69\n    # interleaving (N, K) -> (N // I, I, K // S, S)\n    packed = packed.reshape(N // I, I, K // S, S)\n    # transpose (N // I, I, K // S, S) -> (N // I, K // S, I, S)\n    packed = packed.permute(0, 2, 1, 3)\n    # reshape (N // I, K // S, I, S) -> (N // I, K // S, S, I)\n    packed = packed.reshape(N // I, K // S, S, I)\n    # Packing (N // I, K // S, S, I) -> (N // I, K // S, S)\n    packed = packed.to(torch.int32)\n    packed = packed[..., 0] | (packed[..., 1] << 4) | (packed[..., 2] << 8) | (packed[..., 3] << 12)\n    # Reshape to (N // I, K // S, S) -> (N // I, K)\n    packed = packed.reshape(N // I, K)\n    return packed.to(torch.int16).contiguous()\n\n\ndef unpack_v2(packed):\n    \"\"\"Unpack a packed int16 tensor to a larger uint8 tensor\n\n    Applies pack operations in reverse order (see pack_v2 method for details).\n    Warning: very slow, to be used for debug only.\n\n    Args:\n        packed (`torch.Tensor`):\n            The packed `torch.int16` tensor\n\n    Returns:\n        An unpacked uint8 `torch.Tensor` expanded along the first dimension.\n    \"\"\"\n    assert packed.device.type in [\"cuda\", \"xpu\"]\n    assert packed.ndim == 2\n    I = 4\n    S = 64\n    N_div_I, K = packed.shape\n    N = N_div_I * I\n    # Reshape (N // I, K) -> (N // I, K // S, S, 1)\n    unpacked = packed.reshape(N // I, K // S, S, 1)\n    # Convert to uint16 (through numpy because not supported by pytorch)\n    unpacked = unpacked.cpu().numpy().astype(np.uint16)\n    # Unpack (N // I, K, S) -> (N // I, K // S, S, I)\n    unpacked = torch.cat(\n        [\n            torch.tensor((unpacked & 0xF).astype(np.uint8)).to(packed.device),\n            torch.tensor(((unpacked & 0xF0) >> 4).astype(np.uint8)).to(packed.device),\n            torch.tensor(((unpacked & 0xF00) >> 8).astype(np.uint8)).to(packed.device),\n            torch.tensor(((unpacked & 0xF000) >> 12).astype(np.uint8)).to(packed.device),\n        ],\n        axis=-1,\n    )\n    # reshape (N // I, K // S, S, I) -> (N // I, K // S, I, S)\n    unpacked = unpacked.reshape(N // I, K // S, I, S)\n    # transpose (N // I, K // S, I, S) -> (N // I, I, K // S, S)\n    unpacked = unpacked.permute(0, 2, 1, 3)\n    # deinterleaving (N // I, I, K // S, S) -> (N, K)\n    unpacked = unpacked.reshape(N, K)\n\n    # Final steps to reorder (see packing code for explaination)\n    unpacked = unpacked.reshape(N, K // 32, 4, 2, 4).permute(0, 1, 2, 4, 3)\n    unpacked = unpacked.permute(0, 1, 3, 2, 4)\n    unpacked = unpacked.reshape(N, K)\n\n    return unpacked\n\n\nclass AWQPacking(Enum):\n    V1 = 1\n    V2 = 2\n\n\nclass AWQPackedTensor(torch.Tensor):\n    @staticmethod\n    def __new__(cls, data, packing, reorder, size, stride, requires_grad=False):\n        # AWQPackedTensor represents uint8 data and can therefore NEVER require gradient\n        assert data.device.type in [\"cuda\", \"xpu\"]\n        assert data.dtype == torch.int32 if packing == AWQPacking.V1 else torch.int16\n        assert requires_grad is False\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, data, packing, reorder, size, stride, requires_grad=False):\n        self._data = data\n        self._packing = packing\n        self._reorder = reorder\n\n    def __repr__(self):\n        return f\"AWQPackedTensor({self._data}, packing={self._packing}, reorder={self._reorder})\"\n\n    @classmethod\n    def pack(cls, t, packing=AWQPacking.V1, reorder=False):\n        if packing == AWQPacking.V1:\n            data = pack(t, reorder=reorder)\n        else:\n            data = pack_v2(t)\n        # We need to store size and stride to make sure the unpacked data has the correct shape\n        return AWQPackedTensor(data, packing, reorder, t.size(), t.stride())\n\n    def unpack(self):\n        if self._packing == AWQPacking.V1:\n            return unpack(self._data, self._reorder)\n        return unpack_v2(self._data)\n\n    @property\n    def dtype(self):\n        return torch.uint8\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\"]\n        # Since meta can be used for serialization, use only AST compatible strings\n        meta = {\n            \"packing\": str(self._packing),\n            \"reorder\": str(self._reorder),\n            \"size\": str(list(self.size())),\n            \"stride\": str(self.stride()),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 1\n        assert len(meta) == 4\n        data = inner_tensors[\"_data\"]\n        # Meta should contain only AST compatible strings\n        packing = ast.literal_eval(meta[\"packing\"])\n        reorder = ast.literal_eval(meta[\"reorder\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return AWQPackedTensor(data, packing, reorder, size, stride)\n\n    __torch_function__ = torch._C._disabled_torch_function_impl\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Convert back to tensor before calling any operation except detach and move\n        if op.overloadpacket is torch.ops.aten.detach:\n            t = args[0]\n            data = op(t._data)\n            return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride())\n        elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):\n            t = args[0]\n            dtype = kwargs.get(\"dtype\", torch.uint8)\n            if dtype != torch.uint8:\n                raise ValueError(f\"AWQPackedTensor are torch.uint8 only and cannot be moved to {dtype}.\")\n            device = kwargs.get(\"device\", t.device)\n            # AWQPackedTensor can only be moved to CUDA devices\n            if device.type == \"cuda\":\n                data_kwargs = copy(kwargs)\n                data_kwargs[\"dtype\"] = t._data.dtype\n                data = op(t._data, **data_kwargs)\n                return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride())\n        args, kwargs = pytree.tree_map_only(AWQPackedTensor, lambda x: x.unpack(), (args, kwargs or {}))\n        return op(*args, **kwargs)\n\n    def numpy(self):\n        return self.unpack().cpu().numpy()\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/awq/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\nfrom torch.autograd import Function\n\nfrom ...function import QuantizedLinearFunction\nfrom ...grouped import group, ungroup\nfrom ...qtype import qtypes\nfrom ..qbits import WeightQBitsTensor\nfrom .packed import AWQPackedTensor, AWQPacking\n\n\n__all__ = [\"AWQWeightQBitsTensor\"]\n\n\nclass AWQWeightQBitsDequantizer(Function):\n    @staticmethod\n    def forward(ctx, t):\n        unpacked = t._data.unpack()\n        scale = t._scale\n        shift = t._shift\n        unpacked = group(unpacked, axis=0, group_size=t._group_size)\n        n_scales = scale.numel()\n        scale = scale.t().reshape((n_scales, 1))\n        shift = shift.t().reshape((n_scales, 1))\n        if shift.dtype.is_floating_point:\n            # Shift is already scaled and negated on CUDA\n            dqt = scale * unpacked + shift\n        else:\n            # Shift is int type on XPU to support pytorch fused op\n            dqt = (unpacked - shift) * scale\n        return ungroup(dqt, axis=t.axis, orig_shape=t.shape)\n\n    @staticmethod\n    def backward(ctx, gO):\n        return gO\n\n\nclass AWQWeightQBitsLinearFunction(QuantizedLinearFunction):\n    @staticmethod\n    def forward(ctx, input, other, bias):\n        ctx.save_for_backward(input, other)\n        if type(input) is not torch.Tensor:\n            input = input.dequantize()\n        out_features, in_features = other.shape\n        rows = input.numel() // in_features\n        output = torch.ops.quanto.gemm_f16i4_awq(\n            input,\n            other._data._data,\n            other._scale,\n            other._shift,\n            rows=rows,\n            out_cols=out_features,\n            in_cols=in_features,\n            bits=4,\n            group_size=other._group_size,\n        )\n        if bias is not None:\n            output = output + bias\n        return output\n\n\nclass AWQWeightQBitsTensor(WeightQBitsTensor):\n    @staticmethod\n    def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        assert data.device.type in [\"cuda\", \"xpu\"]\n        assert data.device == scale.device\n        assert data.device == shift.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        # XPU requires awq v1 to support pytorch fused op\n        self.packing_type = AWQPacking.V1 if data.device.type == \"xpu\" else AWQPacking.V2\n        assert axis == 0\n        if not isinstance(data, AWQPackedTensor):\n            assert type(data) is torch.Tensor\n            # Format data, scale and shift for optimized CUDA/XPU gemm\n            ungrouped = ungroup(data, axis=0, orig_shape=size)\n            data = AWQPackedTensor.pack(ungrouped, packing=self.packing_type)\n            out_features, in_features = size\n            scale = scale.reshape(out_features, in_features // group_size).t().contiguous()\n            shift = shift.reshape(out_features, in_features // group_size).t()\n            if not shift.dtype.is_floating_point and data.device.type != \"xpu\":\n                # Integer shift must be scaled\n                shift = scale * shift\n            # Shift must be negated\n            shift = shift.contiguous() if data.device.type == \"xpu\" else -shift.contiguous()\n        super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    def dequantize(self):\n        return AWQWeightQBitsDequantizer.apply(self)\n\n    def weight_qbits_tensor(self):\n        \"\"\"Convert back to a WeightQBitsTensor\n\n        This is required to make sure only standard packing is used when serializing.\n        \"\"\"\n        data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size)\n        n_scales = self._scale.numel()\n        scale = self._scale.t().reshape((n_scales, 1))\n        shift = self._shift if self._shift.device.type == \"xpu\" else -self._shift\n        shift = shift.t().reshape((n_scales, 1))\n        return WeightQBitsTensor(\n            self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift\n        )\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\", \"_shift\"]\n        # Since meta can be used for serialization, use only strings\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"group_size\": str(self._group_size),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 3\n        assert len(meta) == 5\n        data, scale, shift = inner_tensors[\"_data\"], inner_tensors[\"_scale\"], inner_tensors[\"_shift\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        group_size = ast.literal_eval(meta[\"group_size\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return AWQWeightQBitsLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/__init__.py",
    "content": "from .fp8 import *\nfrom .int4 import *\nfrom .permutations import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/fp8/__init__.py",
    "content": "from .packed import *\nfrom .qbits import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/fp8/packed.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nfrom copy import copy\n\nimport torch\nfrom torch.utils import _pytree as pytree\n\n\ndef pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Repack FP8 weights to gptq format (packed int32 elements).\n    \"\"\"\n    assert fp8_tensor.dtype == torch.float8_e4m3fn\n\n    if fp8_tensor.shape[0] % 4 != 0:\n        raise ValueError(f\"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}\")\n\n    # Reshape to prepare for packing\n    reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])\n\n    # Convert fp8 to uint8 (byte) representation\n    byte_tensor = reshaped.view(torch.uint8)\n\n    # Pack 4 uint8 values into one int32\n    packed = torch.zeros(\n        fp8_tensor.shape[0] // 4,\n        fp8_tensor.shape[1],\n        dtype=torch.int32,\n        device=fp8_tensor.device,\n    )\n\n    for i in range(4):\n        packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)\n\n    return packed\n\n\ndef unpack_int32_to_fp8(int32_tensor: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Reinterpret a tensor (a, b) of type int32 to a tensor (a * 4, b) of type float8_e4m3fn.\n    \"\"\"\n    bits = 8\n\n    unpacked = []\n    # Unpack each set of values independently\n    for i in range(4):\n        mask = 2 ** (bits * (i + 1)) - 1\n        tmp = (int32_tensor & mask) >> bits * i\n        tmp = tmp.to(torch.uint8)\n        unpacked.append(tmp)\n\n    # Return the concatenated unpacked tensors\n    unpacked = torch.cat(unpacked).view(torch.float8_e4m3fn)\n\n    return unpacked\n\n\ndef get_scale_perms() -> torch.Tensor:\n    scale_perm_single = []\n    for i in range(4):\n        scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return torch.tensor(scale_perm_single, dtype=torch.int64)\n\n\ndef get_row_permutation(n_rows: int) -> torch.Tensor:\n    \"\"\"\n    Generates a tensor of shape (4 * n_rows,) giving the rows mapping to map from marlin-repacked weights to natural order.\n\n    Example: if n_rows = 8, the row mapping from natural to marlin format is\n    rows_idx = [0,  2,  4,  6,\n                16, 18, 20, 22,\n                8, 10, 12, 14,\n                24, 26, 28, 30,\n                1,  3,  5,  7,\n                17, 19, 21, 23,\n                9, 11, 13, 15,\n                25, 27, 29, 31].\n    \"\"\"\n    modulo = n_rows // 4 * 16 - 8\n    b = n_rows // 2\n\n    # Group by 16*k, then by 8 + 16*k\n    rows_idx = [(i * 16) % modulo for i in range(b)]\n    rows_idx[-1] = rows_idx[-2] + 16 if b > 2 else 8\n    rows_idx = torch.tensor(rows_idx)\n\n    # All even indexes, and then all odd indexes.\n    rows_idx = torch.cat((rows_idx, rows_idx + 1))\n\n    # Indexes are grouped by four, each spaced by 2.\n    rows_idx = torch.tile(rows_idx[:, None], (1, 4))\n    rows_idx = rows_idx + torch.tensor([[0, 2, 4, 6]])\n\n    rows_idx = rows_idx.reshape(-1)\n\n    # `rows_idx` holds the mapping of natural rows to marlin rows, so inverse it.\n    rows_idx_rev = torch.empty_like(rows_idx)\n    rows_idx_rev[rows_idx] = torch.arange(len(rows_idx))\n\n    return rows_idx_rev\n\n\ndef get_column_permutation(n_col: int) -> torch.Tensor:\n    \"\"\"\n    Gets the column mapping to map from marlin-repacked weights to natural order.\n\n    The natural order to marlin is: `8 * rest + frac` to `rest + 32 * frac`, by blocks of 256 values.\n    \"\"\"\n    tile_size = 256\n    n_blocks = n_col // tile_size\n\n    a = torch.arange(tile_size)\n    rest = a % 8\n    frac = a // 8\n\n    original_index = 32 * rest + frac\n\n    original_index = torch.arange(n_blocks)[:, None] * 256 + original_index\n    original_index = original_index.reshape(-1)\n\n    # The mapping per-column is:\n    #\n    #      64   64   64   64      64   64   64   64       64   64   64   64\n    # ------------------------------------------------------------------------\n    # |    0    1    2    3  |    0    1    2    3   |    0    1    2    3   |\n    # ------------------------------------------------------------------------\n    #\n    # Hence to retrieve column 0, 1, 2, 3 in order, we need to\n    # shuffle the blocks of 64 values.\n    original_index = original_index.reshape(4 * n_blocks, 64)\n\n    # Generate a shuffling as e.g. [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] for the above.\n    tmp1 = torch.arange(4)\n    tmp1 = tmp1.repeat(n_blocks, 1).T.reshape(-1)  # e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]\n\n    tmp2 = torch.arange(n_blocks) * 4\n    tmp2 = tmp2.repeat(4)  # e.g. [0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8]\n\n    remap_col_index = tmp1 + tmp2\n\n    original_index = original_index[remap_col_index]\n    original_index = original_index.reshape(-1)\n\n    return original_index\n\n\nclass MarlinF8PackedTensor(torch.Tensor):\n    def __new__(cls, data, size, stride, requires_grad=False):\n        assert data.device.type == \"cuda\"\n        assert data.dtype == torch.int32\n        assert requires_grad is False\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=torch.int32, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, data, size, stride, requires_grad=False):\n        self._data = data\n\n    def __repr__(self):\n        return f\"MarlinF8PackedTensor({self._data})\"\n\n    @classmethod\n    def pack(cls, tensor: torch.Tensor):\n        out_features, in_features = tensor.shape\n\n        data_int32 = pack_fp8_as_int32(tensor.T)  # pack fp8 data to in32.\n\n        perm = torch.empty(0, dtype=torch.int, device=tensor.device)\n\n        data_int32 = torch.ops.quanto.pack_fp8_marlin(\n            b_q_weight=data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8\n        )\n\n        return cls(data_int32, size=tensor.size(), stride=tensor.stride())\n\n    def unpack(self) -> torch.Tensor:\n        \"\"\"\n        Reinterprets the packed tensor (a, b) of type int32 and in the marlin order, to a tensor (a * 4, b) of type float8_e4m3fn, in the natural order.\n        \"\"\"\n        float8_data = unpack_int32_to_fp8(self._data)\n\n        # complex indexing is not implemented for 'Float8_e4m3fn'\n        uint8_data = float8_data.view(torch.uint8)\n\n        n_rows, n_col = uint8_data.shape\n\n        # swap columns\n        column_map = get_column_permutation(n_col=n_col)\n\n        uint8_data = uint8_data.T.contiguous()\n        uint8_data = uint8_data[column_map]\n        uint8_data = uint8_data.T.contiguous()\n\n        uint8_data = uint8_data.reshape(uint8_data.shape[0] * 4, -1)\n\n        # swap rows\n        row_map = get_row_permutation(n_rows=n_rows)\n\n        uint8_data = uint8_data[row_map]\n\n        float8_data = uint8_data.view(torch.float8_e4m3fn)\n        float8_data = float8_data.T  # As we originally transposed in `pack_fp8_as_int32`\n\n        return float8_data\n\n    @property\n    def dtype(self):\n        return torch.int32\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\"]\n        # Since meta can be used for serialization, use only AST compatible strings\n        meta = {\n            \"size\": str(list(self.size())),\n            \"stride\": str(self.stride()),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 1\n        assert len(meta) == 2\n        data = inner_tensors[\"_data\"]\n        # Meta should contain only AST compatible strings\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return MarlinF8PackedTensor(data, size, stride)\n\n    __torch_function__ = torch._C._disabled_torch_function_impl\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Convert back to tensor before calling any operation except detach and move\n        if op.overloadpacket is torch.ops.aten.detach:\n            t = args[0]\n            data = op(t._data)\n            return cls(data, t.size(), t.stride())\n        elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):\n            t = args[0]\n            dtype = kwargs.get(\"dtype\", torch.int32)\n            if dtype != torch.int32:\n                raise ValueError(f\"MarlinF8PackedTensor are torch.int32 only and cannot be moved to {dtype}.\")\n            device = kwargs.get(\"device\", t.device)\n            if device.type == \"cuda\":\n                data_kwargs = copy(kwargs)\n                data_kwargs[\"dtype\"] = t._data.dtype\n                data = op(t._data, **data_kwargs)\n                return cls(data, t.size(), t.stride())\n            else:\n                return t.unpack().to(device)\n        else:\n            args, kwargs = pytree.tree_map_only(cls, lambda x: x.unpack(), (args, kwargs or {}))\n            return op(*args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/fp8/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\n\nfrom ....function import QuantizedLinearFunction\nfrom ....qtype import qfloat8_e4m3fn, qtypes\nfrom ...qbytes import WeightQBytesTensor\nfrom .packed import MarlinF8PackedTensor, get_scale_perms\n\n\n__all__ = [\"MarlinF8QBytesTensor\"]\n\n\nclass MarlinF8QBytesLinearFunction(QuantizedLinearFunction):\n    @staticmethod\n    def forward(ctx, input, other, bias=None):\n        ctx.save_for_backward(input, other)\n        input_shape = input.shape\n\n        if input.ndim > 2:\n            input = input.reshape(-1, input_shape[-1])\n\n        output = torch.ops.quanto.gemm_f16f8_marlin(\n            input,\n            b_q_weight=other._data._data,\n            b_scales=other._scale,  # .to(input.dtype)\n            workspace=other._workspace,\n            num_bits=8,\n            size_m=input.shape[0],\n            size_n=other._scale.shape[1],\n            size_k=input.shape[1],\n        )\n\n        if len(input_shape) > 2:\n            output = output.reshape(input_shape[:-1] + (other._scale.shape[1],))\n\n        return output\n\n\nclass MarlinF8QBytesTensor(WeightQBytesTensor):\n    @staticmethod\n    def __new__(cls, qtype, axis, size, stride, data, scale, requires_grad=False):\n        assert data.device.type == \"cuda\"\n        assert data.device == scale.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False):\n        assert axis == 0\n        assert data.ndim == 2\n\n        out_features = size[0]\n        self._workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=data.device)\n\n        # TODO: Here we should use `not isinstance(data, MarlinF8PackedTensor)`, but `torch.compile` is bugged when using that.\n        # Somewhere in the internals of torch.compile, `data` gets converted to a `torch._subclasses.fake_tensor.FakeTensor` not inheriting from `MarlinF8PackedTensor` and torch then goes into the wrong controlflow.\n        # Reference: https://pytorch.slack.com/archives/C033H6DJSJU/p1721837684035049\n        if data.dtype != torch.int32:\n            assert scale.shape == (out_features, 1)\n            scale_perm_single = get_scale_perms()\n            scale = scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n            scale = scale.reshape(-1, out_features).contiguous()\n\n            data_packed = MarlinF8PackedTensor.pack(data)  # pack fp8 data to in32, and apply marlier re-ordering.\n        else:\n            # When freezing (`model.freeze()`), the data is already a MarlinF8PackedTensor and scale is already repacked.\n            data_packed = data\n\n        super().__init__(\n            qtype, axis, size, stride, data_packed, scale, activation_qtype=qfloat8_e4m3fn, requires_grad=requires_grad\n        )\n\n    def dequantize(self):\n        float8_data = self._data.unpack()\n\n        scale_perm_single = get_scale_perms()\n\n        # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here.\n        scale_perm_single_rev = torch.empty_like(scale_perm_single)\n        scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single))\n\n        scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev]\n        scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).contiguous()\n\n        return float8_data.to(scale_reordered.dtype) * scale_reordered.T\n\n    def __repr__(self):\n        return f\"MarlinF8QBytesTensor({self._data}, scale={self._scale}, dtype={self.dtype})\"\n\n    def weight_qbytes_tensor(self):\n        data = self._data.unpack()\n        scale_perm_single = get_scale_perms()\n\n        # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here.\n        scale_perm_single_rev = torch.empty_like(scale_perm_single)\n        scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single))\n\n        scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev]\n        scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).t().contiguous()\n        return WeightQBytesTensor(\n            self._qtype, self._axis, self.size(), self.stride(), data, scale_reordered, self.activation_qtype\n        )\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\"]\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 2\n        assert len(meta) == 4\n        data, scale = inner_tensors[\"_data\"], inner_tensors[\"_scale\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return MarlinF8QBytesLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        elif func is torch.equal:\n            input, other = args\n            return input.equal(other)\n\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/int4/__init__.py",
    "content": "from .packed import *\nfrom .qbits import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/int4/packed.py",
    "content": "import ast\nfrom copy import copy\n\nimport numpy as np\nimport torch\nfrom torch.utils import _pytree as pytree\n\nfrom ...packing import unpack_int32_to_uint8\nfrom ...reordering import reorder, reverse\n\n\n__all__ = [\"MarlinInt4PackedTensor\"]\n\n\n# From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L40\n# this func does 2 things\n# 1. 1 thread can load 32 4bit == 128bit weights used for mulitple mma instructions at once\n# 2. faster dequant via parallel half2 mul\ndef _get_perm():\n    perm = []\n    # 32 == # of threads in 1 warp\n    for i in range(32):\n        perm1 = []\n        # column id in 16x8 weight block\n        # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float\n        col = i // 4\n        # 1 32bit (int32) == 8 4bit, 1 thread has 4 weights per 16x8 & 4bit weights are packed in int32, so needs 2 16x8 == 1 16x16 blocks\n        for block in [0, 1]:\n            # row id in 16x8 weight block\n            # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float\n            for row in [\n                2 * (i % 4),\n                2 * (i % 4) + 1,\n                2 * (i % 4 + 4),\n                2 * (i % 4 + 4) + 1,\n            ]:\n                # 8 weights used for 1 thread (16x16 block) are contiguous in memory via interleaving\n                # e.g. T0 uses (0, 16, 128, 144, 8, 24, 136, 152)\n                perm1.append(16 * row + col + 8 * block)\n        # 1 128bit (int4) == 4 32bit, 1 thread loads 128bit at once, so needs 4 16x16 == 1 16x64 blocks\n        for j in range(4):\n            # 32 weights loaded by 1 thread (16x64 block) are contiguous in memory via interleaving\n            # e.g. T0 uses ((0 ~ 152) + 0 * 256, (0 ~ 152) + 1 * 256, ..., (0 ~ 152) + 3 * 256)\n            perm.extend([p + 256 * j for p in perm1])\n    perm = np.array(perm)\n    # for faster dequant\n    # check https://arxiv.org/pdf/2211.10017\n    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])\n    perm = perm.reshape((-1, 8))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    return perm\n\n\n_perm = _get_perm()\n_rev_perm = reverse(_perm)\n\n\n# From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L102\ndef pack(unpacked: torch.Tensor):\n    w = unpacked\n    N, K = w.shape\n    w = unpacked.t()\n    # 16 == tile size, marlin uses 16x16 tile, so 16x16 grouping via interleaving\n    w = w.reshape((K // 16, 16, N // 16, 16))\n    w = w.permute((0, 2, 1, 3))\n    w = w.reshape((K // 16, N * 16))\n    res = w\n    # _perm.numel() == 1024 == 4 16x16, permute weights with 4 16x16 unit for efficient mma + dequant\n    res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)\n    p = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)\n    res = res.cpu().numpy().astype(np.uint32)\n    for i in range(8):\n        p |= res[:, i::8] << 4 * i\n    p = torch.from_numpy(p.astype(np.int32)).to(w.device)\n    return p\n\n\ndef unpack(packed, orig_shape):\n    N, K = orig_shape\n    # Unpack to recover individual values\n    unpacked = unpack_int32_to_uint8(packed, bits=4).to(torch.uint8)\n    # Recover the original ordering\n    unpacked = reorder(unpacked, _rev_perm)\n    # Apply block permutations in the reverse order\n    unpacked = unpacked.reshape(K // 16, N // 16, 16, 16)\n    unpacked = unpacked.permute((0, 2, 1, 3))\n    unpacked = unpacked.reshape(K, N)\n    return unpacked.t()\n\n\nclass MarlinInt4PackedTensor(torch.Tensor):\n    @staticmethod\n    def __new__(cls, data, size, stride, requires_grad=False):\n        assert data.device.type == \"cuda\"\n        assert data.dtype == torch.int32\n        assert requires_grad is False\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, data, size, stride, requires_grad=False):\n        self._data = data\n\n    def __repr__(self):\n        return f\"MarlinInt4PackedTensor({self._data})\"\n\n    @classmethod\n    def pack(cls, t):\n        data = pack(t)\n        return MarlinInt4PackedTensor(data, t.size(), t.stride())\n\n    def unpack(self):\n        return unpack(self._data, self.size())\n\n    @property\n    def dtype(self):\n        return torch.uint8\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\"]\n        meta = {\n            \"size\": str(list(self.size())),\n            \"stride\": str(self.stride()),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 1\n        assert len(meta) == 2\n        data = inner_tensors[\"_data\"]\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return MarlinInt4PackedTensor(data, size, stride)\n\n    __torch_function__ = torch._C._disabled_torch_function_impl\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        if op.overloadpacket is torch.ops.aten.detach:\n            t = args[0]\n            data = op(t._data)\n            return MarlinInt4PackedTensor(data, t.size(), t.stride())\n        elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):\n            t = args[0]\n            dtype = kwargs.get(\"dtype\", torch.uint8)\n            if dtype != torch.uint8:\n                raise ValueError(f\"MarlinInt4PackedTensor are torch.uint8 only and cannot be moved to {dtype}.\")\n            device = kwargs.get(\"device\", t.device)\n            if device.type == \"cuda\":\n                data_kwargs = copy(kwargs)\n                data_kwargs[\"dtype\"] = t._data.dtype\n                data = op(t._data, **data_kwargs)\n                return MarlinInt4PackedTensor(data, t.size(), t.stride())\n            return t.unpack()\n        args, kwargs = pytree.tree_map_only(MarlinInt4PackedTensor, lambda x: x.unpack(), (args, kwargs or {}))\n        return op(*args, **kwargs)\n\n    def numpy(self):\n        return self.unpack().cpu().numpy()\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/int4/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\nfrom torch.autograd import Function\n\nfrom ....function import QuantizedLinearFunction\nfrom ....grouped import group, ungroup\nfrom ....qtype import qtypes\nfrom ...qbits import WeightQBitsTensor\nfrom ..permutations import marlin_permute\nfrom .packed import MarlinInt4PackedTensor\n\n\n__all__ = [\"MarlinInt4WeightQBitsTensor\"]\n\n\nclass MarlinQBitsDequantizer(Function):\n    @staticmethod\n    def forward(ctx, t):\n        unpacked = t._data.unpack()\n        scale = t._scale\n        shift = t._shift\n        unpacked = group(unpacked, axis=0, group_size=t._group_size)\n        # Apply inverted permutations\n        scale = marlin_permute(scale, reverse=True)\n        shift = marlin_permute(shift, reverse=True)\n        n_scales = scale.numel()\n        scale = scale.t().reshape((n_scales, 1))\n        shift = shift.t().reshape((n_scales, 1))\n        # Shift is already scaled and negated\n        dqt = scale * unpacked + shift\n        return ungroup(dqt, axis=t.axis, orig_shape=t.shape)\n\n    @staticmethod\n    def backward(ctx, gO):\n        return gO\n\n\nclass MarlinQBitsLinearFunction(QuantizedLinearFunction):\n    @staticmethod\n    def forward(ctx, input, other, bias):\n        ctx.save_for_backward(input, other)\n        if type(input) is not torch.Tensor:\n            input = input.dequantize()\n        out_features, in_features = other.shape\n        output = torch.ops.quanto.gemm_f16i4_marlin(\n            input,\n            other._data._data,\n            other._scale,\n            other._shift,\n            other._workspace,\n        )\n        if bias is not None:\n            output = output + bias\n        return output\n\n\nclass MarlinInt4WeightQBitsTensor(WeightQBitsTensor):\n    @staticmethod\n    def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        assert data.device.type == \"cuda\"\n        assert data.device == scale.device\n        assert data.device == shift.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        assert axis == 0\n        out_features, in_features = size\n        if not isinstance(data, MarlinInt4PackedTensor):\n            assert type(data) is torch.Tensor\n            # Format data, scale and shift for optimized CUDA gemm\n            ungrouped = ungroup(data, axis=0, orig_shape=size)\n            data = MarlinInt4PackedTensor.pack(ungrouped)\n            scale = scale.reshape(out_features, in_features // group_size).t().contiguous()\n            shift = shift.reshape(out_features, in_features // group_size).t()\n            if not shift.dtype.is_floating_point:\n                # Integer shift must be scaled\n                shift = scale * shift\n            # Shift must be negated\n            shift = -shift.contiguous()\n            # Finally, apply scale and shift permutations\n            scale = marlin_permute(scale)\n            shift = marlin_permute(shift)\n        super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)\n        self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device)\n\n    def dequantize(self):\n        return MarlinQBitsDequantizer.apply(self)\n\n    def weight_qbits_tensor(self):\n        \"\"\"Convert back to a WeightQBitsTensor\n\n        This is required to make sure only standard packing is used when serializing.\n        \"\"\"\n        data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size)\n        scale = marlin_permute(self._scale, reverse=True)\n        shift = marlin_permute(self._shift, reverse=True)\n        n_scales = scale.numel()\n        scale = scale.t().reshape((n_scales, 1))\n        shift = -shift.t().reshape((n_scales, 1))\n        return WeightQBitsTensor(\n            self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift\n        )\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\", \"_shift\"]\n        # Since meta can be used for serialization, use only strings\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"group_size\": str(self._group_size),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 3\n        assert len(meta) == 5\n        data, scale, shift = inner_tensors[\"_data\"], inner_tensors[\"_scale\"], inner_tensors[\"_shift\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        group_size = ast.literal_eval(meta[\"group_size\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return MarlinInt4WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return MarlinQBitsLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/marlin/permutations.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#         http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nfrom typing import List, Tuple\n\nimport torch\n\nfrom ..reordering import reorder, reverse\n\n\n__all__ = [\"marlin_permute\"]\n\n\n# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54\n@functools.cache\ndef _get_perms() -> Tuple[List[int], List[int]]:\n    perm = []\n    for i in range(8):\n        perm.extend([i + 8 * j for j in range(8)])\n    perm_single = []\n    for i in range(4):\n        perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return perm, perm_single\n\n\n@functools.cache\ndef _get_inverted_perms() -> Tuple[List[int], List[int]]:\n    perm, perm_single = _get_perms()\n    return reverse(perm), reverse(perm_single)\n\n\ndef marlin_permute(t: torch.Tensor, reverse=False):\n    perm, perm_single = _get_inverted_perms() if reverse else _get_perms()\n    out_features = t.shape[1]\n    if t.shape[0] == 1:\n        reordered = reorder(t, perm_single)\n    else:\n        reordered = reorder(t, perm)\n    return reordered.reshape((-1, out_features)).contiguous()\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/packing.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\n\n\ndef unpack_int32_to_uint8(packed: torch.Tensor, bits: int):\n    \"\"\"Unpack a packed int32 tensor to a larger uint8 tensor\n\n    Args:\n        packed (`torch.Tensor`):\n            The packed integer tensor\n        bits: (`int`):\n            The number of bits of each packed value.\n\n    Returns:\n        An unpacked uint8 `torch.Tensor` expanded along the last dimension.\n    \"\"\"\n    total_bits = 32\n    shifts = torch.arange(0, total_bits, bits, device=packed.device)\n\n    # Unpack column-wise\n    unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to(\n        torch.int8  # smallest dtype available\n    )\n    unpacked = unpacked.reshape(unpacked.shape[0], -1)\n\n    # Convert to unsigned\n    unpacked = torch.bitwise_and(unpacked, (2**bits) - 1)\n\n    unpacked = unpacked if packed.device.type == \"xpu\" else unpacked.to(torch.uint8)\n\n    return unpacked\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nfrom typing import Optional\n\nimport torch\nfrom packaging import version\nfrom torch.autograd import Function\n\nfrom ...library import is_extension_available\nfrom ..function import QuantizedLinearFunction\nfrom ..grouped import grouped_shape\nfrom ..packed import PackedTensor\nfrom ..qbits import QBitsTensor\nfrom ..qtensor import qfallback\nfrom ..qtype import qint2, qint4, qtype, qtypes\n\n\n__all__ = [\"WeightQBitsTensor\"]\n\n\nclass WeightsQBitsQuantizer(Function):\n    @staticmethod\n    def forward(\n        ctx,\n        base: torch.Tensor,\n        qtype: qtype,\n        axis: int,\n        group_size: int,\n        scale: torch.Tensor,\n        shift: torch.Tensor,\n        optimized: bool,\n    ):\n        if qtype not in (qint2, qint4):\n            raise ValueError(\"WeightQBitsTensor can only be of qint2 or qint4 qtype\")\n        if axis not in (0, -1):\n            raise ValueError(\"WeightQBitsTensor axis parameter must be 0 (first axis) or -1 (last axis)\")\n        size = base.size()\n        stride = base.stride()\n        data = torch.ops.quanto.quantize_affine(\n            base, bits=qtype.bits, axis=axis, group_size=group_size, scale=scale, shift=shift\n        )\n        if optimized:\n            return WeightQBitsTensor.create(qtype, axis, group_size, size, stride, data, scale, shift)\n        return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    @staticmethod\n    def backward(ctx, gO):\n        # For autograd, quantization is a no-op\n        return gO, None, None, None, None, None, None\n\n\nclass WeightQBitsTensor(QBitsTensor):\n    @staticmethod\n    def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        \"\"\"Factory method to create a WeightQBitsTensor\n\n        This selects the most appropriate WeightQBitsTensor based on the configuration.\n\n        Args:\n            axis (`int`):\n                The axis that is preserved by quantization (usually zero for linear weights).\n            group_size (`int`):\n                The group size that further splits the data elements for each index along the quantization axis.\n            size ():\n                The Tensor size.\n            stride():\n                The Tensor stride.\n            data (`torch.Tensor`):\n                The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor.\n            scale (`torch.Tensor`):\n                The floating point scale expressed as a torch.Tensor.\n            shift (`torch.Tensor`):\n                The shift expressed as a torch.Tensor. It can be either an integer representing zero\n                (i.e. zero-point) or a float value.\n            requires_grad (`bool`):\n                If the Tensor must be receive a gradient or not.\n\n        Returns:\n            a `WeightQBitsTensor` (can be a subclass).\n        \"\"\"\n        from .awq import AWQWeightQBitsTensor\n        from .tinygemm import TinyGemmWeightQBitsTensor\n\n        if (\n            qtype == qint4\n            and size[0] >= 128  # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)\n            and scale.dtype == torch.float16\n            and axis == 0\n            and group_size == 128\n            and len(size) == 2\n            and (data.device.type == \"cuda\" and torch.version.cuda)\n            and torch.cuda.get_device_capability(data.device)[0] >= 8\n            and is_extension_available(\"quanto_cuda\")\n        ) or (\n            qtype == qint4\n            and axis == 0\n            and group_size == 128\n            and len(size) == 2\n            and data.device.type == \"xpu\"\n            and shift.dtype == torch.int8\n            and version.parse(torch.__version__).release >= version.parse(\"2.8.0\").release\n        ):\n            if type(data) is PackedTensor:\n                data = data.unpack()\n            return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)\n        if (\n            qtype == qint4\n            and scale.dtype == torch.bfloat16\n            and shift.dtype == torch.bfloat16\n            and axis == 0\n            and group_size == 128\n            and len(size) == 2\n        ):\n            if data.device.type == \"cpu\" or (\n                (data.device.type == \"cuda\" and torch.version.cuda)\n                and version.parse(torch.version.cuda).release >= (12, 1)\n                and torch.cuda.get_device_capability(data.device)[0] >= 8\n            ):\n                if type(data) is PackedTensor:\n                    data = data.unpack()\n                return TinyGemmWeightQBitsTensor(\n                    qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad\n                )\n\n        return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)\n\n    @staticmethod\n    def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        assert data.device == scale.device\n        assert data.device == shift.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):\n        if type(data) is torch.Tensor:\n            data = PackedTensor.pack(data, qtype.bits)\n        super().__init__(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    @classmethod\n    def quantize(\n        cls,\n        base: torch.Tensor,\n        qtype: qtype,\n        axis: int,\n        group_size: int,\n        scale: torch.Tensor,\n        shift: torch.Tensor,\n        optimized: Optional[bool] = True,\n    ):\n        return WeightsQBitsQuantizer.apply(base, qtype, axis, group_size, scale, shift, optimized)\n\n    @staticmethod\n    def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):\n        if group_size is None:\n            data_size = size\n            data_stride = stride\n        else:\n            data_size = grouped_shape(size, axis, group_size)\n            assert len(data_size) == 2\n            # In row major, inner dimension (stride 1) is the last one\n            data_stride = (data_size[1], 1)\n        inner_tensors_dict = {\n            \"_data\": PackedTensor.load_from_state_dict(\n                state_dict, prefix + \"_data.\", qtype.bits, data_size, data_stride, missing_keys=missing_keys\n            )\n        }\n        missing = inner_tensors_dict[\"_data\"] is None\n        for name in [\"_scale\", \"_shift\"]:\n            if prefix + name not in state_dict:\n                missing_keys.append(prefix + name)\n                missing = True\n            else:\n                inner_tensors_dict[name] = state_dict.pop(prefix + name)\n\n        if missing:  # could not deserialize because of missing keys\n            return None\n\n        meta = {\n            \"qtype\": qtype.name,\n            \"axis\": str(axis),\n            \"group_size\": str(group_size),\n            \"size\": str(list(size)),\n            \"stride\": str(list(stride)),\n        }\n        return WeightQBitsTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)\n\n    def optimize(self):\n        \"\"\"Allows to convert an existing WeightQBitsTensor to an optimized subclass\n\n        This is used in particular after reloading a serialized WeightQBitsTensor (which is\n        always saved using the kernel-agnostic packing).\n        \"\"\"\n        if type(self) is not WeightQBitsTensor:\n            return self\n        data = self._data.unpack()\n        # Call dedicated helper to select the best subclass for this device\n        return WeightQBitsTensor.create(\n            self.qtype,\n            self.axis,\n            self._group_size,\n            self.size(),\n            self.stride(),\n            data,\n            self._scale,\n            self._shift,\n            self.requires_grad,\n        )\n\n    def save_to_state_dict(self, destination, prefix, keep_vars):\n        if type(self) is WeightQBitsTensor:\n            super().save_to_state_dict(destination, prefix, keep_vars)\n        else:\n            # Convert back subclass before serializing\n            self.weight_qbits_tensor().save_to_state_dict(destination, prefix, keep_vars)\n\n    def weight_qbits_tensor(self):\n        \"\"\"Convert back a subclass to a WeightQBitsTensor\n\n        This is required to make sure only standard packing is used when serializing.\n        \"\"\"\n        raise NotImplementedError\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\", \"_shift\"]\n        # Since meta can be used for serialization, use only strings\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"group_size\": str(self._group_size),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 3\n        assert len(meta) == 5\n        data, scale, shift = inner_tensors[\"_data\"], inner_tensors[\"_scale\"], inner_tensors[\"_shift\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        group_size = ast.literal_eval(meta[\"group_size\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return QuantizedLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        elif func is torch.equal:\n            input, other = args\n            return input.equal(other)\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Do not use directly op, but rather its overload\n        op = op.overloadpacket\n        if op is torch.ops.aten.detach:\n            t = args[0]\n            # Detach is required when copying and deserializing\n            inner_tensor_names, meta = t.__tensor_flatten__()\n            # Detach inner tensors\n            detached_tensors = {}\n            for inner_name in inner_tensor_names:\n                detached_tensors[inner_name] = op(getattr(t, inner_name))\n            return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())\n        elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:\n            t = args[0]\n            dtype = kwargs.pop(\"dtype\", t.dtype)\n            device = kwargs.pop(\"device\", t.device)\n            if dtype is not None and dtype != t.dtype:\n                raise ValueError(\"The dtype of a WeightQBitsTensor cannot be changed\")\n            if type(t) is not WeightQBitsTensor and t.device.type != device.type:\n                # Before moving to another device type, convert back to a WeightQBitsTensor\n                t = t.weight_qbits_tensor()\n            scale = op(t._scale, dtype=dtype, device=device, **kwargs)\n            data = op(t._data, device=device, **kwargs)\n            shift = op(t._shift, device=device, **kwargs)\n            return WeightQBitsTensor.create(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift)\n        # No dispatch available: qfallback\n        kwargs = kwargs or {}\n        return qfallback(op, *args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/qbytes.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nfrom typing import Optional\n\nimport torch\nfrom torch.autograd import Function\n\nfrom ...library import is_extension_available\nfrom ..function import QuantizedLinearFunction\nfrom ..qbytes import QBytesTensor\nfrom ..qtensor import qfallback\nfrom ..qtype import qtype, qtypes\n\n\n__all__ = [\"WeightQBytesTensor\"]\n\n\nclass WeightQBytesQuantizer(Function):\n    @staticmethod\n    def forward(\n        ctx, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, activation_qtype: qtype, optimized: bool\n    ) -> torch.Tensor:\n        if qtype.bits != 8:\n            raise ValueError(\"QBytesTensor can only be of 8-bit qtype\")\n        data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=axis, scale=scale)\n        # The instantiation of the quantized tensor must happen within the context of the Function\n        # for the autograd magic to work.\n\n        if optimized:\n            return WeightQBytesTensor.create(\n                qtype,\n                axis,\n                size=base.size(),\n                stride=base.stride(),\n                data=data,\n                scale=scale,\n                activation_qtype=activation_qtype,\n            )\n        return WeightQBytesTensor(\n            qtype,\n            axis,\n            size=base.size(),\n            stride=base.stride(),\n            data=data,\n            scale=scale,\n            activation_qtype=activation_qtype,\n        )\n\n    @staticmethod\n    def backward(ctx, gO):\n        # For autograd, quantization is a no-op\n        return gO, None, None, None, None, None, None\n\n\nclass WeightQBytesLinearFunction(QuantizedLinearFunction):\n    @staticmethod\n    def forward(ctx, input, other, bias=None):\n        ctx.save_for_backward(input, other)\n        if isinstance(input, QBytesTensor):\n            output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale)\n        else:\n            in_features = input.shape[-1]\n            out_features = other.shape[0]\n            output_shape = input.shape[:-1] + (out_features,)\n            output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale)\n            output = output.reshape(output_shape)\n        if bias is not None:\n            output = output + bias\n        return output\n\n\nclass WeightQBytesTensor(QBytesTensor):\n    @staticmethod\n    def create(\n        qtype,\n        axis,\n        size,\n        stride,\n        data,\n        scale,\n        activation_qtype: Optional[qtype] = None,\n        requires_grad=False,\n    ):\n        \"\"\"Factory method to create a QBytesTensor\n\n        This selects the most appropriate QBytesTensor based on the configuration.\n\n        Args:\n            axis (`int`):\n                The axis that is preserved by quantization (usually zero for linear weights).\n            size ():\n                The Tensor size.\n            stride():\n                The Tensor stride.\n            data (`torch.Tensor`):\n                The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor.\n            scale (`torch.Tensor`):\n                The floating point scale expressed as a torch.Tensor.\n            activation_qtype (`qtype`, defaults to `None`):\n                The qtype used for the activations. If one needs to use a different tensor subclass e.g. for weights depending on the activations qtype, this argument must be specified accordingly when calling `QBytesTensor.create`.\n            requires_grad (`bool`):\n                If the Tensor must be receive a gradient or not.\n\n        Returns:\n            a `QBytesTensor` (can be a subclass).\n        \"\"\"\n        from .marlin import MarlinF8QBytesTensor\n\n        if (\n            qtype == qtypes[\"qfloat8_e4m3fn\"]\n            and activation_qtype is None\n            and scale.dtype in [torch.float16, torch.bfloat16]\n            and len(size) == 2\n            and (data.device.type == \"cuda\" and torch.version.cuda)\n            and axis == 0\n            and torch.cuda.get_device_capability(data.device)[0] >= 8\n            and is_extension_available(\"quanto_cuda\")\n        ):\n            out_features, in_features = size\n            if (\n                in_features >= 64\n                and out_features >= 64\n                and (\n                    (in_features % 64 == 0 and out_features % 128 == 0)\n                    or (in_features % 128 == 0 and out_features % 64 == 0)\n                )\n            ):\n                return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad)\n\n        return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype, requires_grad)\n\n    @staticmethod\n    def __new__(cls, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False):\n        assert data.device == scale.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False):\n        super().__init__(qtype, axis, size, stride, data, scale, requires_grad=requires_grad)\n        self.activation_qtype = activation_qtype\n\n    @classmethod\n    def quantize(\n        cls,\n        base: torch.Tensor,\n        qtype: qtype,\n        axis: int,\n        scale: torch.Tensor,\n        activation_qtype: Optional[qtype] = None,\n        optimized: Optional[bool] = True,\n    ) -> torch.Tensor:\n        return WeightQBytesQuantizer.apply(base, qtype, axis, scale, activation_qtype, optimized)\n\n    @staticmethod\n    def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, activation_qtype, missing_keys):\n        inner_tensors_dict = {}\n        missing = False\n        for name in [\"_data\", \"_scale\"]:\n            if prefix + name not in state_dict:\n                missing_keys.append(prefix + name)\n                missing = True\n            else:\n                inner_tensors_dict[name] = state_dict.pop(prefix + name)\n\n        if missing:  # could not deserialize because of missing keys\n            return None\n\n        meta = {\n            \"qtype\": qtype.name,\n            \"axis\": str(axis),\n            \"size\": str(list(size)),\n            \"stride\": str(list(stride)),\n            \"activation_qtype\": \"none\" if activation_qtype is None else activation_qtype.name,\n        }\n        return WeightQBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)\n\n    def optimize(self):\n        \"\"\"Allows to convert an existing WeightQBytesTensor to an optimized subclass\n\n        This is used in particular after reloading a serialized WeightQBytesTensor (which is\n        always saved using the kernel-agnostic packing).\n        \"\"\"\n        if type(self) is not WeightQBytesTensor:\n            return self\n        # Call dedicated helper to select the best subclass for this device\n        return WeightQBytesTensor.create(\n            self.qtype,\n            self.axis,\n            self.size(),\n            self.stride(),\n            self._data,\n            self._scale,\n            self.activation_qtype,\n            self.requires_grad,\n        )\n\n    def save_to_state_dict(self, destination, prefix, keep_vars):\n        if type(self) is WeightQBytesTensor:\n            super().save_to_state_dict(destination, prefix, keep_vars)\n        else:\n            # Convert back subclass before serializing\n            self.weight_qbytes_tensor().save_to_state_dict(destination, prefix, keep_vars)\n\n    def weight_qbytes_tensor(self):\n        \"\"\"Convert back a subclass to a WeightQBytesTensor\n\n        This is required to make sure only standard packing is used when serializing.\n        \"\"\"\n        raise NotImplementedError\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale\"]\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n            \"activation_qtype\": \"none\" if self.activation_qtype is None else self.activation_qtype.name,\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 2\n        assert len(meta) == 5\n        data, scale = inner_tensors[\"_data\"], inner_tensors[\"_scale\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        activation_qtype = None if meta[\"activation_qtype\"] == \"none\" else qtypes[meta[\"activation_qtype\"]]\n        return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return WeightQBytesLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        elif func is torch.equal:\n            input, other = args\n            return input.equal(other)\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Do not use directly op, but rather its overload\n        op = op.overloadpacket\n        if op is torch.ops.aten.detach:\n            t = args[0]\n            # Detach is required when copying and deserializing\n            inner_tensor_names, meta = t.__tensor_flatten__()\n            # Detach inner tensors\n            detached_tensors = {}\n            for inner_name in inner_tensor_names:\n                detached_tensors[inner_name] = op(getattr(t, inner_name))\n            return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride())\n        elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]:\n            t = args[0]\n            dtype = kwargs.pop(\"dtype\", t.dtype)\n            device = kwargs.pop(\"device\", t.device)\n            if dtype != t.dtype:\n                raise ValueError(\"The dtype of a weights Tensor cannot be changed\")\n            if type(t) is not WeightQBytesTensor and t.device.type != device.type:\n                # Before moving to another device type, convert back to a WeightQBytesTensor\n                t = t.weight_qbytes_tensor()\n            out_data = op(t._data, device=device, **kwargs)\n            out_scale = op(t._scale, device=device, **kwargs)\n            return WeightQBytesTensor.create(\n                t.qtype,\n                t.axis,\n                t.size(),\n                t.stride(),\n                out_data,\n                out_scale,\n                activation_qtype=t.activation_qtype,\n                requires_grad=t.requires_grad,\n            )\n        elif op is torch.ops.aten.t and cls is WeightQBytesTensor:\n            t = args[0]\n            out_data = op(t._data)\n            out_scale = t._scale\n            out_axis = t.axis\n            # Manually reverse size and stride because we cannot trust the out_data shape\n            dim0, dim1 = t.size()\n            out_size = torch.Size([dim1, dim0])\n            out_stride = t.stride()[::-1]\n            if t.axis is not None:\n                # We need to transpose also the scale\n                out_scale = op(out_scale)\n                out_axis = 0 if out_axis == -1 else -1\n            return WeightQBytesTensor(t.qtype, out_axis, out_size, out_stride, out_data, out_scale, t.activation_qtype)\n        # No dispatch available: qfallback\n        kwargs = kwargs or {}\n        return qfallback(op, *args, **kwargs)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/quantization.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Optional\n\nimport torch\n\nfrom ..qtype import qtype\nfrom .qbits import WeightQBitsTensor\nfrom .qbytes import WeightQBytesTensor\n\n\n__all__ = [\"quantize_weight\"]\n\n\ndef quantize_weight(\n    t: torch.Tensor,\n    qtype: qtype,\n    axis: int,\n    scale: torch.Tensor,\n    shift: Optional[torch.Tensor] = None,\n    group_size: Optional[int] = None,\n    activation_qtype: Optional[qtype] = None,\n    optimized: Optional[bool] = True,\n):\n    \"\"\"Quantize a weight Tensor.\n\n    Weights are always quantized per-axis.\n\n    Args:\n        t (`torch.Tensor`): the weight Tensor to quantize\n        qtype (`quanto.qtype`): The target quantization type\n        axis ('int`): The quantization axis (0 or -1)\n        scale (`torch.Tensor`): the quantization scale\n        shift (`Optional[torch.Tensor]`): optional shift to apply\n        group_size (`Optional[int]`): The quantization group size\n        activation_qtype (`Optional[qtype]`, defaults to `None`):\n            Which quantization type is being used for the activations. The function `quantize_weight`\n            initializes `torch.Tensor` subclasses that may depend on the activation dtype.\n            `None` corresponds to no quantization.\n        optimized (`Optional[bool]`, defaults to True):\n            If True, the quantization algorithm will select the most efficient kernel\n            for the weights and format the resulting Tensor accordingly.\n            If False, a kernel-agnostic Tensor will be returned (but it can be optimized later\n            explicitly by calling QTensor.optimize() or implicitly by moving it to a specific device).\n    Returns:\n        A quantized Tensor.\n    \"\"\"\n    if axis not in (0, -1):\n        raise ValueError(\"axis parameter must be 0 (first axis) or -1 (last axis)\")\n    if qtype.bits == 8:\n        if shift is not None:\n            raise ValueError(\"shift cannot be specified for 8-bit qtypes\")\n        if group_size is not None:\n            raise ValueError(\"group_size cannot be specified for 8-bit qtypes.\")\n        if axis is not None and t.shape[axis] == 1:\n            # Quantizing along an axis of dimension 1 means quantizing per-tensor\n            axis = None\n        return WeightQBytesTensor.quantize(t, qtype, axis, scale, activation_qtype, optimized)\n    if shift is None:\n        raise ValueError(\"shift must be specified for qtypes lower than 8-bit\")\n    return WeightQBitsTensor.quantize(t, qtype, axis, group_size, scale, shift, optimized)\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/reordering.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import List, Union\n\nimport torch\n\n\n__all__ = [\"reorder\", \"reverse\"]\n\n\ndef reorder(t: torch.Tensor, permutation: Union[torch.Tensor, List[int]]):\n    \"\"\"Reorder a Tensor using a permutation\n\n    Args:\n        t (`torch.Tensor`): the Tensor to reorder\n        permutation (`Union[torch.Tensor, List[int]]`): the permutation to apply\n\n    Returns:\n        The reordered torch.Tensor\n    \"\"\"\n    block_size = permutation.numel() if isinstance(permutation, torch.Tensor) else len(permutation)\n    reordered = t.reshape((-1, block_size))[:, permutation].reshape(t.shape)\n    return reordered.contiguous()\n\n\ndef reverse(permutation: Union[torch.Tensor, List[int]]):\n    \"\"\"Reverse a permutation\n\n    The reversed permutation can be used to revert a reordered Tensor to its original\n    ordering.\n\n    Args:\n        permutation (`Union[torch.Tensor, List[int]]`): the permutation to reverse\n\n    Returns:\n        The reversed permutation\n    \"\"\"\n    block_size = permutation.numel() if isinstance(permutation, torch.Tensor) else len(permutation)\n    reversed = torch.empty((block_size,), dtype=torch.int64)\n    reversed[permutation] = torch.arange(block_size)\n    return reversed\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/tinygemm/__init__.py",
    "content": "from .packed import *\nfrom .qbits import *\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/tinygemm/packed.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\nfrom copy import copy\n\nimport torch\nfrom torch.utils import _pytree as pytree\n\n\n__all__ = [\"TinyGemmPackedTensor\"]\n\n\nclass TinyGemmPackedTensor(torch.Tensor):\n    @staticmethod\n    def __new__(cls, data, size, stride, requires_grad=False):\n        # TinyGemmPackedTensor represents uint8 data and can therefore NEVER require gradient\n        assert requires_grad is False\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, data, size, stride, requires_grad=False):\n        self._data = data\n\n    def __repr__(self):\n        return f\"TinyGemmPackedTensor({self._data})\"\n\n    @classmethod\n    def pack(cls, t):\n        \"\"\"Pack a torch.Tensor for tinygemm kernel\n\n        This packs uint4 weights in an int32 tensor as expected by the torch tinygemm mixed mm kernel\n\n        Args:\n            t (`torch.Tensor`):\n                The un-packed `torch.uint8` tensor\n\n        Returns:\n            A `TinyGemmPackedTensor`.\n        \"\"\"\n        inner_ktiles = 2\n        t = t.to(torch.int32).contiguous()\n        if t.device.type == \"cpu\":\n            data = torch._convert_weight_to_int4pack_for_cpu(t, innerKTiles=inner_ktiles)\n        elif t.device.type == \"xpu\":\n            t_uint8 = (t[::, 1::2] << 4 | t[::, ::2]).to(torch.uint8)\n            data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)\n        else:\n            t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)\n            data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles)\n        # We need to store size and stride to make sure the unpacked data has the correct shape\n        return TinyGemmPackedTensor(data, t.size(), t.stride())\n\n    def unpack(self):\n        \"\"\"Unpack the packed tensor to a torch.Tensor\n\n        Packing is device specific and implemented in undocumented dedicated kernels\n        that are synchronized with the corresponding matrix multiplication operation.\n\n        Instead of implementing a dedicated unpacking code, we pass an identity matrix\n        to the mm operation with identity scale and shifts to produce the unpacked uint8 weights.\n\n        Returns:\n            An unpacked uint8 `torch.Tensor` expanded along the second dimension.\n        \"\"\"\n        out_features, in_features = self.size()\n        # We need to pass a group_size to the mm and format the scale and shift accordingly,\n        # although it does not modify the calculation since we use identity scales and shifts.\n        # We arbitrarily choose the smallest group_size to be sure it divides in_features\n        group_size = 32\n        scale_and_shift_shape = (in_features // group_size, out_features, 2)\n        # Initialize identity scale\n        id_scale_and_shift = torch.ones(scale_and_shift_shape, dtype=torch.bfloat16, device=self.device)\n        # Set shift to mid-point, i.e. 2 **(bits - 1)\n        id_scale_and_shift[:, :, 1] = 8\n\n        identity = torch.eye(in_features, dtype=torch.bfloat16, device=self.device)\n        if self._data.device.type == \"cpu\":\n            unpacked_data = torch._weight_int4pack_mm_for_cpu(identity, self._data, group_size, id_scale_and_shift)\n        else:\n            unpacked_data = torch._weight_int4pack_mm(identity, self._data, group_size, id_scale_and_shift)\n\n        return unpacked_data.t().to(torch.uint8)\n\n    @property\n    def dtype(self):\n        return torch.uint8\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\"]\n        # Since meta can be used for serialization, use only AST compatible strings\n        meta = {\n            \"size\": str(list(self.size())),\n            \"stride\": str(self.stride()),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 1\n        assert len(meta) == 2\n        data = inner_tensors[\"_data\"]\n        # Meta should contain only AST compatible strings\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return TinyGemmPackedTensor(data, size, stride)\n\n    __torch_function__ = torch._C._disabled_torch_function_impl\n\n    @classmethod\n    def __torch_dispatch__(cls, op, types, args, kwargs=None):\n        # Convert back to tensor before calling any operation except detach and move\n        if op.overloadpacket is torch.ops.aten.detach:\n            t = args[0]\n            data = op(t._data)\n            return TinyGemmPackedTensor(data, t.size(), t.stride())\n        elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to):\n            t = args[0]\n            dtype = kwargs.get(\"dtype\", torch.uint8)\n            if dtype != torch.uint8:\n                raise ValueError(f\"TinyGemmPackedTensor are torch.uint8 only and cannot be moved to {dtype}.\")\n            data_kwargs = copy(kwargs)\n            data_kwargs[\"dtype\"] = t._data.dtype\n            if kwargs.get(\"device\", t.device).type != t.device.type:\n                # Packing is device specific, so we need to unpack before moving\n                unpacked = t.unpack()\n                unpacked = op(unpacked, **data_kwargs)\n                return TinyGemmPackedTensor.pack(unpacked)\n            # If we stay on the same device type, just copy/move packed data\n            data = op(t._data, **data_kwargs)\n            return TinyGemmPackedTensor(data, t.size(), t.stride())\n        args, kwargs = pytree.tree_map_only(TinyGemmPackedTensor, lambda x: x.unpack(), (args, kwargs or {}))\n        return op(*args, **kwargs)\n\n    def numpy(self):\n        return self.unpack().cpu().numpy()\n"
  },
  {
    "path": "optimum/quanto/tensor/weights/tinygemm/qbits.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport ast\n\nimport torch\nfrom torch.autograd import Function\n\nfrom ...function import QuantizedLinearFunction\nfrom ...grouped import group, ungroup\nfrom ...qtype import qtypes\nfrom ..qbits import WeightQBitsTensor\nfrom .packed import TinyGemmPackedTensor\n\n\n__all__ = [\"TinyGemmWeightQBitsTensor\"]\n\n\nclass TinyGemmQBitsDequantizer(Function):\n    @staticmethod\n    def forward(ctx, t):\n        # There is no custom dequantize kernel available, so we need to convert back to a QBitsTensor\n        qbt = t.weight_qbits_tensor()\n        return qbt.dequantize()\n\n    @staticmethod\n    def backward(ctx, gO):\n        return gO\n\n\nclass TinyGemmQBitsLinearFunction(QuantizedLinearFunction):\n    @staticmethod\n    def forward(ctx, input, other, bias):\n        ctx.save_for_backward(input, other)\n        if type(input) is not torch.Tensor:\n            input = input.dequantize()\n        in_features = input.shape[-1]\n        out_features = other.shape[0]\n        output_shape = input.shape[:-1] + (out_features,)\n        if input.device.type == \"cpu\":\n            output = torch._weight_int4pack_mm_for_cpu(\n                input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift\n            )\n        else:\n            output = torch._weight_int4pack_mm(\n                input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift\n            )\n        output = output.reshape(output_shape)\n        if bias is not None:\n            output = output + bias\n        return output\n\n\nclass TinyGemmWeightQBitsTensor(WeightQBitsTensor):\n    @staticmethod\n    def __new__(cls, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False):\n        if isinstance(scale_shift, torch.Tensor):\n            dtype = scale_shift.dtype\n            assert data.device == scale_shift.device\n        else:\n            assert isinstance(scale_shift, (tuple, list))\n            scale, shift = scale_shift\n            dtype = scale.dtype\n            assert shift.dtype == dtype\n            assert data.device == scale.device\n            assert data.device == shift.device\n        return torch.Tensor._make_wrapper_subclass(\n            cls, size, strides=stride, dtype=dtype, device=data.device, requires_grad=requires_grad\n        )\n\n    def __init__(self, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False):\n        assert axis == 0\n        if not isinstance(data, TinyGemmPackedTensor):\n            assert type(data) is torch.Tensor\n            assert isinstance(scale_shift, (tuple, list))\n            # Format data, scale and shift for tinygemm\n            ungrouped = ungroup(data, axis=0, orig_shape=size)\n            self._data = TinyGemmPackedTensor.pack(ungrouped)\n            out_features, in_features = size\n            scale, shift = scale_shift\n            scale = scale.reshape(out_features, in_features // group_size, 1)\n            shift = shift.reshape(out_features, in_features // group_size, 1)\n            if not shift.dtype.is_floating_point:\n                # Integer shift must be scaled\n                shift = scale * shift\n            # The tinygemm kernel actually uses the mid-point of the quantization range as shift\n            min_range = -shift\n            half_qrange = 2 ** (qtype.bits - 1) * scale\n            # This operation is lossy for bfloat16, and the actual value of shift will be lost\n            shift = min_range + half_qrange\n            # Scale and shift are actually stored in the same tensor\n            self._scale_shift = torch.cat([scale, shift], 2).transpose(0, 1).contiguous()\n        else:\n            self._data = data\n            self._scale_shift = scale_shift\n        self._qtype = qtype\n        self._axis = axis\n        self._group_size = group_size\n\n    def dequantize(self):\n        return TinyGemmQBitsDequantizer.apply(self)\n\n    def weight_qbits_tensor(self):\n        \"\"\"Convert back to a WeightQBitsTensor\n\n        This is required to make sure only standard packing is used when serializing.\n        \"\"\"\n        data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size)\n        n_scales = self._scale_shift.numel() // 2\n        scale = self._scale_shift[:, :, 0].t().reshape((n_scales, 1))\n        shift = self._scale_shift[:, :, 1].t().reshape((n_scales, 1))\n        half_qrange = 2 ** (self.qtype.bits - 1) * scale\n        # This operation is lossy for bfloat16, and the actual value of shift will not be recovered\n        shift = half_qrange - shift\n        return WeightQBitsTensor(\n            self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift\n        )\n\n    def __tensor_flatten__(self):\n        inner_tensors = [\"_data\", \"_scale_shift\"]\n        # Since meta can be used for serialization, use only strings\n        meta = {\n            \"qtype\": self._qtype.name,\n            \"axis\": str(self._axis),\n            \"group_size\": str(self._group_size),\n            \"size\": str(list(self.size())),\n            \"stride\": str(list(self.stride())),\n        }\n        return inner_tensors, meta\n\n    @staticmethod\n    def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):\n        assert len(inner_tensors) == 2\n        assert len(meta) == 5\n        data, scale_shift = inner_tensors[\"_data\"], inner_tensors[\"_scale_shift\"]\n        # Meta should only contain strings, AST compatible except qtype\n        qtype = qtypes[meta[\"qtype\"]]\n        axis = ast.literal_eval(meta[\"axis\"])\n        group_size = ast.literal_eval(meta[\"group_size\"])\n        size = ast.literal_eval(meta[\"size\"])\n        stride = ast.literal_eval(meta[\"stride\"])\n        return TinyGemmWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale_shift)\n\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        \"\"\"Dispatch torch functions applied on this subtensor\n\n        This method is called whenever a torch function (such as `torch.nn.functional.linear`)\n        is called with at least one parameter coresponding to this subtensor:\n\n        - if a quantized implementation exists for the selected function, it is called,\n        - otherwise, the original implementation is called, deactivating further functional dispatch.\n\n        During the execution of the standard torch function, a second-level of dispatch will\n        happen, but this time directly on individual torch Tensor operations (mainly ATEN).\n        \"\"\"\n        kwargs = kwargs or {}\n        if func is torch.nn.functional.linear:\n\n            def qlinear(input, other, bias=None):\n                return TinyGemmQBitsLinearFunction.apply(input, other, bias)\n\n            return qlinear(*args, **kwargs)\n        # Defer to operations dispatcher\n        with torch._C.DisableTorchFunctionSubclass():\n            return func(*args, **kwargs)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = 'optimum-quanto'\ndescription = 'A pytorch quantization backend for optimum.'\nclassifiers = [\n    'Development Status :: 2 - Pre-Alpha',\n    'License :: OSI Approved :: Apache Software License',\n    'Intended Audience :: Developers',\n    'Intended Audience :: Education',\n    'Intended Audience :: Science/Research',\n    'Operating System :: OS Independent',\n    'Programming Language :: Python :: 3.9',\n    'Programming Language :: Python :: 3.10',\n    'Programming Language :: Python :: 3.11',\n    'Topic :: Scientific/Engineering :: Artificial Intelligence'\n]\nkeywords = ['torch', 'quantization']\nrequires-python = '>=3.9.0'\nauthors = [{ name = 'David Corvoysier' }]\nmaintainers = [\n    {name = \"HuggingFace Inc. Special Ops Team\", email=\"hardware@huggingface.co\"},\n]\ndependencies = ['torch>=2.6.0', 'ninja', 'numpy', 'safetensors', 'huggingface_hub']\nlicense = { text = 'Apache-2.0' }\nreadme = 'README.md'\ndynamic = ['version']\n\n[project.urls]\nhomepage = 'https://github.com/huggingface/optimum-quanto'\n\n[project.optional-dependencies]\ndev = ['pytest', 'ruff']\nexamples = [\n    'torchvision',\n    'transformers',\n    'diffusers',\n    'datasets',\n    'accelerate',\n    'sentencepiece',\n    'scipy'\n]\n\n[tool.setuptools.packages.find]\nwhere = [\".\"]\ninclude = [\"optimum*\"]\n\n[tool.setuptools.dynamic]\nversion = {attr = 'optimum.quanto.__version__'}\n\n[build-system]\nrequires = ['setuptools>65.5.1', 'setuptools_scm']\nbuild-backend = 'setuptools.build_meta'\n\n[tool.ruff]\n# Configuration for Ruff\nline-length = 119  # Same line-length as Black had\n\n# Linting rules:\n# Never enforce `E501` (line length violations) and other specific rules.\nlint.ignore = ['C901', 'E501', 'E741']\nlint.select = ['C', 'E', 'F', 'I', 'W']\n\n# Ignore import violations in all `__init__.py` files.\n[tool.ruff.lint.per-file-ignores]\n'__init__.py' = ['E402', 'F401', 'F403', 'F811']\n\n# isort configuration (to sort imports)\n[tool.ruff.lint.isort]\nlines-after-imports = 2\nknown-first-party = ['optimum.quanto']\n"
  },
  {
    "path": "setup.sh",
    "content": "#!/bin/bash\n\nNIGHTLY=${1:-0}\nVENV=\".venv\"\nif [ ! -d \"${VENV}\" ]; then\n    python3 -m venv ${VENV}\nfi\n. ${VENV}/bin/activate\nif [ \"$NIGHTLY\" -eq \"0\" ]; then\n    pip install --upgrade torch torchvision torchaudio\nelse\n    pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118\nfi\n# Build tools\npip install ruff pytest build\n# For examples\npip install accelerate transformers datasets\n"
  },
  {
    "path": "tests/cli/cli_helpers.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib\n\nimport pytest\n\n\nrequires_optimum_cli = pytest.mark.skipif(\n    importlib.util.find_spec(\"optimum.commands\") is None, reason=\"optimum-cli is required\"\n)\n"
  },
  {
    "path": "tests/cli/test_quantize_cli.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport subprocess\nfrom tempfile import TemporaryDirectory\n\nimport pytest\nfrom cli_helpers import requires_optimum_cli\n\nfrom optimum.quanto import quantization_map\n\n\n@requires_optimum_cli\n@pytest.mark.parametrize(\"weights\", [\"int4\", \"int8\"])\ndef test_export_decoder_cli(weights):\n    from optimum.quanto import QuantizedModelForCausalLM\n\n    model_id = \"facebook/opt-125m\"\n    with TemporaryDirectory() as tempdir:\n        subprocess.run(\n            [\n                \"optimum-cli\",\n                \"quanto\",\n                \"quantize\",\n                \"--model\",\n                model_id,\n                \"--weights\",\n                f\"{weights}\",\n                tempdir,\n            ],\n            shell=False,\n            check=True,\n        )\n        # Verify we can reload the quantized model\n        qmodel = QuantizedModelForCausalLM.from_pretrained(tempdir)\n        qmap = quantization_map(qmodel)\n        for layer_qconfig in qmap.values():\n            assert layer_qconfig[\"weights\"] == f\"q{weights}\"\n"
  },
  {
    "path": "tests/conftest.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\n\ndevices = [\"cpu\"]\nif torch.cuda.is_available():\n    devices += [\"cuda\"]\nelif torch.backends.mps.is_available():\n    devices += [\"mps\"]\nelif torch.xpu.is_available():\n    devices += [\"xpu\"]\n\n\n@pytest.fixture(scope=\"module\", params=devices)\ndef device(request):\n    return torch.device(request.param)\n\n\ndef pytest_configure(config):\n    # register additional markers\n    config.addinivalue_line(\"markers\", \"skip_device(type): mark test to be skipped for the specified device type\")\n\n\ndef pytest_runtest_call(item):\n    fixture_name = \"device\"\n    if fixture_name in item.fixturenames:\n        # TODO: should be able to recover the fixture id instead of the actual value\n        fixture_arg = item.funcargs[fixture_name].type\n        skip_marks = {mark.args[0] for mark in item.iter_markers(name=f\"skip_{fixture_name}\")}\n        if fixture_arg in skip_marks:\n            pytest.skip(f\"Test skipped for {fixture_name} {fixture_arg}\")\n"
  },
  {
    "path": "tests/helpers.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport gc\nimport os\n\nimport pytest\nimport torch\nfrom packaging import version\n\nfrom optimum.quanto import (\n    AbsmaxOptimizer,\n    MaxOptimizer,\n    absmax_scale,\n    qint8,\n    quantize_activation,\n    quantize_weight,\n)\n\n\ndef torch_min_version(v):\n    def torch_min_version_decorator(test):\n        @functools.wraps(test)\n        def test_wrapper(*args, **kwargs):\n            if version.parse(torch.__version__) < version.parse(v):\n                pytest.skip(f\"Requires pytorch >= {v}\")\n            test(*args, **kwargs)\n\n        return test_wrapper\n\n    return torch_min_version_decorator\n\n\ndef device_eq(a, b):\n    if a.type != b.type:\n        return False\n    a_index = a.index if a.index is not None else 0\n    b_index = b.index if b.index is not None else 0\n    return a_index == b_index\n\n\ndef random_tensor(shape, dtype=torch.float32, device=\"cpu\"):\n    if dtype.is_floating_point:\n        rand_dtype = dtype if dtype.itemsize > 1 else torch.float16\n        # Generate a random tensor between -1. and 1.\n        t = torch.rand(shape, dtype=rand_dtype, device=device) * 2 - 1\n        return t.to(dtype)\n    else:\n        assert dtype == torch.int8\n        return torch.randint(-127, 127, shape, dtype=torch.int8, device=device)\n\n\ndef random_qactivation(shape, qtype=qint8, dtype=torch.float32, device=\"cpu\"):\n    t = random_tensor(shape, dtype, device=device)\n    scale = absmax_scale(t, qtype=qtype)\n    return quantize_activation(t, qtype=qtype, scale=scale)\n\n\ndef random_qweight(shape, qtype, dtype=torch.float32, axis=0, group_size=None, device=\"cpu\"):\n    device = device.type if isinstance(device, torch.device) else device\n    t = random_tensor(shape, dtype, device=device)\n    if qtype.bits == 8:\n        scale = AbsmaxOptimizer()(t, qtype=qtype, axis=axis)\n        shift = None\n    else:\n        optimizer_kwargs = {\"qtype\": qtype, \"axis\": axis, \"group_size\": group_size}\n        if device == \"xpu\":\n            optimizer_kwargs.update({\"zeropoint\": True})\n        scale, shift = MaxOptimizer()(t, **optimizer_kwargs)\n    return quantize_weight(t, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size, optimized=False)\n\n\ndef assert_similar(a, b, atol=None, rtol=None):\n    \"\"\"Verify that the cosine similarity of the two inputs is close to 1.0 everywhere\"\"\"\n    assert a.dtype == b.dtype\n    assert a.shape == b.shape\n    if atol is None:\n        # We use torch finfo resolution\n        atol = torch.finfo(a.dtype).resolution\n    if rtol is None:\n        # Please refer to that discussion for default rtol values based on the float type:\n        # https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision\n        rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype]\n    sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0)\n    if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol):\n        max_deviation = torch.min(sim)\n        raise ValueError(f\"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}\")\n\n\ndef get_device_memory(device):\n    gc.collect()\n    if device.type == \"cuda\":\n        torch.cuda.empty_cache()\n        return torch.cuda.memory_allocated()\n    elif device.type == \"mps\":\n        torch.mps.empty_cache()\n        return torch.mps.current_allocated_memory()\n    elif device.type == \"xpu\":\n        torch.xpu.empty_cache()\n        return torch.xpu.memory_allocated()\n    return None\n\n\n_run_staging = os.getenv(\"HUGGINGFACE_CO_STAGING\", False)\n"
  },
  {
    "path": "tests/library/test_extensions.py",
    "content": "import platform\n\nimport pytest\nimport torch\nfrom packaging import version\n\nfrom optimum.quanto.library.extensions import get_extension, is_extension_available\n\n\ndef _is_xpu_available():\n    # SYCL extension support is added in torch>=2.7 on Linux\n    if platform.system() != \"Linux\":\n        return False\n    if version.parse(torch.__version__).release < version.parse(\"2.7\").release:\n        return False\n    return torch.xpu.is_available()\n\n\nextension_names = [\"quanto_cpp\"]\nif torch.cuda.is_available():\n    if torch.version.cuda:\n        extension_names.append(\"quanto_cuda\")\n    if torch.version.hip:\n        extension_names.append(\"quanto_hip\")\nif torch.backends.mps.is_available():\n    extension_names.append(\"quanto_mps\")\nif _is_xpu_available():\n    extension_names.append(\"quanto_xpu\")\n\n\n@pytest.mark.parametrize(\"extension_name\", extension_names)\ndef test_extension_available(extension_name):\n    assert is_extension_available(extension_name)\n\n\n@pytest.mark.parametrize(\"extension_name\", extension_names)\ndef test_extension_compilation(extension_name):\n    extension = get_extension(extension_name)\n    assert extension.lib is not None\n"
  },
  {
    "path": "tests/library/test_mm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_tensor\n\nfrom optimum.quanto.library.extensions import is_extension_available\nfrom optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking\nfrom optimum.quanto.tensor.weights.marlin import marlin_permute\nfrom optimum.quanto.tensor.weights.marlin.fp8.packed import get_scale_perms, pack_fp8_as_int32\nfrom optimum.quanto.tensor.weights.marlin.int4.packed import MarlinInt4PackedTensor\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10, None], ids=[\"single\", \"batched\", \"static\"])\n@pytest.mark.parametrize(\"input_features\", [32, 50])\n@pytest.mark.parametrize(\"output_features\", [48, 50, 64])\n@pytest.mark.parametrize(\"input_dtype\", [None, torch.int8], ids=[\"i-as-out\", \"i-int8\"])\n@pytest.mark.parametrize(\n    \"weight_dtype\", [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.int8], ids=[\"w-float8\", \"w-float8-uz\", \"w-int8\"]\n)\n@pytest.mark.parametrize(\"output_dtype\", [torch.float16, torch.bfloat16], ids=[\"o-fp16\", \"o-bf16\"])\ndef test_qbytes_mm(batch_size, input_features, input_dtype, weight_dtype, output_features, output_dtype, device):\n    if device.type in [\"mps\"] and weight_dtype.is_floating_point:\n        pytest.skip(f\"Float8 types are not supported on {device.type} device\")\n    input_shape = (32, input_features)\n    if batch_size is not None:\n        input_shape = (batch_size,) + input_shape\n    if input_dtype is None:\n        input_dtype = output_dtype\n    input = random_tensor(input_shape, dtype=input_dtype, device=device)\n    weight = random_tensor((output_features, input_features), dtype=weight_dtype, device=device)\n    # Use a scale small enough to prevent overflows\n    scale = random_tensor((output_features, 1), dtype=output_dtype, device=device) / 1e3\n    output = torch.ops.quanto.qbytes_mm(input, weight, scale)\n    expected = torch.matmul(input.to(scale.dtype), (weight.to(scale.dtype) * scale).t())\n    assert_similar(expected, output)\n\n\n@pytest.mark.skipif(\n    (not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8)\n    and not torch.xpu.is_available(),\n    reason=\"The test requires CUDA device >= sm80 or Intel XPU\",\n)\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 256), (512, 256)])\n@pytest.mark.parametrize(\"batch_size, tokens\", [(4, 1), (10, 128)], ids=[\"gemv\", \"gemm\"])\ndef test_gemm_fp16_int4(batch_size, tokens, in_features, out_features):\n    \"\"\"This test verifies that the GEMM operation is equivalent to torch.mm.\"\"\"\n    bits = 4\n    group_size = 128  # Hard-coded in kernels\n    device = torch.device(0)  # XPU can also share this setting.\n    input_shape = (batch_size, tokens, in_features)\n    # FIXME: does not work if inputs are negative !!??\n    inputs = torch.rand(input_shape, dtype=torch.float16, device=device)\n    qmax = 2**bits\n    other_shape = (out_features, in_features)\n    other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device)\n    pack_type = AWQPacking.V1 if device.type == \"xpu\" else AWQPacking.V2\n    packed_other_data = AWQPackedTensor.pack(other_data, packing=pack_type)._data\n    # The GEMM kernel works on transposed scales\n    scales_shape = (in_features // group_size, out_features)\n    other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax\n    # The GEMM kernel works on transposed, negated and scaled shifts\n    qmin = -(2 ** (bits - 1))\n    qmax = 2 ** (bits - 1)\n    other_shifts = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device)\n    # Negate and scale, xpu should keep the original int8 shifts\n    other_scaled_shifts = other_shifts if device.type == \"xpu\" else -other_shifts * other_scales\n    # Evaluate mm outputs using the GEMM kernel\n    lib_outputs = torch.ops.quanto.gemm_f16i4_awq(\n        inputs,\n        packed_other_data,\n        other_scales,\n        other_scaled_shifts,\n        rows=inputs.numel() // inputs.shape[-1],\n        out_cols=out_features,\n        in_cols=in_features,\n        bits=4,\n        group_size=group_size,\n    )\n    # Transpose other data and reshape it to align it with transposed scales and zeros\n    other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features)\n    # Dequantize transposed other\n    other_t = (other_data_t - other_shifts) * other_scales\n    # Reshape it as expected by the matmul\n    other_t = other_t.reshape(in_features, out_features)\n    # Evaluate the matrix multiplication using pytorch float16 mm\n    pt_outputs = torch.matmul(inputs, other_t)\n    # Verify the results are similar\n    assert_similar(lib_outputs, pt_outputs, rtol=5e-3)\n\n\n@pytest.mark.skipif(\n    not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8,\n    reason=\"CUDA device >= sm80 not available\",\n)\n@pytest.mark.parametrize(\"tokens\", [1, 10, 128])\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 1024), (512, 2048)])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16], ids=[\"bf16\", \"fp16\"])\ndef test_fp8_marlin(tokens, in_features, out_features, dtype):\n    device = torch.device(\"cuda\")\n    input_shape = (tokens, in_features)\n    inputs = torch.rand(input_shape, dtype=dtype, device=device)\n    other_shape = (in_features, out_features)\n    other_data = torch.rand(other_shape, dtype=dtype, device=device).to(torch.float8_e4m3fn)\n    other_data_int32 = pack_fp8_as_int32(other_data)\n    perm = torch.empty(0, dtype=torch.int, device=device)\n\n    other_data_repack = torch.ops.quanto.pack_fp8_marlin(\n        b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8\n    )\n    other_scale = torch.rand(1, out_features, dtype=dtype, device=device)\n    other_scale_original = other_scale.clone()\n\n    scale_perm_single = get_scale_perms()\n    other_scale = other_scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    other_scale = other_scale.reshape(-1, out_features).contiguous()\n\n    workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device)\n    lib_outputs = torch.ops.quanto.gemm_f16f8_marlin(\n        a=inputs,\n        b_q_weight=other_data_repack,\n        b_scales=other_scale,\n        workspace=workspace,\n        num_bits=8,\n        size_m=tokens,\n        size_n=out_features,\n        size_k=in_features,\n    )\n    # Evaluate the matrix multiplication using pytorch mm\n    other = other_data.to(dtype) * other_scale_original\n    pt_outputs = torch.matmul(inputs.to(dtype), other)\n    # Verify the results are similar\n    assert_similar(lib_outputs, pt_outputs)\n\n\n@pytest.mark.skipif(\n    not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8,\n    reason=\"CUDA device >= sm80 not available\",\n)\n@pytest.mark.parametrize(\"in_features, out_features\", [(256, 256), (512, 256)])\n@pytest.mark.parametrize(\"batch_size, tokens\", [(1, 16), (10, 128)], ids=[\"small\", \"medium\"])\ndef test_gemm_marlin_fp16_int4(batch_size, tokens, in_features, out_features):\n    bits = 4\n    group_size = 128  # Hard-coded in kernels\n    device = torch.device(\"cuda\")\n    input_shape = (batch_size, tokens, in_features)\n    # FIXME: does not work if inputs are negative !!??\n    inputs = torch.rand(input_shape, dtype=torch.float16, device=device)\n    qmax = 2**bits\n    other_shape = (out_features, in_features)\n    other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device)\n    # The GEMM kernel works on transposed scales\n    scales_shape = (in_features // group_size, out_features)\n    other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax\n    # This kernel works on transposed, negated and scaled zeropoints\n    qmin = -(2 ** (bits - 1))\n    qmax = 2 ** (bits - 1)\n    other_shifts = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device)\n    # Negate and scale\n    other_scaled_shifts = -other_shifts * other_scales\n    workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=inputs.device)\n    packed_other_data_marlin = MarlinInt4PackedTensor.pack(other_data)._data\n    # Apply scale and shift permutations\n    other_scales_marlin = marlin_permute(other_scales)\n    other_scaled_shifts_marlin = marlin_permute(other_scaled_shifts)\n    lib_outputs = torch.ops.quanto.gemm_f16i4_marlin(\n        inputs, packed_other_data_marlin, other_scales_marlin, other_scaled_shifts_marlin, workspace\n    )\n    # Transpose other data and reshape it to align it with transposed scales and zeros\n    other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features)\n    # Dequantize transposed other\n    other_t = other_data_t * other_scales + other_scaled_shifts\n    # Reshape it as expected by the matmul\n    other_t = other_t.reshape(in_features, out_features)\n    # Evaluate the matrix multiplication using pytorch float16 mm\n    pt_outputs = torch.matmul(inputs, other_t)\n    # Verify the results are similar\n    assert_similar(lib_outputs, pt_outputs, rtol=1e-3)\n"
  },
  {
    "path": "tests/library/test_quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, device_eq, random_tensor\n\nfrom optimum.quanto import (\n    MaxOptimizer,\n    absmax_scale,\n    qfloat8,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint2,\n    qint4,\n    qint8,\n)\nfrom optimum.quanto.tensor.grouped import ungroup\n\n\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8], ids=[\"qint8\"])\n@pytest.mark.parametrize(\n    \"axis\",\n    [None, 0, -1],\n    ids=[\"per-tensor\", \"first-axis\", \"last-axis\"],\n)\ndef test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=axis)\n    data = torch.ops.quanto.quantize_symmetric(a, dtype=qtype.dtype, axis=axis, scale=scale)\n    assert data.dtype == qtype.dtype\n    assert device_eq(data.device, device)\n    assert_similar(a, data * scale)\n\n\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\n    \"qtype\",\n    [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],\n    ids=[\"qfloat8\", \"qfloat8_e4m3fn\", \"qfloat8_e4m3fnuz\", \"qfloat8_e5m2\"],\n)\n@pytest.mark.parametrize(\n    \"axis\",\n    [None, 0, -1],\n    ids=[\"per-tensor\", \"first-axis\", \"last-axis\"],\n)\ndef test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=axis)\n    data = torch.ops.quanto.quantize_symmetric(a, dtype=qtype.dtype, axis=axis, scale=scale)\n    assert data.dtype == qtype.dtype\n    assert device_eq(data.device, device)\n    assert_similar(a, data.to(dtype) * scale, atol=5e-3)\n\n\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"group_size\", [None, 8], ids=[\"channel-wise\", \"group-wise\"])\n@pytest.mark.parametrize(\"shift_mode\", [\"zeropoint\", \"float\"])\ndef test_affine_quantize(input_shape, dtype, qtype, axis, group_size, shift_mode, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size)\n    if shift_mode == \"zeropoint\":\n        shift = torch.round(shift / scale).to(torch.int8)\n    data = torch.ops.quanto.quantize_affine(a, qtype.bits, axis, group_size, scale, shift)\n    assert data.dtype == torch.uint8\n    assert device_eq(data.device, device)\n    if shift_mode == \"zeropoint\":\n        qa = (data - shift) * scale\n    else:\n        qa = data * scale - shift\n    atol = {\n        qint4: {\n            \"zeropoint\": 4e-3,\n            \"float\": 3e-3,\n        },\n        qint2: {\n            \"zeropoint\": 6e-2,\n            \"float\": 5e-2,\n        },\n    }[qtype][shift_mode]\n    if group_size is not None:\n        qa = ungroup(qa, axis=axis, orig_shape=a.shape)\n    assert_similar(a, qa, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\ndef test_affine_quantize_integer_tensor(dtype, qtype, device):\n    \"\"\"This test verifies that an integer tensor in the correct range is preserved.\"\"\"\n    bits = qtype.bits\n    qmin = -(2 ** (bits - 1))\n    qmax = 2 ** (bits - 1) - 1\n    a = torch.tensor(range(qmin, qmax + 1), dtype=dtype).to(device)\n    scale, shift = MaxOptimizer()(a, qtype=qtype, axis=0, group_size=None)\n    zeropoint = torch.round(shift / scale)\n    data = torch.ops.quanto.quantize_affine(a, bits, 0, None, scale, zeropoint)\n\n    assert data.dtype == torch.uint8\n    assert device_eq(data.device, device)\n    assert torch.equal(a, data - zeropoint)\n"
  },
  {
    "path": "tests/library/test_unpack.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport pytest\nimport torch\n\nfrom optimum.quanto.tensor.packed import pack_weights\n\n\n@pytest.mark.parametrize(\"bits\", [2, 4], ids=[\"int2\", \"int4\"])\n@pytest.mark.parametrize(\"shape\", [(12,), (32, 32)], ids=[\"vector\", \"matrix\"])\ndef test_unpack(bits, shape, device):\n    qmax = 2**bits\n    a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    packed_a = pack_weights(a, bits)\n    unpacked_a = torch.ops.quanto.unpack(packed_a, bits)\n    assert unpacked_a.dtype == torch.uint8\n    assert torch.equal(unpacked_a, a)\n"
  },
  {
    "path": "tests/models/conftest.py",
    "content": "import pytest\nfrom huggingface_hub.constants import _staging_mode\n\n\n@pytest.fixture\ndef staging():\n    \"\"\"A pytest fixture only available in huggingface_hub staging mode\n\n    If the huggingface_hub is not operating in staging mode, tests using\n    that fixture are automatically skipped.\n\n    Returns:\n        a Dict containing a valid staging user and token.\n    \"\"\"\n    if not _staging_mode:\n        pytest.skip(\"requires huggingface_hub staging mode\")\n    return {\n        \"user\": \"__DUMMY_TRANSFORMERS_USER__\",\n        # Not critical, only usable on the sandboxed CI instance.\n        \"token\": \"hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL\",\n    }\n\n\n@pytest.fixture(autouse=True)\ndef skip_if_staging(request):\n    if _staging_mode:\n        if \"staging\" not in request.fixturenames:\n            pytest.skip(\"requires huggingface_hub standard mode\")\n"
  },
  {
    "path": "tests/models/test_quantized_model_for_causal_lm.py",
    "content": "import uuid\nfrom tempfile import TemporaryDirectory\n\nimport pytest\nimport torch\nfrom huggingface_hub import delete_repo\n\nfrom optimum.quanto import QModuleMixin, is_transformers_available, qint4, qint8\n\n\ndef quantized_model_for_causal_lm(model_id, qtype, exclude, from_config=False):\n    from transformers import AutoModelForCausalLM, OPTConfig\n\n    from optimum.quanto import QuantizedModelForCausalLM\n\n    if from_config:\n        config = OPTConfig(\n            **{\n                \"activation_dropout\": 0.0,\n                \"activation_function\": \"relu\",\n                \"architectures\": [\"OPTForCausalLM\"],\n                \"attention_dropout\": 0.0,\n                \"bos_token_id\": 2,\n                \"do_layer_norm_before\": True,\n                \"dropout\": 0.1,\n                \"eos_token_id\": 2,\n                \"ffn_dim\": 32,\n                \"hidden_size\": 8,\n                \"init_std\": 0.02,\n                \"layerdrop\": 0.0,\n                \"max_position_embeddings\": 16,\n                \"model_type\": \"opt\",\n                \"num_attention_heads\": 2,\n                \"num_hidden_layers\": 2,\n                \"pad_token_id\": 1,\n                \"prefix\": \"</s>\",\n                \"torch_dtype\": \"float16\",\n                \"use_cache\": True,\n                \"vocab_size\": 64,\n                \"word_embed_proj_dim\": 8,\n            }\n        )\n        model = AutoModelForCausalLM.from_config(config).eval()\n    else:\n        model = AutoModelForCausalLM.from_pretrained(model_id)\n    return QuantizedModelForCausalLM.quantize(model, weights=qtype, exclude=exclude)\n\n\ndef compare_models(a_model, b_model):\n    # Compare tensors\n    for (a_name, a_m), (b_name, b_m) in zip(a_model.named_modules(), b_model.named_modules()):\n        assert a_name == b_name\n        if isinstance(a_m, QModuleMixin):\n            assert isinstance(b_m, QModuleMixin)\n        if isinstance(b_m, QModuleMixin):\n            assert isinstance(a_m, QModuleMixin)\n        if isinstance(a_m, QModuleMixin):\n            assert torch.equal(a_m.weight, b_m.weight)\n        for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):\n            assert a_p_name == b_p_name\n            assert isinstance(a_p, torch.Tensor)\n            assert torch.equal(a_p, b_p)\n    # Compare model outputs\n    inputs = torch.ones((1, 1), dtype=torch.int64)\n    with torch.no_grad():\n        output_a = a_model.forward(inputs)\n        output_b = b_model.forward(inputs)\n    assert torch.equal(output_a.logits, output_b.logits)\n    for i, a_key_value in enumerate(output_a.past_key_values):\n        b_key_value = output_b.past_key_values[i]\n        for j, a_value in enumerate(a_key_value):\n            assert torch.equal(a_value, b_key_value[j])\n\n\n@pytest.mark.skipif(not is_transformers_available(), reason=\"requires transformers\")\n@pytest.mark.parametrize(\"model_id\", [\"facebook/opt-125m\"])\n@pytest.mark.parametrize(\"qtype\", [qint4, qint8], ids=[\"qint4\", \"qint8\"])\n@pytest.mark.parametrize(\"exclude_lm_head\", [True, False], ids=[\"full\", \"no_lm_head\"])\ndef test_quantized_model_for_causal_lm_base(model_id, qtype, exclude_lm_head):\n    from optimum.quanto import QuantizedModelForCausalLM\n\n    exclude = \"lm_head\" if exclude_lm_head else None\n    quantized = quantized_model_for_causal_lm(model_id, qtype, exclude)\n    with TemporaryDirectory() as tmpdir:\n        quantized.save_pretrained(tmpdir)\n        requantized = QuantizedModelForCausalLM.from_pretrained(tmpdir)\n\n    compare_models(quantized, requantized)\n\n\n@pytest.mark.skipif(not is_transformers_available(), reason=\"requires transformers\")\ndef test_quantized_model_for_causal_lm_sharded():\n    from optimum.quanto import QuantizedModelForCausalLM\n\n    model_id = \"facebook/opt-125m\"\n    qtype = qint4\n    quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None)\n    with TemporaryDirectory() as tmpdir:\n        quantized.save_pretrained(tmpdir, max_shard_size=\"100MB\")\n        requantized = QuantizedModelForCausalLM.from_pretrained(tmpdir)\n\n    compare_models(quantized, requantized)\n\n\n@pytest.mark.skipif(not is_transformers_available(), reason=\"requires transformers\")\n@pytest.mark.parametrize(\"in_org\", [True, False], ids=[\"org\", \"user\"])\ndef test_causal_lm_base_push_to_hub(staging, in_org):\n    from optimum.quanto import QuantizedModelForCausalLM\n\n    identifier = uuid.uuid4()\n\n    qtype = qint4\n    exclude = None\n    quantized = quantized_model_for_causal_lm(None, qtype, exclude, from_config=True)\n\n    repo_id = f\"test-model-{identifier}\"\n    if in_org:\n        quantized.push_to_hub(repo_id, token=staging[\"token\"])\n        hub_repo_id = f\"{staging['user']}/{repo_id}\"\n    else:\n        hub_repo_id = f\"valid_org/{repo_id}-org\"\n        quantized.push_to_hub(hub_repo_id, token=staging[\"token\"])\n\n    requantized = QuantizedModelForCausalLM.from_pretrained(hub_repo_id, token=staging[\"token\"])\n\n    compare_models(quantized, requantized)\n\n    delete_repo(hub_repo_id, token=staging[\"token\"])\n\n\n@pytest.mark.skipif(not is_transformers_available(), reason=\"requires transformers\")\n@pytest.mark.parametrize(\"model_id\", [\"facebook/opt-125m\"])\n@pytest.mark.parametrize(\"qtype\", [qint4, qint8], ids=[\"qint4\", \"qint8\"])\ndef test_quantized_model_load_state_dict_non_strict(model_id, qtype):\n    # see issue #278\n    quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None)\n    sd = quantized.state_dict()\n\n    # delete a key used by both qint4 and qint8 from the state dict\n    key = \"model.decoder.layers.0.self_attn.k_proj.weight._scale\"\n    del sd[key]\n\n    # strict loading should raise a RuntimeError, which is what PyTorch does in this case\n    with pytest.raises(RuntimeError, match=key):\n        quantized.load_state_dict(sd)\n\n    # non-strict loading should not raise an errror\n    result = quantized.load_state_dict(sd, strict=False)\n    assert result.missing_keys == [key]\n"
  },
  {
    "path": "tests/models/test_quantized_model_for_pixart.py",
    "content": "import uuid\nfrom tempfile import TemporaryDirectory\n\nimport pytest\nimport torch\nfrom huggingface_hub import delete_repo\n\nfrom optimum.quanto import QModuleMixin, is_diffusers_available, qint4, qint8\n\n\ndef quantized_model_for_pixart(qtype, exclude):\n    from diffusers import PixArtTransformer2DModel\n\n    from optimum.quanto import QuantizedPixArtTransformer2DModel\n\n    init_dict = {\n        \"sample_size\": 8,\n        \"num_layers\": 1,\n        \"patch_size\": 2,\n        \"attention_head_dim\": 2,\n        \"num_attention_heads\": 2,\n        \"in_channels\": 4,\n        \"cross_attention_dim\": 8,\n        \"out_channels\": 8,\n        \"attention_bias\": True,\n        \"activation_fn\": \"gelu-approximate\",\n        \"num_embeds_ada_norm\": 8,\n        \"norm_type\": \"ada_norm_single\",\n        \"norm_elementwise_affine\": False,\n        \"norm_eps\": 1e-6,\n        \"use_additional_conditions\": False,\n        \"caption_channels\": None,\n    }\n    torch.manual_seed(0)\n    model = PixArtTransformer2DModel(**init_dict).eval()\n\n    return QuantizedPixArtTransformer2DModel.quantize(model, weights=qtype, exclude=exclude)\n\n\ndef compare_models(a_model, b_model):\n    # Compare tensors\n    for (a_name, a_m), (b_name, b_m) in zip(a_model.named_modules(), b_model.named_modules()):\n        assert a_name == b_name\n        if isinstance(a_m, QModuleMixin):\n            assert isinstance(b_m, QModuleMixin)\n        if isinstance(b_m, QModuleMixin):\n            assert isinstance(a_m, QModuleMixin)\n        if isinstance(a_m, QModuleMixin):\n            assert torch.equal(a_m.weight, b_m.weight)\n        for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()):\n            assert a_p_name == b_p_name\n            assert isinstance(a_p, torch.Tensor)\n            assert torch.equal(a_p, b_p)\n        for (a_b_name, a_b), (b_b_name, b_b) in zip(a_m.named_buffers(), b_m.named_buffers()):\n            assert a_b_name == b_b_name\n            assert isinstance(a_b, torch.Tensor)\n            assert torch.equal(a_b, b_b)\n\n    # Compare model outputs\n    hidden_states = torch.randn((1, 4, 8, 8))\n    timesteps = torch.tensor([1.0])\n    encoder_hidden_states = torch.randn((1, 8, 8))\n    model_inputs = {\n        \"hidden_states\": hidden_states,\n        \"timestep\": timesteps,\n        \"encoder_hidden_states\": encoder_hidden_states,\n        \"added_cond_kwargs\": {\"aspect_ratio\": None, \"resolution\": None},\n        \"return_dict\": False,\n    }\n\n    with torch.no_grad():\n        output_a = a_model(**model_inputs)[0]\n        output_b = b_model(**model_inputs)[0]\n    assert torch.allclose(output_a, output_b, atol=1e-3, rtol=1e-3)\n\n\n@pytest.mark.skipif(not is_diffusers_available(), reason=\"requires diffusers\")\n@pytest.mark.parametrize(\"qtype\", [qint4, qint8], ids=[\"qint4\", \"qint8\"])\n@pytest.mark.parametrize(\"exclude_proj_out\", [True, False], ids=[\"without_proj_out\", \"with_proj_out\"])\ndef test_quantized_model_for_pixart(qtype, exclude_proj_out):\n    from optimum.quanto import QuantizedPixArtTransformer2DModel\n\n    exclude = \"proj_out\" if exclude_proj_out else None\n    quantized = quantized_model_for_pixart(qtype, exclude)\n    with TemporaryDirectory() as tmpdir:\n        quantized.save_pretrained(tmpdir)\n        requantized = QuantizedPixArtTransformer2DModel.from_pretrained(tmpdir)\n\n    compare_models(quantized, requantized)\n\n\n@pytest.mark.skipif(not is_diffusers_available(), reason=\"requires diffusers\")\n@pytest.mark.parametrize(\"in_org\", [True, False], ids=[\"org\", \"user\"])\ndef test_push_to_hub(staging, in_org):\n    from optimum.quanto import QuantizedPixArtTransformer2DModel\n\n    identifier = uuid.uuid4()\n\n    exclude = None\n    quantized = quantized_model_for_pixart(\"qint8\", exclude)\n    repo_id = f\"test-model-{identifier}\"\n    if in_org:\n        quantized.push_to_hub(repo_id, token=staging[\"token\"])\n        hub_repo_id = f\"{staging['user']}/{repo_id}\"\n    else:\n        hub_repo_id = f\"valid_org/{repo_id}-org\"\n        quantized.push_to_hub(hub_repo_id, token=staging[\"token\"])\n\n    requantized = QuantizedPixArtTransformer2DModel.from_pretrained(hub_repo_id, token=staging[\"token\"])\n    compare_models(quantized, requantized)\n\n    delete_repo(hub_repo_id, token=staging[\"token\"])\n"
  },
  {
    "path": "tests/nn/test_calibrate.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import random_qactivation\n\nfrom optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8\nfrom optimum.quanto.nn import QLinear\n\n\ndef _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device):\n    linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(device)\n    qlinear = QLinear.from_module(linear, weights=qint8, activations=activations)\n    qinputs = random_qactivation(\n        (batch_size, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device\n    )\n    # Run a first inference without Calibration\n    with torch.no_grad():\n        qout = qlinear(qinputs)\n    assert torch.all(qlinear.input_scale == 1)\n    assert torch.all(qlinear.output_scale == 1)\n    # Calibrate to adjust input and output scales and set the correct dtype\n    with torch.no_grad(), Calibration():\n        qout = qlinear(qinputs)\n    assert qout.qtype == activations\n    assert torch.any(qlinear.input_scale != 1)\n    assert torch.any(qlinear.output_scale != 1)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_calibrate_qlinear_activations_int8(batch_size, tokens, embeddings, use_bias, device):\n    _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, qint8, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-qfloat8-e5m2\", \"a-qfloat8-e4m3\", \"a-qfloat8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_calibrate_qlinear_activations_float8(batch_size, tokens, embeddings, use_bias, activations, device):\n    _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device)\n\n\ndef _test_calibrate_custom_module(activations, device):\n    tokens = 10\n    embeddings = 32\n\n    class TwoLinearModel(torch.nn.Module):\n        def __init__(self, embeddings):\n            super().__init__()\n            self.linear1 = torch.nn.Linear(embeddings, embeddings)\n            self.linear2 = torch.nn.Linear(embeddings, embeddings)\n\n        def forward(self, input):\n            return self.linear2(self.linear1(input))\n\n    model = TwoLinearModel(embeddings).to(device)\n    model.linear1 = QLinear.from_module(model.linear1, weights=qint8, activations=activations)\n    model.linear2 = QLinear.from_module(model.linear2, weights=qint8, activations=activations)\n    qinputs = random_qactivation((1, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device)\n    with torch.no_grad(), Calibration():\n        qout = model(qinputs)\n    assert torch.any(model.linear1.input_scale != 1)\n    assert torch.any(model.linear1.output_scale != 1)\n    assert torch.any(model.linear2.input_scale != 1)\n    assert torch.any(model.linear2.output_scale != 1)\n    assert qout.qtype == activations\n\n\ndef test_calibrate_custom_module_activations_int8(device):\n    _test_calibrate_custom_module(qint8, device)\n\n\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-qfloat8-e5m2\", \"a-qfloat8-e4m3\", \"a-qfloat8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_calibrate_custom_module_activations_float8(activations, device):\n    _test_calibrate_custom_module(activations, device)\n"
  },
  {
    "path": "tests/nn/test_qattention.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional\n\nimport pytest\nimport torch\nimport torch.utils.checkpoint\nfrom helpers import assert_similar, random_tensor\nfrom torch import nn\n\nfrom optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8, quantize\n\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n        )\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, hidden_size=128, num_heads=4, max_position_embeddings=1024, bias=False):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.num_heads = num_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.max_position_embeddings = max_position_embeddings\n        self.rope_theta = 10000.0\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=bias)\n        self.rotary_emb = RotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> torch.Tensor:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        return self.o_proj(attn_output)\n\n\ndef _test_quantize_attention(device, dtype=torch.float32, weights=qint8, activations=None, atol=None):\n    att = Attention().to(dtype).to(device)\n    batch_size = 10\n    seq_len = 64\n    input_shape = (batch_size, seq_len, att.hidden_size)\n    inputs = random_tensor(input_shape).to(device)\n    with torch.no_grad():\n        outputs = att(inputs)\n    quantize(att, weights=weights, activations=activations)\n    if activations is None:\n        with torch.no_grad():\n            qoutputs = att(inputs)\n    else:\n        with torch.no_grad(), Calibration():\n            qoutputs = att(inputs)\n    assert_similar(outputs, qoutputs, atol=atol)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\ndef test_quantize_attention_weights_only(weights, device):\n    _test_quantize_attention(device, weights=weights, atol=1e-4)\n\n\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_attention_weights_only_float8(device):\n    _test_quantize_attention(device, weights=qfloat8_e4m3fn, atol=1e-3)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\ndef test_quantize_attention_activations_int8(weights, device):\n    _test_quantize_attention(device, weights=weights, activations=qint8, atol=1e-3)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-float8-e5m2\", \"a-float8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_attention_activations_float8(weights, activations, device):\n    _test_quantize_attention(device, weights=weights, activations=activations, atol=1e-2)\n"
  },
  {
    "path": "tests/nn/test_qconv2d.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation, random_tensor\n\nfrom optimum.quanto import (\n    ActivationQBytesTensor,\n    Calibration,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint4,\n    qint8,\n)\nfrom optimum.quanto.nn import QConv2d\n\n\ndef _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, dtype, device):\n    conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=use_bias).to(dtype).to(device)\n    qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations)\n    assert qconv2d.qweight.qtype == weights\n    inputs = random_tensor((batch_size,) + img_shape, dtype=dtype, device=device)\n    # Run an inference with Calibration to get the correct output dtype\n    with torch.no_grad(), Calibration():\n        qout = qconv2d(inputs)\n    if activations is not None:\n        assert isinstance(qout, ActivationQBytesTensor)\n        assert qout.qtype == activations\n    # Align weights with quantized linear weights for comparison\n    conv2d.weight = torch.nn.Parameter(qconv2d.qweight.dequantize())\n    out = conv2d(inputs)\n    # We need to increase atol for float16 dtype\n    dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype]\n    # We also need to increase atol for float8 itypes\n    atol = {None: dtype_atol, qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[\n        activations\n    ]\n    assert_similar(out, qout, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\ndef test_quantize_conv2d_float16_activations_int8(batch_size, img_shape, out_channels, use_bias, weights, device):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, qint8, torch.float16, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\ndef test_quantize_conv2d_float32_activations_int8(batch_size, img_shape, out_channels, use_bias, weights, device):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, qint8, torch.float32, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-float8-e5m2\", \"a-float8-e4m3\", \"a-float8_e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_conv2d_float16_activations_float8(\n    batch_size, img_shape, out_channels, use_bias, weights, activations, device\n):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, torch.float16, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-float8-e5m2\", \"a-float8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_conv2d_float32_activations_float8(\n    batch_size, img_shape, out_channels, use_bias, weights, activations, device\n):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, torch.float32, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\ndef test_quantize_conv2d_float16_weight_only(batch_size, img_shape, out_channels, use_bias, weights, device):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, None, torch.float16, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\ndef test_quantize_conv2d_float32_weight_only(batch_size, img_shape, out_channels, use_bias, weights, device):\n    _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, None, torch.float32, device)\n\n\n@pytest.mark.parametrize(\"img_shape\", [(3, 32, 32), (10, 32, 32)])\n@pytest.mark.parametrize(\"out_channels\", [3, 10])\n@pytest.mark.parametrize(\"activations\", [None, qint8], ids=[\"a-float\", \"a-int8\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-int4\", \"w-int8\"])\ndef test_qconv2d_gradient(img_shape, out_channels, activations, weights, device):\n    batch_size = 10\n    conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=True).to(device)\n    qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations)\n    assert qconv2d.weight.requires_grad is True\n    assert qconv2d.bias.requires_grad is True\n    # Run an inference with identical inputs\n    qinputs = random_qactivation((batch_size,) + img_shape, dtype=torch.float32).to(device)\n    qout = qconv2d(qinputs)\n    out = conv2d(qinputs.dequantize())\n    # Outputs are not identical because of the quantization\n    assert not torch.equal(qout, out)\n    # Compute gradients and compare\n    gradient = torch.randn(qout.size()).to(device)\n    qout.backward(gradient)\n    out.backward(gradient)\n    # Gradients are nearly identical because they depend only on the input\n    atol = 1e-5\n    assert_similar(qconv2d.weight.grad, conv2d.weight.grad, atol=atol)\n    assert_similar(qconv2d.bias.grad, conv2d.bias.grad, atol=atol)\n"
  },
  {
    "path": "tests/nn/test_qlayernorm.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation\n\nfrom optimum.quanto import ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8\nfrom optimum.quanto.nn import QLayerNorm\n\n\ndef _test_quantize_layernorm(batch_size, tokens, embeddings, affine, dtype, activations, device):\n    # Instantiate a normalization layer\n    norm = torch.nn.LayerNorm(embeddings, elementwise_affine=affine).to(dtype).to(device)\n    qnorm = QLayerNorm.from_module(norm, activations=activations)\n    qinputs = random_qactivation((batch_size,) + (tokens, embeddings), qtype=activations, dtype=dtype).to(device)\n    # Calibrate to avoid clipping and to set the correct dtype\n    with torch.no_grad(), Calibration():\n        qout = qnorm(qinputs)\n    qout = qnorm(qinputs)\n    assert isinstance(qout, ActivationQBytesTensor)\n    assert qout.dtype == dtype\n    assert qout.qtype == activations\n    # Compare with the float results\n    out = norm(qinputs.dequantize())\n    # We need to increase atol for float16 dtype\n    dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype]\n    # We also need to increase atol for float8 qtypes\n    atol = {qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[activations]\n    assert_similar(out, qout, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"affine\", [True, False], ids=[\"affine\", \"non-affine\"])\ndef test_quantize_layernorm_float16_activations_int8(batch_size, tokens, embeddings, affine, device):\n    _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, qint8, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"affine\", [True, False], ids=[\"affine\", \"non-affine\"])\ndef test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddings, affine, device):\n    _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, qint8, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"affine\", [True, False], ids=[\"affine\", \"non-affine\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-float8-e5m2\", \"a-float8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embeddings, affine, activations, device):\n    _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, activations, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"affine\", [True, False], ids=[\"affine\", \"non-affine\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-float8-e5m2\", \"a-float8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_layernorm_float32_activations_float8(batch_size, tokens, embeddings, affine, activations, device):\n    _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, activations, device)\n\n\ndef test_quantize_layernom_no_activation():\n    norm = torch.nn.LayerNorm(32)\n    qnorm = QLayerNorm.from_module(norm, activations=None)\n    assert qnorm is None\n"
  },
  {
    "path": "tests/nn/test_qlinear.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport io\nfrom contextlib import nullcontext\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation, random_tensor\n\nfrom optimum.quanto import (\n    ActivationQBytesTensor,\n    Calibration,\n    absmax_scale,\n    qfloat8,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint4,\n    qint8,\n    quantize_activation,\n)\nfrom optimum.quanto.nn import QLinear\n\n\ndef _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, activations, dtype, device, atol=None):\n    linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device)\n    qlinear = QLinear.from_module(linear, weights=weights, activations=activations)\n    assert qlinear.qweight.qtype == weights\n    input_shape = (batch_size, tokens, embeddings)\n    if activations is not None:\n        qinputs = random_qactivation(input_shape, qtype=activations, dtype=dtype).to(device)\n        inputs = qinputs.dequantize()\n    else:\n        inputs = random_tensor(input_shape, dtype=dtype, device=device)\n    # Run an inference with Calibration to get the correct output dtype\n    context = nullcontext if activations is None else Calibration\n    with torch.no_grad(), context():\n        qout = qlinear(inputs if activations is None else qinputs)\n    if activations is not None:\n        assert isinstance(qout, ActivationQBytesTensor)\n        assert qout.qtype == activations\n    # Align linear weights with quantized linear weights for comparison\n    linear.weight = torch.nn.Parameter(qlinear.qweight.dequantize())\n    out = linear(inputs)\n    assert_similar(out, qout, atol=atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16], ids=[\"bf16\", \"fp16\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\ndef test_quantize_linear_float16_activations_int8(batch_size, tokens, embeddings, use_bias, dtype, weights, device):\n    _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float16, device)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\ndef test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings, use_bias, weights, device):\n    # Default atol for float32 is 1e-6\n    atol = 1e-4\n    _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float32, device, atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16], ids=[\"bf16\", \"fp16\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-qfloat8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_linear_float16_activations_float8(\n    batch_size, tokens, embeddings, use_bias, dtype, weights, activations, device\n):\n    atol = 5e-3\n    _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, activations, dtype, device, atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(32, 32), (10, 32)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-qfloat8-e5m2\", \"a-qfloat8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_linear_float32_activations_float8(\n    batch_size, tokens, embeddings, use_bias, weights, activations, device\n):\n    atol = 5e-3\n    _test_quantize_linear(\n        batch_size, tokens, embeddings, use_bias, weights, activations, torch.float32, device, atol=atol\n    )\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8, qfloat8], ids=[\"w-qint4\", \"w-qint8\", \"float8\"])\ndef test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use_bias, weights, device):\n    if device.type in [\"mps\"] and weights == qfloat8:\n        pytest.skip(f\"Float8 are not supported on {device.type} device\")\n    atol = None\n    if device.type == \"cuda\" and weights == qfloat8 and embeddings % 64 == 0:\n        # FIXME: accuracy is slightly worse using MARLIN FP8 kernels\n        atol = 1e-2\n    _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device, atol)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\ndef test_quantize_linear_float32_weight_only(batch_size, tokens, embeddings, use_bias, weights, device):\n    _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float32, device)\n\n\n@pytest.mark.parametrize(\"tokens, embeddings\", [(10, 32), (10, 256)])\n@pytest.mark.parametrize(\"activations\", [None, qint8, qfloat8], ids=[\"a-float\", \"a-qint8\", \"a-float8\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8, qfloat8], ids=[\"w-qint4\", \"w-qint8\", \"w-float8\"])\ndef test_qlinear_gradient(tokens, embeddings, activations, weights, device):\n    if device.type in [\"mps\"] and (activations == qfloat8 or weights == qfloat8):\n        pytest.skip(f\"Float8 is not supported on {device.type} device\")\n    batch_size = 10\n    linear = torch.nn.Linear(embeddings, embeddings).to(device)\n    qlinear = QLinear.from_module(linear, weights=weights, activations=activations)\n    assert qlinear.weight.requires_grad is True\n    assert qlinear.bias.requires_grad is True\n    # Run an inference with dynamically quantized inputs\n    inputs = random_tensor((batch_size, tokens, embeddings), dtype=torch.float32, device=device)\n    inputs.requires_grad = True\n    if activations is None:\n        qout = qlinear(inputs)\n        float_inputs = inputs.clone().detach()\n    else:\n        qinputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs, activations))\n        qout = qlinear(qinputs)\n        # Run an equivalent inference with float inputs\n        float_inputs = qinputs.dequantize().clone().detach()\n    float_inputs.requires_grad = True\n    out = linear(float_inputs)\n    # Outputs are not identical because of the quantization\n    assert not torch.equal(qout, out)\n    # Compute gradients and compare\n    gradient = torch.randn(qout.size()).to(device)\n    qout.backward(gradient)\n    out.backward(gradient)\n    # Bias gradients are identical because they don't depend on inputs and weights\n    atol = 1e-6\n    assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol)\n    # Weights gradients are nearly identical, based on identical inputs through subtly different graphs\n    atol = 1e-5\n    assert_similar(qlinear.weight.grad, linear.weight.grad, atol=atol)\n    # Inputs gradients are slightly different because they depend on the quantized weights\n    atol = {qint8: 1e-5, qint4: 5e-3, qfloat8: 5e-3}[weights]\n    assert_similar(inputs.grad, float_inputs.grad, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8, qfloat8], ids=[\"w-int4\", \"w-int8\", \"w-float8\"])\ndef test_move_qlinear(dtype, use_bias, weights, device):\n    linear = torch.nn.Linear(1024, 1024, bias=use_bias).to(dtype)\n    qlinear = QLinear.from_module(linear, weights=weights)\n    qlinear.freeze()\n    qlinear.to(device)\n    inner_tensor_names, _ = qlinear.weight.__tensor_flatten__()\n    for name in inner_tensor_names:\n        assert getattr(qlinear.weight, name).device.type == device.type\n    if use_bias:\n        assert qlinear.bias.device.type == device.type\n\n\n@pytest.mark.parametrize(\"features\", [10, 256], ids=[\"per-axis\", \"per-group\"])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"weights\", [qint4, qint8, qfloat8], ids=[\"w-qint4\", \"w-qint8\", \"w-qfloat8\"])\n@pytest.mark.parametrize(\"activations\", [None, qint8, qfloat8], ids=[\"a-float\", \"a-qint8\", \"a-qfloat8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16], ids=[\"fp16\", \"bf16\"])\n@pytest.mark.parametrize(\"weights_only\", [True, False], ids=[\"weights-only\", \"pickle\"])\ndef test_qlinear_serialization(features, use_bias, activations, weights, dtype, weights_only, device):\n    if device.type in [\"mps\"] and (activations == qfloat8 or weights == qfloat8):\n        pytest.skip(f\"Float8 is not supported on {device.type} device\")\n    linear = torch.nn.Linear(features, features, bias=use_bias).to(dtype).to(device)\n    qlinear = QLinear.from_module(linear, weights=weights, activations=activations)\n    if activations is not None:\n        qinputs = random_qactivation((10, 10, features), qtype=activations, dtype=dtype).to(device)\n        with Calibration():\n            qlinear(qinputs)\n    qlinear.freeze()\n    b = io.BytesIO()\n    torch.save(qlinear.state_dict(), b)\n    b.seek(0)\n    state_dict = torch.load(b, weights_only=weights_only)\n    qlinear_reloaded = QLinear(features, features, weights=weights, activations=activations, bias=use_bias).to(device)\n    qlinear_reloaded.load_state_dict(state_dict)\n    assert qlinear_reloaded.weight_qtype == weights\n    w = qlinear.weight\n    w_reloaded = qlinear_reloaded.weight\n    assert torch.equal(w, w_reloaded)\n    if activations is not None:\n        assert qlinear_reloaded.activation_qtype == activations\n        for attr in [\"input_scale\", \"output_scale\"]:\n            v = getattr(qlinear, attr)\n            v_reloaded = getattr(qlinear_reloaded, attr)\n            assert torch.equal(v, v_reloaded)\n"
  },
  {
    "path": "tests/nn/test_qmodule.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\nfrom optimum.quanto import QTensor, qint8, qtypes\nfrom optimum.quanto.nn import QLinear\n\n\n@pytest.mark.parametrize(\"in_features\", [8, 16])\n@pytest.mark.parametrize(\"out_features\", [32, 64])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16], ids=[\"fp32\", \"fp16\"])\ndef test_qmodule_freeze(in_features, out_features, use_bias, dtype):\n    qlinear = QLinear(in_features, out_features, bias=use_bias, weights=qint8).to(dtype)\n    assert not qlinear.frozen\n    assert not isinstance(qlinear.weight, QTensor)\n    assert qlinear.weight.dtype == dtype\n    if use_bias:\n        assert not isinstance(qlinear.bias, QTensor)\n        assert qlinear.bias.dtype == dtype\n    qweight = qlinear.qweight\n    assert isinstance(qweight, QTensor)\n    assert qweight.dtype == dtype\n    assert qweight.qtype == qint8\n    qlinear.freeze()\n    assert qlinear.frozen\n    assert isinstance(qlinear.weight, QTensor)\n    assert qlinear.weight.dtype == dtype\n    assert qlinear.weight.qtype == qint8\n    if use_bias:\n        assert not isinstance(qlinear.bias, QTensor)\n        assert qlinear.bias.dtype == dtype\n\n\n@pytest.mark.parametrize(\"weights\", [\"qint2\", \"qint4\", \"qint8\", \"qfloat8\"])\n@pytest.mark.parametrize(\"activations\", [None, \"qint8\", \"qfloat8\"])\ndef test_qmodule_qtype_as_string(weights, activations):\n    qlinear = QLinear(16, 64, weights=weights, activations=activations)\n    assert qlinear.weight_qtype == qtypes[weights]\n    assert qlinear.activation_qtype is None if activations is None else qtypes[activations]\n"
  },
  {
    "path": "tests/quantize/test_quantize_mlp.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom contextlib import nullcontext\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, get_device_memory, random_tensor\n\nfrom optimum.quanto import (\n    AbsmaxOptimizer,\n    ActivationQBytesTensor,\n    Calibration,\n    MaxOptimizer,\n    QLinear,\n    QTensor,\n    absmax_scale,\n    freeze,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint4,\n    qint8,\n    quantize,\n    quantize_activation,\n)\n\n\nclass MLP(torch.nn.Module):\n    def __init__(self, input_size, output_size, hidden_size):\n        super().__init__()\n        self.input_layer = torch.nn.Linear(input_size, hidden_size)\n        self.mid_layer = torch.nn.Linear(hidden_size, hidden_size)\n        self.output_layer = torch.nn.Linear(hidden_size, output_size)\n\n    def forward(self, inputs):\n        x = torch.nn.functional.relu(self.input_layer(inputs))\n        x = torch.nn.functional.relu(self.mid_layer(x))\n        return torch.nn.functional.softmax(self.output_layer(x), dim=-1)\n\n\ndef check_mlp(model, frozen):\n    assert isinstance(model.input_layer, QLinear)\n    assert isinstance(model.mid_layer, QLinear)\n    assert isinstance(model.output_layer, QLinear)\n    if frozen:\n        assert isinstance(model.input_layer.weight, QTensor)\n        assert isinstance(model.mid_layer.weight, QTensor)\n        assert isinstance(model.output_layer.weight, QTensor)\n\n\ndef _test_quantize_mlp(weights, activations, optimizer, frozen, device, atol=1e-6):\n    model = MLP(32, 10, 128).to(device)\n    inputs = random_tensor((1, 32), dtype=torch.float32, device=device)\n    output = model(inputs)\n    quantize(model, weights=weights, activations=activations, optimizer=optimizer)\n    if frozen:\n        freeze(model)\n    check_mlp(model, frozen)\n    if activations is not None:\n        inputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs))\n        context = Calibration\n    else:\n        context = nullcontext\n    with context():\n        qoutput = model(inputs)\n    if activations is not None:\n        assert isinstance(qoutput, ActivationQBytesTensor)\n    assert_similar(output, qoutput, atol=atol)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\"frozen\", [True, False], ids=[\"frozen\", \"non-frozen\"])\ndef test_quantize_mlp_weights_only(weights, frozen, device):\n    _test_quantize_mlp(weights, None, None, frozen, device)\n\n\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"weights\", [qfloat8_e4m3fn], ids=[\"w-float8_e4m3fn\"])\n@pytest.mark.parametrize(\"frozen\", [True, False], ids=[\"frozen\", \"non-frozen\"])\ndef test_quantize_mlp_weights_only_float8(weights, frozen, device):\n    _test_quantize_mlp(weights, None, None, frozen, device)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\"frozen\", [True, False], ids=[\"frozen\", \"non-frozen\"])\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_mlp_int8_activations(weights, frozen, device):\n    _test_quantize_mlp(weights, qint8, None, frozen, device, atol=1e-3)\n\n\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\n    \"activations\",\n    [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz],\n    ids=[\"a-qfloat8-e5m2\", \"a-qfloat8-e4m3\", \"a-float8-e4m3-uz\"],\n)\n@pytest.mark.parametrize(\"frozen\", [True, False], ids=[\"frozen\", \"non-frozen\"])\n@pytest.mark.skip_device(\"mps\")\ndef test_quantize_mlp_float8_activations(weights, activations, frozen, device):\n    atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e4m3fnuz: 1e-3, qfloat8_e5m2: 1e-2}[activations]\n    _test_quantize_mlp(weights, activations, None, frozen, device, atol=atol)\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"weights_only\", [True, False], ids=[\"weights-only\", \"pickle\"])\ndef test_quantized_mlp_device_memory(weights, dtype, weights_only, device):\n    # We might not start from a clean state\n    base_memory = get_device_memory(device)\n    input_features = 1024\n    hidden_features = 2048\n    output_features = 1024\n    model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)\n    full_precision_memory = get_device_memory(device)\n    assert full_precision_memory > base_memory\n    quantize(model, weights=weights)\n    freeze(model)\n    quantized_memory = get_device_memory(device)\n    assert quantized_memory > base_memory\n    assert quantized_memory < full_precision_memory\n\n\n@pytest.mark.parametrize(\n    \"weights, optimizer\", [[qint8, AbsmaxOptimizer()], [qint4, MaxOptimizer()]], ids=[\"w-qint8\", \"w-qint4\"]\n)\n@pytest.mark.parametrize(\"frozen\", [True, False], ids=[\"frozen\", \"non-frozen\"])\ndef test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device):\n    atol = {qint4: 1e-4, qint8: 1e-6}[weights]\n    _test_quantize_mlp(weights, None, optimizer, frozen, device, atol=atol)\n\n\n@pytest.mark.parametrize(\n    \"weights, optimizer\", [[qint8, MaxOptimizer()], [qint4, AbsmaxOptimizer()]], ids=[\"w-qint8\", \"w-qint4\"]\n)\ndef test_quantize_mlp_wrong_optimizer(weights, optimizer, device):\n    with pytest.raises(ValueError):\n        _test_quantize_mlp(weights, None, optimizer, False, device)\n"
  },
  {
    "path": "tests/quantize/test_quantize_patterns.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\n\nfrom optimum.quanto import (\n    qint8,\n    quantize,\n)\nfrom optimum.quanto.nn import QModuleMixin\n\n\nclass MLP(torch.nn.Module):\n    def __init__(self, input_size, output_size, hidden_size):\n        super().__init__()\n        self.input_layer = torch.nn.Linear(input_size, hidden_size)\n        self.mid_layer = torch.nn.Linear(hidden_size, hidden_size)\n        self.output_layer = torch.nn.Linear(hidden_size, output_size)\n\n    def forward(self, inputs):\n        x = torch.nn.functional.relu(self.input_layer(inputs))\n        x = torch.nn.functional.relu(self.mid_layer(x))\n        return self.output_layer(x)\n\n\nclass ClassificationModel(torch.nn.Module):\n    def __init__(self, input_size, output_size, hidden_size, classes):\n        super().__init__()\n        self.model = MLP(input_size, output_size, hidden_size)\n        self.lm_head = torch.nn.Linear(output_size, classes)\n\n    def forward(self, inputs):\n        x = self.model(inputs)\n        return torch.nn.functional.softmax(self.classifier(x), dim=-1)\n\n\ndef has_children(module: torch.nn.Module):\n    return next(module.children(), None) is not None\n\n\ndef leaf_module_names(module: torch.nn.Module):\n    return [name for name, m in module.named_modules() if not has_children(m)]\n\n\ndef parent_module_names(module: torch.nn.Module):\n    return [name for name, m in module.named_children() if has_children(m)]\n\n\ndef test_quantize_mlp_include_explicit_layers():\n    model = ClassificationModel(32, 10, 128, 10)\n    include_names = leaf_module_names(model)\n    for include in include_names:\n        model = ClassificationModel(32, 10, 128, 10)\n        quantize(model, weights=qint8, include=include)\n        for name, m in model.named_modules():\n            if name == include:\n                assert isinstance(m, QModuleMixin)\n            else:\n                assert not isinstance(m, QModuleMixin)\n\n\ndef test_quantize_mlp_exclude_explicit_layers():\n    model = ClassificationModel(32, 10, 128, 10)\n    exclude_names = leaf_module_names(model)\n    for exclude in exclude_names:\n        model = ClassificationModel(32, 10, 128, 10)\n        quantize(model, weights=qint8, exclude=exclude)\n        for name, m in model.named_modules():\n            if name == exclude:\n                assert not isinstance(m, QModuleMixin)\n            elif not has_children(m):\n                assert isinstance(m, QModuleMixin)\n\n\ndef test_quantize_mlp_include_layer_patterns():\n    model = ClassificationModel(32, 10, 128, 10)\n    parent_names = parent_module_names(model)\n    for parent_name in parent_names:\n        model = ClassificationModel(32, 10, 128, 10)\n        quantize(model, weights=qint8, include=f\"{parent_name}*\")\n        for name, m in model.named_modules():\n            if name.startswith(parent_name) and not has_children(m):\n                assert isinstance(m, QModuleMixin)\n            else:\n                assert not isinstance(m, QModuleMixin)\n\n\ndef test_quantize_mlp_exclude_layer_patterns():\n    model = ClassificationModel(32, 10, 128, 10)\n    parent_names = parent_module_names(model)\n    for parent_name in parent_names:\n        model = ClassificationModel(32, 10, 128, 10)\n        quantize(model, weights=qint8, exclude=f\"{parent_name}*\")\n        for name, m in model.named_modules():\n            if name.startswith(parent_name):\n                assert not isinstance(m, QModuleMixin)\n            elif not has_children(m):\n                assert isinstance(m, QModuleMixin)\n"
  },
  {
    "path": "tests/quantize/test_requantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport io\nfrom tempfile import NamedTemporaryFile\n\nimport pytest\nimport torch\nfrom helpers import get_device_memory, random_tensor\nfrom safetensors.torch import load_file, save_file\nfrom test_quantize_mlp import MLP\n\nfrom optimum.quanto import Calibration, freeze, qint4, qint8, quantization_map, quantize, requantize\nfrom optimum.quanto.nn import QModuleMixin\n\n\ndef save_and_reload_state_dict(state_dict, serialization):\n    if serialization == \"safetensors\":\n        with NamedTemporaryFile() as tmp_file:\n            save_file(state_dict, tmp_file.name)\n            return load_file(tmp_file.name)\n    else:\n        b = io.BytesIO()\n        torch.save(state_dict, b)\n        b.seek(0)\n        weights_only = serialization == \"weights_only\"\n        return torch.load(b, weights_only=weights_only)\n\n\n@pytest.mark.parametrize(\n    \"input_features, hidden_features, output_features\",\n    [(32, 10, 128), (1024, 1024, 1024)],\n    ids=[\"small\", \"large\"],\n)\n@pytest.mark.parametrize(\"weights\", [qint4, qint8], ids=[\"w-qint4\", \"w-qint8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"serialization\", [\"weights_only\", \"pickle\", \"safetensors\"])\n@pytest.mark.parametrize(\"activations\", [None, qint8], ids=[\"a-none\", \"a-qint8\"])\ndef test_requantize_serialized_model(\n    input_features, hidden_features, output_features, weights, activations, dtype, serialization, device\n):\n    model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)\n    quantize(model, weights=weights, activations=activations)\n    inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device)\n    if activations is not None:\n        with Calibration():\n            model(inputs)\n    freeze(model)\n    qmap = quantization_map(model)\n    model_reloaded = MLP(input_features, hidden_features, output_features).to(device)\n    state_dict = save_and_reload_state_dict(model.state_dict(), serialization)\n    requantize(model_reloaded, state_dict, qmap)\n    for name, module in model.named_modules():\n        if isinstance(module, QModuleMixin):\n            module_reloaded = getattr(model_reloaded, name)\n            assert torch.equal(module_reloaded.weight, module.weight)\n            assert module_reloaded.weight_qtype == module.weight_qtype\n            assert module_reloaded.activation_qtype == module.activation_qtype\n            assert torch.equal(module_reloaded.input_scale, module.input_scale)\n            assert torch.equal(module_reloaded.output_scale, module.output_scale)\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.parametrize(\"weights\", [qint8], ids=[\"w-qint8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"serialization\", [\"weights_only\", \"pickle\", \"safetensors\"])\ndef test_requantized_model_device_memory(weights, dtype, serialization, device):\n    input_features = 1024\n    hidden_features = 2048\n    output_features = 1024\n    model = MLP(input_features, hidden_features, output_features).to(dtype).to(device)\n    full_precision_memory = get_device_memory(device)\n    quantize(model, weights=weights)\n    freeze(model)\n    qmap = quantization_map(model)\n    quantized_memory = get_device_memory(device)\n    assert quantized_memory < full_precision_memory\n    state_dict = save_and_reload_state_dict(model.state_dict(), serialization)\n    # Free device memory\n    del model\n    with torch.device(\"meta\"):\n        reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype)\n    requantize(reloaded_model, state_dict, qmap, device)\n    # Free device memory\n    del state_dict\n    requantized_memory = get_device_memory(device)\n    assert requantized_memory <= quantized_memory\n"
  },
  {
    "path": "tests/tensor/activations/test_activations_compile.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import random_tensor\n\nfrom optimum.quanto import ActivationQBytesTensor, absmax_scale, qint8, quantize_activation\n\n\ndef compile_for_device(f, device):\n    # Remove any side-effects form previous compilation\n    torch.compiler.reset()\n    # Inductor relies on Triton for inference which does not support MPS\n    backend = \"aot_eager\" if device == torch.device(\"mps\") else \"inductor\"\n    return torch.compile(f, backend=backend)\n\n\n@pytest.mark.skip(\"Disabled as it is not working (yet ?)\")\n@pytest.mark.parametrize(\"input_shape\", [(2, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"qtype\", [qint8], ids=[\"qint8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16, torch.bfloat16], ids=[\"fp32\", \"fp16\", \"bf16\"])\ndef test_compile_quantize_tensor(input_shape, qtype, dtype, device):\n    if device == torch.device(\"mps\") and dtype == torch.bfloat16:\n        pytest.skip(\"BFloat16 is not supported on MPS\")\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n\n    def f(x, qtype):\n        scale = absmax_scale(x)\n        return quantize_activation(x, qtype=qtype, scale=scale)\n\n    compiled_f = compile_for_device(f, device)\n    qa = compiled_f(a, qtype)\n    assert isinstance(qa, ActivationQBytesTensor)\n    assert qa.qtype == qtype\n    assert qa._scale.dtype == dtype\n    assert qa.axis is None\n\n\ndef test_compile_qtensor_to(device):\n    input_shape = (10, 32, 32)\n    a = random_tensor(input_shape).to(device)\n\n    def f(x, dtype):\n        return x.to(dtype)\n\n    compiled_f = compile_for_device(f, device)\n\n    scale = absmax_scale(a)\n    qa = quantize_activation(a, qtype=qint8, scale=scale)\n    cqa = compiled_f(qa, torch.float16)\n    assert isinstance(cqa, ActivationQBytesTensor)\n    assert cqa.qtype == qint8\n    assert cqa._scale.dtype == torch.float16\n"
  },
  {
    "path": "tests/tensor/activations/test_activations_dispatch.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation, random_tensor\n\nfrom optimum.quanto import ActivationQBytesTensor, quantize_activation\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (1, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"scalar\", [1, 0.5, torch.tensor(0.12)], ids=[\"int\", \"float\", \"tensor\"])\ndef test_qactivation_mul_scalar(input_shape, scalar, device):\n    qa = random_qactivation(input_shape, dtype=torch.float32).to(device)\n    if isinstance(scalar, torch.Tensor):\n        scalar = scalar.to(device)\n    qprod = qa * scalar\n    assert isinstance(qprod, ActivationQBytesTensor)\n    prod = qa.dequantize() * scalar\n    assert_similar(prod, qprod)\n    qprod = scalar * qa\n    assert isinstance(qprod, ActivationQBytesTensor)\n    prod = scalar * qa.dequantize()\n    assert_similar(prod, qprod)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(5, 5), (32, 32), (10, 32)])\ndef test_qactivation_relu(batch_size, tokens, embeddings, device):\n    qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)\n    qout = torch.nn.functional.relu(qinputs)\n    assert isinstance(qout, ActivationQBytesTensor)\n    assert torch.equal(qout._data, torch.maximum(qinputs._data, torch.zeros((1,)).to(device)))\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(5, 5), (32, 32), (10, 32)])\ndef test_qactivation_softmax(batch_size, tokens, embeddings, device):\n    qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device)\n    qout = torch.nn.functional.softmax(qinputs, dim=-1)\n    assert isinstance(qout, ActivationQBytesTensor)\n    assert torch.min(qout.dequantize()) >= 0\n    assert torch.max(qout.dequantize()) <= 1\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (10, 32)])\ndef test_qactivation_view(input_shape, device):\n    qinputs = random_qactivation(input_shape, dtype=torch.float32).to(device)\n    qview = qinputs.view((1,) + input_shape)\n    assert isinstance(qview, ActivationQBytesTensor)\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (10, 32)])\ndef test_qactivation_cat(input_shape, device):\n    qinputs = random_qactivation(input_shape, dtype=torch.float32).to(device)\n    other = random_tensor(input_shape, dtype=torch.float32).to(device)\n    # First, quantize other with the same scale\n    qother = quantize_activation(other, qtype=qinputs.qtype, scale=qinputs._scale)\n    qcat = torch.cat([qinputs, qother])\n    assert isinstance(qcat, ActivationQBytesTensor)\n    assert_similar(torch.cat([qinputs.dequantize(), qother.dequantize()]), qcat)\n\n\ndef test_qactivation_transpose_2d(device):\n    input_shape = (4, 6)\n    qinputs = random_qactivation(input_shape).to(device)\n    qtransposed = qinputs.t()\n    assert qtransposed.qtype == qinputs.qtype\n    assert qtransposed.shape == input_shape[::-1]\n    assert torch.equal(qtransposed.dequantize(), qinputs.dequantize().t())\n\n\ndef test_qactivation_transpose(device):\n    input_shape = (10, 32, 64)\n    qinputs = random_qactivation(input_shape).to(device)\n    qtransposed = torch.transpose(qinputs, 1, 2)\n    assert qtransposed.qtype == qinputs.qtype\n    assert torch.equal(qtransposed.dequantize(), torch.transpose(qinputs.dequantize(), 1, 2))\n"
  },
  {
    "path": "tests/tensor/activations/test_activations_quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, device_eq, random_tensor\n\nfrom optimum.quanto import (\n    ActivationQBytesTensor,\n    absmax_scale,\n    qfloat8,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint8,\n)\n\n\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8], ids=[\"qint8\"])\ndef test_symmetric_quantize_int(input_shape, dtype, qtype, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=None)\n    qa = ActivationQBytesTensor.quantize(a, qtype, scale)\n    assert isinstance(qa, ActivationQBytesTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    assert_similar(a, qa)\n\n\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\n    \"qtype\",\n    [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],\n    ids=[\"qfloat8\", \"qfloat8_e4m3fn\", \"qfloat8_e4m3fnuz\", \"qfloat8_e5m2\"],\n)\ndef test_symmetric_quantize_float8(input_shape, dtype, qtype, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=None)\n    qa = ActivationQBytesTensor.quantize(a, qtype, scale)\n    assert isinstance(qa, ActivationQBytesTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    assert_similar(a, qa, atol=5e-3)\n"
  },
  {
    "path": "tests/tensor/ops/test_linear_dispatch.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation, random_qweight, random_tensor\n\nfrom optimum.quanto import qint2, qint4, qint8\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(5, 5), (32, 32), (10, 32)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16], ids=[\"fp32\", \"fp16\"])\n@pytest.mark.parametrize(\"activation_qtype\", [None, qint8], ids=[\"a-none\", \"a-qint8\"])\n@pytest.mark.parametrize(\"weight_qtype\", [qint2, qint4, qint8], ids=[\"w-qint2\", \"w-qint4\", \"w-qint8\"])\ndef test_qactivation_qweight_linear(\n    batch_size, tokens, embeddings, use_bias, dtype, activation_qtype, weight_qtype, device\n):\n    input_shape = (batch_size, tokens, embeddings)\n    if activation_qtype is None:\n        inputs = random_tensor(input_shape, dtype=dtype).to(device)\n    else:\n        inputs = random_qactivation(input_shape, qtype=activation_qtype, dtype=dtype).to(device)\n    qweight = random_qweight((embeddings, embeddings), qtype=weight_qtype, dtype=dtype, axis=0).to(device)\n    bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None\n    qout = torch.nn.functional.linear(inputs, qweight, bias)\n    if activation_qtype is not None:\n        inputs = inputs.dequantize()\n    out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias)\n    assert_similar(out, qout)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(256, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_linear_fp16_int4(batch_size, tokens, embeddings, use_bias, device):\n    dtype = torch.float16\n    weight_qtype = qint4\n    inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device)\n    qweight = random_qweight((embeddings, embeddings), weight_qtype, dtype=dtype, axis=0, group_size=128).to(device)\n    bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None\n    qout = torch.nn.functional.linear(inputs, qweight, bias)\n    out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias)\n    assert_similar(out, qout)\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"tokens, embeddings\", [(256, 256)])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_linear_bf16_int4(batch_size, tokens, embeddings, use_bias, device):\n    dtype = torch.bfloat16\n    weight_qtype = qint4\n    input_shape = (batch_size, tokens, embeddings)\n    inputs = torch.rand(input_shape, dtype=dtype, device=device)\n    weight_shape = (embeddings, embeddings)\n    qweight = random_qweight(weight_shape, weight_qtype, dtype=dtype, axis=0, group_size=128, device=device)\n    bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None\n    qout = torch.nn.functional.linear(inputs, qweight, bias)\n    out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias)\n    assert_similar(out, qout)\n"
  },
  {
    "path": "tests/tensor/ops/test_mm_dispatch.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qactivation, random_qweight\n\nfrom optimum.quanto import qint8\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16], ids=[\"fp32\", \"fp16\"])\n@pytest.mark.parametrize(\"in_features\", [5, 16, 24])\n@pytest.mark.parametrize(\"hidden\", [5, 16, 24])\n@pytest.mark.parametrize(\"out_features\", [5, 16, 24])\ndef test_qactivation_qweight_matmul(dtype, in_features, hidden, out_features, device):\n    qa = random_qactivation((in_features, hidden), qint8, dtype=dtype).to(device)\n    qb = random_qweight((hidden, out_features), qint8, dtype=dtype, axis=-1).to(device)\n    qmatmul = torch.matmul(qa, qb)\n    # The outputs should be almost identical if we use the dequantized inputs\n    matmul = torch.matmul(qa.dequantize(), qb.dequantize())\n    assert_similar(matmul, qmatmul)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16], ids=[\"fp32\", \"fp16\"])\n@pytest.mark.parametrize(\"batch_size\", [1, 10])\n@pytest.mark.parametrize(\"a_shape, b_shape\", [[(16, 32), (32, 24)], [(5, 10), (10, 6)]])\ndef test_qactivation_qactivation_bmm(dtype, batch_size, a_shape, b_shape, device):\n    qa = random_qactivation((batch_size,) + a_shape, qint8, dtype=dtype).to(device)\n    qb = random_qactivation((batch_size,) + b_shape, qint8, dtype=dtype).to(device)\n    qbmm = torch.bmm(qa, qb)\n    # The outputs should be almost identical if we use the dequantized inputs\n    bmm = torch.bmm(qa.dequantize(), qb.dequantize())\n    assert_similar(bmm, qbmm)\n"
  },
  {
    "path": "tests/tensor/optimizers/test_hqq_optimizer.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import random_tensor\n\nfrom optimum.quanto import (\n    HqqOptimizer,\n    MaxOptimizer,\n    WeightQBitsTensor,\n    qint2,\n    qint4,\n)\n\n\ndef compare_quantized_tensor(a, qtype, axis, group_size, scale, shift):\n    qa = WeightQBitsTensor.quantize(a, qtype, axis, group_size, scale, shift)\n    # Evaluate mean absolute error\n    mean_error = torch.mean(torch.abs(a - qa))\n    # Also evaluate cosine similarity\n    sim = torch.nn.functional.cosine_similarity(a.flatten(), qa.flatten(), dim=0)\n    return mean_error, sim\n\n\n@pytest.mark.parametrize(\"input_shape\", [(1024, 1024)])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16], ids=[\"bf16\", \"fp16\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"group_size\", [32, 64, 128])\ndef test_hqq_optimizer(input_shape, dtype, qtype, axis, group_size, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    max_scale, max_shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size)\n    max_mean_error, max_sim = compare_quantized_tensor(a, qtype, axis, group_size, max_scale, max_shift)\n    hqq_scale, hqq_shift = HqqOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size)\n    hqq_mean_error, hqq_sim = compare_quantized_tensor(a, qtype, axis, group_size, hqq_scale, hqq_shift)\n    # HQQ optimizes the mean error, so it should be lower\n    assert hqq_mean_error <= max_mean_error\n    # FIXME: HQQ cosine similarity should be also closer to 1\n"
  },
  {
    "path": "tests/tensor/test_absmax.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import random_tensor\n\nfrom optimum.quanto import absmax_scale, qfloat8, qint8\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (1, 10), (2, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"qtype\", [qint8, qfloat8], ids=[\"qint8\", \"qfloat8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"axis\", [None, 0, -1], ids=[\"per-tensor\", \"first-axis\", \"last-axis\"])\ndef test_absmax_scale(input_shape, axis, dtype, qtype, device):\n    if device.type == \"mps\" and qtype.is_floating_point:\n        pytest.skip(\"Float8 are not supported on MPS device\")\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype, axis)\n    assert scale.dtype == dtype\n    if axis is None:\n        assert scale.ndim == 0\n    else:\n        assert scale.ndim == a.ndim\n        sscale = torch.squeeze(scale)\n        if a.ndim == 1 or a.shape[axis] == 1:\n            # Quantization is actually per-tensor as the axis dim is 1\n            assert sscale.ndim == 0\n        else:\n            assert sscale.ndim == 1\n"
  },
  {
    "path": "tests/tensor/test_packed_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport io\n\nimport pytest\nimport torch\nfrom helpers import device_eq\n\nfrom optimum.quanto.tensor.packed import PackedTensor\n\n\n@pytest.mark.parametrize(\"shape\", [(10,), (12,), (10, 10), (12, 10), (32, 32)])\n@pytest.mark.parametrize(\"bits\", [2, 4], ids=[\"int2\", \"int4\"])\ndef test_pack_tensor(shape, bits, device):\n    \"\"\"This test verifies that an integer tensor in the correct range is preserved.\"\"\"\n    qmax = 2**bits\n    t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    packed = PackedTensor.pack(t, bits=bits)\n\n    assert isinstance(packed, PackedTensor)\n    assert packed.dtype == torch.uint8\n    assert device_eq(packed.device, device)\n    assert torch.equal(t, packed.unpack())\n\n\n@pytest.mark.parametrize(\"bits\", [2, 4], ids=[\"int2\", \"int4\"])\ndef test_packed_tensor_serialization(bits, device):\n    qmax = 2**bits\n    shape = (10, 32)\n    t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    packed = PackedTensor.pack(t, bits=bits)\n    b = io.BytesIO()\n    torch.save(packed, b)\n    b.seek(0)\n    packed_reloaded = torch.load(b, weights_only=False)\n    assert isinstance(packed_reloaded, PackedTensor)\n    assert packed_reloaded.shape == packed.shape\n    assert packed_reloaded.dtype == packed.dtype\n    assert packed_reloaded.bits == packed.bits\n    assert torch.equal(packed_reloaded._data, packed._data)\n    assert torch.equal(t, packed_reloaded.unpack())\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_awq_packed_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport pytest\nimport torch\nfrom helpers import device_eq\n\nfrom optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"random\", [True, False])\n@pytest.mark.parametrize(\"packing, reorder\", [(AWQPacking.V1, True), (AWQPacking.V1, False), (AWQPacking.V2, False)])\ndef test_pack_awq_tensor(in_features, out_features, random, packing, reorder, device):\n    bits = 4\n    qmax = 2**bits\n    shape = (out_features, in_features)\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    packed = AWQPackedTensor.pack(t, packing=packing, reorder=reorder)\n    assert isinstance(packed, AWQPackedTensor)\n    assert packed._packing == packing\n    assert packed._reorder == reorder\n    assert device_eq(packed.device, device)\n    assert torch.equal(t, packed.unpack())\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"packing, reorder\", [(AWQPacking.V1, True), (AWQPacking.V2, False)])\ndef test_move_awq_tensor(packing, reorder, device):\n    shape = (256, 256)\n    bits = 4\n    qmax = 2**bits\n    numel = np.prod(shape)\n    t = torch.tensor(range(numel), dtype=torch.int32)\n    t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    packed = AWQPackedTensor.pack(t, packing=packing, reorder=reorder)\n    assert packed._packing == packing\n    assert packed._reorder == reorder\n    moved = packed.to(device)\n    assert isinstance(moved, AWQPackedTensor)\n    assert moved._packing == packing\n    assert moved._reorder == reorder\n    # TensorRT tensors are unpacked when moved out of CUDA or XPU device\n    moved = packed.to(\"cpu\")\n    assert type(moved) is torch.Tensor\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_awq_weight_qbits_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import device_eq, random_qweight\nfrom tensor.weights.weight_helpers import check_weight_qtensor_linear\n\nfrom optimum.quanto import qint4\nfrom optimum.quanto.library.extensions import is_extension_available\nfrom optimum.quanto.tensor.weights import WeightQBitsTensor\nfrom optimum.quanto.tensor.weights.awq import AWQWeightQBitsTensor\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\ndef test_awq_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device):\n    qtype = qint4\n    group_size = 128\n    dtype = torch.float16\n    shape = (out_features, in_features)\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)\n    # Create a AWQWeightQBitsTensor from the WeightQBitsTensor members\n    awqbt = AWQWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    assert awqbt.dtype == dtype\n    assert awqbt.qtype == qtype\n    assert awqbt.shape == shape\n    assert device_eq(awqbt.device, device)\n    # Verify the dequantized tensors are identical\n    assert torch.equal(awqbt.dequantize(), qbt.dequantize())\n    # Now verify that we can reconstruct the WeightQBitsTensor\n    new_qbt = awqbt.weight_qbits_tensor()\n    assert type(new_qbt) is WeightQBitsTensor\n    assert new_qbt.dtype == dtype\n    assert new_qbt.qtype == qtype\n    assert new_qbt.shape == shape\n    assert torch.equal(new_qbt._data, qbt._data)\n    assert torch.equal(new_qbt._scale, qbt._scale)\n    assert torch.equal(new_qbt._shift, qbt._shift)\n\n\n@pytest.mark.skip_device(\"cpu\")\n@pytest.mark.skip_device(\"mps\")\ndef test_awq_weight_qbits_tensor_move(device):\n    qtype = qint4\n    group_size = 128\n    dtype = torch.float16\n    shape = (1024, 1024)\n    # Create an AWQWeightQBitsTensor from a QBitsTensor on CUDA or XPU\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)\n    awqbt = AWQWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    # Move to device, dequantize and compare\n    moved_qbt = awqbt.to(device)\n    assert isinstance(moved_qbt, WeightQBitsTensor)\n    if device.type not in [\"cuda\", \"xpu\"]:\n        assert type(moved_qbt) is not AWQWeightQBitsTensor\n    assert awqbt.dtype == moved_qbt.dtype\n    assert awqbt.qtype == moved_qbt.qtype\n    assert awqbt.shape == moved_qbt.shape\n    assert torch.equal(awqbt.dequantize().to(device), moved_qbt.dequantize())\n\n\ndef _test_awq_weight_qbits_tensor_linear(\n    dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias\n):\n    # Create an AWQWeightQBitsTensor from a QBitsTensor on CUDA\n    qbt = random_qweight(\n        (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device(0)\n    )\n    awq_qweight = AWQWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    check_weight_qtensor_linear(awq_qweight, batch_size, tokens, use_bias)\n\n\n@pytest.mark.skipif(\n    (not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8)\n    and not torch.xpu.is_available(),\n    reason=\"The test requires CUDA device >= sm80 or Intel XPU\",\n)\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [16, 32, 48, 64])\n@pytest.mark.parametrize(\"in_features\", [256, 512, 1024, 4096, 16384])\n@pytest.mark.parametrize(\"out_features\", [256, 512, 1024, 2048, 4096])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_awq_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias):\n    dtype = torch.float16\n    weight_qtype = qint4\n    group_size = 128\n    _test_awq_weight_qbits_tensor_linear(\n        dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias\n    )\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_marlin_fp8_packed_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport pytest\nimport torch\nfrom helpers import device_eq\n\nfrom optimum.quanto.library.extensions import is_extension_available\nfrom optimum.quanto.tensor.weights.marlin.fp8 import MarlinF8PackedTensor\n\n\ndef get_fp8_tensor(shape, device, random=False):\n    # We will initialize float8 from an uint8 tensor\n    qmax = 2**8\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    # Remove values that would be interpreted as nans in float8.\n    t[t == 127] = 0\n    t[t == 255] = 0\n    return t.view(torch.float8_e4m3fn).to(device)\n\n\n@pytest.mark.skipif(not is_extension_available(\"quanto_cuda\"), reason=\"CUDA extension is not available\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"random\", [True, False])\ndef test_pack_marlin_fp8_tensor(in_features, out_features, random):\n    shape = (out_features, in_features)\n    device = torch.device(\"cuda\")\n    t = get_fp8_tensor(shape, device, random)\n    packed = MarlinF8PackedTensor.pack(t)\n    assert isinstance(packed, MarlinF8PackedTensor)\n    assert device_eq(packed.device, device)\n    assert torch.equal(t, packed.unpack())\n\n\n@pytest.mark.skipif(not is_extension_available(\"quanto_cuda\"), reason=\"CUDA extension is not available\")\ndef test_move_marlin_fp8_tensor():\n    shape = (256, 256)\n    device = torch.device(\"cuda\")\n    t = get_fp8_tensor(shape, device)\n    packed = MarlinF8PackedTensor.pack(t)\n    moved = packed.to(\"cuda\")\n    assert isinstance(moved, MarlinF8PackedTensor)\n    # Marlin FP8 tensors are unpacked when moved out of CUDA device\n    moved = packed.to(\"cpu\")\n    assert type(moved) is torch.Tensor\n    assert torch.equal(t, moved.to(\"cuda\"))\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_marlin_int4_packed_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport pytest\nimport torch\nfrom helpers import device_eq\n\nfrom optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor\n\n\ndef get_uint4_tensor(shape, device, random=False):\n    qmax = 2**4\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    return t\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"random\", [True, False])\ndef test_pack_marlin_int4_tensor(in_features, out_features, random):\n    shape = (out_features, in_features)\n    device = torch.device(\"cuda\")\n    t = get_uint4_tensor(shape, device, random)\n    packed = MarlinInt4PackedTensor.pack(t)\n    assert isinstance(packed, MarlinInt4PackedTensor)\n    assert device_eq(packed.device, device)\n    assert torch.equal(t, packed.unpack())\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\ndef test_move_marlin_int4_packed_tensor(device):\n    shape = (256, 256)\n    device = torch.device(\"cuda\")\n    t = get_uint4_tensor(shape, device)\n    packed = MarlinInt4PackedTensor.pack(t)\n    moved = packed.to(\"cuda\")\n    assert isinstance(moved, MarlinInt4PackedTensor)\n    # Marlin int4 tensors are unpacked when moved out of CUDA device\n    moved = packed.to(\"cpu\")\n    assert type(moved) is torch.Tensor\n    assert torch.equal(t, moved.to(\"cuda\"))\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import device_eq, random_qweight\nfrom tensor.weights.weight_helpers import check_weight_qtensor_linear\n\nfrom optimum.quanto import qint4\nfrom optimum.quanto.library.extensions import is_extension_available\nfrom optimum.quanto.tensor.weights import WeightQBitsTensor\nfrom optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor\n\n\n@pytest.mark.skipif(\n    not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason=\"CUDA >= sm80 not available\"\n)\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\ndef test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features):\n    qtype = qint4\n    group_size = 128\n    dtype = torch.float16\n    shape = (out_features, in_features)\n    device = torch.device(\"cuda\")\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)\n    # Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members\n    marlinqbt = MarlinInt4WeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    assert marlinqbt.dtype == dtype\n    assert marlinqbt.qtype == qtype\n    assert marlinqbt.shape == shape\n    assert device_eq(marlinqbt.device, device)\n    # Verify the dequantized tensors are identical\n    assert torch.equal(marlinqbt.dequantize(), qbt.dequantize())\n    # Now verify that we can reconstruct the WeightQBitsTensor\n    new_qbt = marlinqbt.weight_qbits_tensor()\n    assert type(new_qbt) is WeightQBitsTensor\n    assert new_qbt.dtype == dtype\n    assert new_qbt.qtype == qtype\n    assert new_qbt.shape == shape\n    assert torch.equal(new_qbt._data, qbt._data)\n    assert torch.equal(new_qbt._scale, qbt._scale)\n    assert torch.equal(new_qbt._shift, qbt._shift)\n\n\n@pytest.mark.skipif(not torch.cuda.is_available(), reason=\"CUDA not available\")\ndef test_marlin_int4_weight_qbits_tensor_move(device):\n    qtype = qint4\n    group_size = 128\n    dtype = torch.float16\n    shape = (1024, 1024)\n    device = torch.device(\"cuda\")\n    # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device(\"cuda\"))\n    marlinqbt = MarlinInt4WeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    # Move to device, dequantize and compare\n    moved_qbt = marlinqbt.to(device)\n    assert isinstance(moved_qbt, WeightQBitsTensor)\n    if device.type != \"cuda\":\n        assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor\n    assert marlinqbt.dtype == moved_qbt.dtype\n    assert marlinqbt.qtype == moved_qbt.qtype\n    assert marlinqbt.shape == moved_qbt.shape\n    assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize())\n\n\ndef _test_marlin_int4_weight_qbits_tensor_linear(\n    dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias\n):\n    # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA\n    qbt = random_qweight(\n        (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device(\"cuda\")\n    )\n    marlin_qweight = MarlinInt4WeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale=qbt._scale,\n        shift=qbt._shift,\n    )\n    check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias)\n\n\n@pytest.mark.skipif(\n    not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8,\n    reason=\"CUDA >= sm80 not available\",\n)\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [16, 32])\n@pytest.mark.parametrize(\"in_features\", [1024])\n@pytest.mark.parametrize(\"out_features\", [1024, 2048, 4096])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias):\n    dtype = torch.float16\n    weight_qtype = qint4\n    group_size = 128\n    _test_marlin_int4_weight_qbits_tensor_linear(\n        dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias\n    )\n\n\n@pytest.mark.xfail(reason=\"Bug in Marlin kernel\", strict=False)\n@pytest.mark.skipif(\n    not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8,\n    reason=\"CUDA >= sm80 not available\",\n)\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [48, 64])\n# @pytest.mark.parametrize(\"in_features\", [1024, 2048, 4096, 16384])\n@pytest.mark.parametrize(\"in_features\", [4096, 16384])\n@pytest.mark.parametrize(\"out_features\", [2048, 4096])\ndef test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features):\n    dtype = torch.float16\n    weight_qtype = qint4\n    group_size = 128\n    _test_marlin_int4_weight_qbits_tensor_linear(\n        dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False\n    )\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_marlin_qbytes_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\n\nfrom optimum.quanto import qfloat8_e4m3fn\nfrom optimum.quanto.library.extensions import is_extension_available\nfrom optimum.quanto.tensor.weights.marlin import MarlinF8QBytesTensor\n\n\n@pytest.mark.skipif(\n    not is_extension_available(\"quanto_cuda\") or torch.cuda.get_device_capability()[0] < 8,\n    reason=\"CUDA >= sm80 not available\",\n)\n@pytest.mark.parametrize(\"in_features\", [16, 32, 48, 64])\n@pytest.mark.parametrize(\"out_features\", [64, 128, 192, 256])\ndef test_pack_unpack(in_features: int, out_features: int):\n    data = torch.randint(0, 256, size=(out_features, in_features), dtype=torch.uint8, device=\"cuda\")\n\n    # Remove nans.\n    data[data == 127] = 0\n    data[data == 255] = 0\n\n    data = data.view(torch.float8_e4m3fn)\n\n    qtype = qfloat8_e4m3fn\n    axis = 0\n    size = data.shape\n    stride = data.stride()\n    scale = torch.rand((out_features, 1), dtype=torch.float16, device=\"cuda\")\n    marlin_tensor = MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale)\n\n    data_dequantized = marlin_tensor.dequantize()\n\n    assert torch.all((data.to(torch.float16) * scale - data_dequantized).abs() < 1e-4)\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_tinygemm_packed_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport numpy as np\nimport pytest\nimport torch\nfrom helpers import device_eq\nfrom packaging import version\n\nfrom optimum.quanto.tensor.weights.tinygemm import TinyGemmPackedTensor\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"random\", [True, False])\ndef test_pack_tinygemm_tensor(in_features, out_features, random, device):\n    if device.type == \"cuda\":\n        if torch.version.hip:\n            pytest.skip(reason=\"TinyGemm is not supported on ROCm devices\")\n        if version.parse(torch.version.cuda).release < (12, 1):\n            pytest.skip(reason=\"CUDA runtime must be at least 12.1\")\n        if torch.cuda.get_device_capability()[0] < 8:\n            pytest.skip(reason=\"CUDA device >= sm80 not available\")\n    bits = 4\n    qmax = 2**bits\n    shape = (out_features, in_features)\n    if random:\n        t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)\n    else:\n        numel = np.prod(shape)\n        t = torch.tensor(range(numel), dtype=torch.int32)\n        t = (t % qmax).reshape(shape).to(torch.uint8).to(device)\n    packed = TinyGemmPackedTensor.pack(t)\n    assert isinstance(packed, TinyGemmPackedTensor)\n    assert device_eq(packed.device, device)\n    assert torch.equal(t, packed.unpack())\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\ndef test_move_tinygemm_packed_tensor(device):\n    if device.type == \"cuda\":\n        if torch.version.hip:\n            pytest.skip(reason=\"TinyGemm is not supported on ROCm devices\")\n        if version.parse(torch.version.cuda).release < (12, 1):\n            pytest.skip(reason=\"CUDA runtime must be at least 12.1\")\n        if torch.cuda.get_device_capability()[0] < 8:\n            pytest.skip(reason=\"CUDA device >= sm80 not available\")\n    shape = (256, 256)\n    bits = 4\n    qmax = 2**bits\n    numel = np.prod(shape)\n    t = torch.tensor(range(numel), dtype=torch.int32)\n    t = (t % qmax).reshape(shape).to(torch.uint8)\n    packed = TinyGemmPackedTensor.pack(t)\n    moved = packed.to(device)\n    assert torch.equal(t.to(device), moved.unpack())\n"
  },
  {
    "path": "tests/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, device_eq, random_qweight, random_tensor\nfrom packaging import version\n\nfrom optimum.quanto import qint4\nfrom optimum.quanto.tensor.weights import WeightQBitsTensor\nfrom optimum.quanto.tensor.weights.tinygemm import TinyGemmWeightQBitsTensor\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\n@pytest.mark.parametrize(\"in_features\", [128, 256, 512, 1024])\n@pytest.mark.parametrize(\"out_features\", [128, 256, 512, 1024])\ndef test_tinygemm_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device):\n    if device.type == \"cuda\":\n        if torch.version.hip:\n            pytest.skip(reason=\"TinyGemm not available for ROCm devices\")\n        if version.parse(torch.version.cuda).release < (12, 1):\n            pytest.skip(reason=\"CUDA runtime must be at least 12.1\")\n        if torch.cuda.get_device_capability()[0] < 8:\n            pytest.skip(reason=\"CUDA device >= sm80 not available\")\n    qtype = qint4\n    group_size = 128\n    dtype = torch.bfloat16\n    shape = (out_features, in_features)\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device)\n    # Create a TinyGemmWeightQBitsTensor from the WeightQBitsTensor members\n    tgqbt = TinyGemmWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale_shift=(qbt._scale, qbt._shift),\n    )\n    assert tgqbt.dtype == dtype\n    assert tgqbt.qtype == qtype\n    assert tgqbt.shape == shape\n    assert device_eq(tgqbt.device, device)\n    # Verify that we can reconstruct the WeightQBitsTensor\n    new_qbt = tgqbt.weight_qbits_tensor()\n    assert type(new_qbt) is WeightQBitsTensor\n    assert new_qbt.dtype == dtype\n    assert new_qbt.qtype == qtype\n    assert new_qbt.shape == shape\n    assert torch.equal(new_qbt._data, qbt._data)\n    assert torch.equal(new_qbt._scale, qbt._scale)\n    # FIXME: we cannot guarantee an exact match because of the addition/removal of the mid-point\n    # which is lossy in bfloat16 (a + b - b != a)\n    assert_similar(new_qbt._shift, qbt._shift)\n    # Verify the dequantized tensors are similar\n    assert_similar(tgqbt.dequantize(), qbt.dequantize())\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\ndef test_tinygemm_weight_qbits_tensor_move(device):\n    qtype = qint4\n    group_size = 128\n    dtype = torch.bfloat16\n    shape = (1024, 1024)\n    # Create a TinyGemmWeightQBitsTensor from a QBitsTensor on CPU\n    qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device(\"cpu\"))\n    tgqbt_cpu = TinyGemmWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale_shift=(qbt._scale, qbt._shift),\n    )\n    # Move to device, dequantize and compare\n    tgqbt = tgqbt_cpu.to(device)\n    assert isinstance(tgqbt, WeightQBitsTensor)\n    assert tgqbt.dtype == tgqbt_cpu.dtype\n    assert tgqbt.qtype == tgqbt_cpu.qtype\n    assert tgqbt.shape == tgqbt_cpu.shape\n    assert torch.equal(tgqbt.dequantize().cpu(), tgqbt_cpu.dequantize())\n\n\n@pytest.mark.skip_device(\"mps\")  # Only available with pytorch 2.4\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [256, 512])\n@pytest.mark.parametrize(\"embeddings\", [256, 512, 1024, 4096])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device):\n    if device.type == \"cuda\":\n        if torch.version.hip:\n            pytest.skip(reason=\"TinyGemm not available for ROCm devices\")\n        if version.parse(torch.version.cuda).release < (12, 1):\n            pytest.skip(reason=\"CUDA runtime must be at least 12.1\")\n        if torch.cuda.get_device_capability()[0] < 8:\n            pytest.skip(reason=\"CUDA device >= sm80 not available\")\n    qtype = qint4\n    group_size = 128\n    dtype = torch.bfloat16\n    inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device)\n    # Create a TinyGemmWeightQBitsTensor from a QBitsTensor\n    qbt = random_qweight((tokens, embeddings), qtype, dtype, group_size=group_size, device=device)\n    tinygemm_qweight = TinyGemmWeightQBitsTensor(\n        qtype=qbt.qtype,\n        axis=qbt.axis,\n        group_size=qbt._group_size,\n        size=qbt.size(),\n        stride=qbt.stride(),\n        data=qbt._data.unpack(),\n        scale_shift=(qbt._scale, qbt._shift),\n    )\n    bias = random_tensor((tokens,), dtype=dtype).to(device) if use_bias else None\n    qout = torch.nn.functional.linear(inputs, tinygemm_qweight, bias)\n    out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias)\n    assert_similar(out, qout)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbits_tensor.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport io\n\nimport pytest\nimport torch\nfrom helpers import random_qweight, random_tensor\n\nfrom optimum.quanto import MaxOptimizer, WeightQBitsTensor, qint2, qint4, quantize_weight\n\n\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"int2\", \"int4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\ndef test_weight_qbits_tensor_serialization(qtype, axis):\n    qa = random_qweight((5, 5), qtype=qtype, axis=axis)\n    b = io.BytesIO()\n    torch.save(qa, b)\n    b.seek(0)\n    qa_reloaded = torch.load(b, weights_only=False)\n    assert isinstance(qa_reloaded, WeightQBitsTensor)\n    assert qa_reloaded.qtype == qa.qtype\n    assert qa_reloaded.dtype == qa.dtype\n    assert torch.equal(qa_reloaded._data, qa._data)\n    assert torch.equal(qa_reloaded._scale, qa._scale)\n    assert torch.equal(qa_reloaded._shift, qa._shift)\n\n\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"int2\", \"int4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"group_size\", [None, 16], ids=[\"channel-wise\", \"group-wise\"])\ndef test_weight_qbits_tensor_requires_grad(qtype, axis, group_size, device):\n    weight = random_tensor((32, 32), dtype=torch.float32).to(device)\n    weight.requires_grad = True\n    scale, shift = MaxOptimizer()(weight, qtype=qtype, axis=axis, group_size=group_size)\n    qweight = quantize_weight(weight, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size)\n    assert qweight.requires_grad is True\n\n\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"int2\", \"int4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"group_size\", [None, 16], ids=[\"channel-wise\", \"group-wise\"])\ndef test_weight_qbits_tensor_backward(qtype, axis, group_size, device):\n    weight = random_tensor((32, 32), dtype=torch.float32).to(device)\n    weight.requires_grad = True\n    scale, shift = MaxOptimizer()(weight, qtype=qtype, axis=axis, group_size=group_size)\n    qweight = quantize_weight(weight, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size)\n    gradient = torch.randn((32, 32)).to(device)\n    # Backpropagate gradient to the inner float weights\n    qweight.dequantize().backward(gradient)\n    assert torch.equal(weight.grad, gradient)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbits_tensor_dispatch.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, random_qweight, random_tensor\nfrom tensor.weights.weight_helpers import check_weight_qtensor_linear\n\nfrom optimum.quanto import MaxOptimizer, QBitsTensor, qint2, qint4, quantize_weight\n\n\n@pytest.mark.parametrize(\"group_size\", [None, 128], ids=[\"channel-wise\", \"group-wise\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.float16], ids=[\"fp32\", \"fp16\"])\ndef test_qbitstensor_to_device(dtype, group_size, device):\n    qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device=\"cpu\")\n    # Keep a copy of the dequantized Tensor as a reference\n    dqa = qa.dequantize()\n    # Move to the target device\n    moved_qa = qa.to(device)\n    assert isinstance(moved_qa, QBitsTensor)\n    assert moved_qa.device.type == device.type\n    assert moved_qa._data.device.type == device.type\n    assert moved_qa._scale.device.type == device.type\n    assert moved_qa._shift.device.type == device.type\n    moved_dqa = moved_qa.dequantize().to(\"cpu\")\n    if type(moved_qa) is not QBitsTensor:\n        # Since we use an optimized packing, the order of operations during\n        # dequantization might differ, but the moved dequantized Tensor should be nearly identical\n        assert_similar(moved_dqa, dqa)\n    else:\n        assert torch.equal(moved_dqa, dqa)\n\n\ndef test_qbitstensor_detach():\n    qa = random_qweight((32, 32), qtype=qint4)\n    dqa = qa.detach()\n    assert isinstance(dqa, QBitsTensor)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\ndef test_qbitstensor_equal(dtype, qtype, axis, device):\n    a = random_tensor((1024, 1024), dtype=dtype, device=device)\n    scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=128)\n    qa1 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=128)\n    qa2 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=128)\n    assert torch.equal(qa1, qa2)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16], ids=[\"fp16\", \"bf16\"])\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [16, 32])\n@pytest.mark.parametrize(\"in_features\", [256, 512])\n@pytest.mark.parametrize(\"out_features\", [256, 512])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_weight_qbits_tensor_linear(dtype, batch_size, tokens, in_features, out_features, use_bias, device):\n    weight_qtype = qint4\n    group_size = 128\n    # Create a QBitsTensor\n    qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=device)\n    check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.bfloat16], ids=[\"fp16\", \"bf16\"])\n@pytest.mark.parametrize(\"batch_size\", [1, 2])\n@pytest.mark.parametrize(\"tokens\", [16, 32, 48, 64])\n@pytest.mark.parametrize(\"in_features\", [1024, 4096, 16384])\n@pytest.mark.parametrize(\"out_features\", [1024, 2048, 4096])\n@pytest.mark.parametrize(\"use_bias\", [True, False], ids=[\"bias\", \"no-bias\"])\ndef test_weight_qbits_tensor_linear_gpu(dtype, batch_size, tokens, in_features, out_features, use_bias):\n    if torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    elif torch.xpu.is_available():\n        device = torch.device(\"xpu\")\n    else:\n        pytest.skip(reason=\"Test is too slow on non-GPU devices\")\n\n    weight_qtype = qint4\n    group_size = 128\n    # Create a QBitsTensor\n    qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=device)\n    check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbits_tensor_instantiate.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport pytest\nimport torch\n\nfrom optimum.quanto import qint2, qint4\nfrom optimum.quanto.tensor.weights import WeightQBitsTensor\n\n\ndef random_data_scale_shift(input_shape, dtype, qtype, axis, group_size):\n    out_features, in_features = input_shape\n    n_groups = in_features * out_features // group_size\n    data_shape = (n_groups, group_size) if axis == 0 else (group_size, n_groups)\n    scale_shape = (n_groups, 1) if axis == 0 else (1, n_groups)\n    min_value = -(2 ** (qtype.bits - 1))\n    max_value = 2 ** (qtype.bits - 1) - 1\n    data = torch.randint(max_value - min_value + 1, data_shape, dtype=torch.uint8)\n    scale = torch.full(scale_shape, 1.0 / -min_value, dtype=dtype)\n    shift = torch.ones(scale_shape, dtype=dtype)\n    return data, scale, shift\n\n\n@pytest.mark.parametrize(\"input_shape, group_size\", [[(32, 32), 16], [(1024, 1024), 128]])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\ndef test_weight_qbits_tensor_instantiate(input_shape, dtype, qtype, axis, group_size, device):\n    data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size)\n    input_stride = torch.ones(input_shape).stride()\n    qa = WeightQBitsTensor(qtype, axis, group_size, input_shape, input_stride, data, scale=scale, shift=shift).to(\n        device\n    )\n    assert torch.max(torch.abs(qa.dequantize())) <= 1\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert qa.shape == input_shape\n\n\n@pytest.mark.parametrize(\"input_shape, group_size\", [[(32, 32), 16], [(1024, 1024), 128]])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\ndef test_weight_qbits_tensor_equal(input_shape, dtype, qtype, axis, group_size, device):\n    data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size)\n    qa = WeightQBitsTensor(qtype, axis, group_size, data.size(), data.stride(), data, scale=scale, shift=shift).to(\n        device\n    )\n    qb = WeightQBitsTensor(\n        qtype, axis, group_size, data.size(), data.stride(), data.clone(), scale=scale.clone(), shift=shift.clone()\n    ).to(device)\n    assert qa.equal(qb)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbits_tensor_quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, device_eq, random_tensor\n\nfrom optimum.quanto import (\n    MaxOptimizer,\n    qint2,\n    qint4,\n)\nfrom optimum.quanto.tensor.weights import WeightQBitsTensor\n\n\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"group_size\", [None, 8], ids=[\"channel-wise\", \"group-wise\"])\n@pytest.mark.parametrize(\"shift_mode\", [\"zeropoint\", \"float\"])\ndef test_weight_qbits_tensor_quantize(input_shape, dtype, qtype, axis, group_size, shift_mode, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size)\n    if shift_mode == \"zeropoint\":\n        shift = torch.round(shift / scale).to(torch.int8)\n    qa = WeightQBitsTensor.quantize(a, qtype, axis, group_size, scale, shift)\n    assert isinstance(qa, WeightQBitsTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    atol = {\n        qint4: {\n            \"zeropoint\": 4e-3,\n            \"float\": 3e-3,\n        },\n        qint2: {\n            \"zeropoint\": 6e-2,\n            \"float\": 5e-2,\n        },\n    }[qtype][shift_mode]\n    assert_similar(a, qa, atol=atol)\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint2, qint4], ids=[\"qint2\", \"qint4\"])\ndef test_weight_qbits_tensor_quantize_integer_tensor(dtype, qtype, device):\n    \"\"\"This test verifies that an integer tensor in the correct range is preserved.\"\"\"\n    bits = qtype.bits\n    qmin = -(2 ** (bits - 1))\n    qmax = 2 ** (bits - 1) - 1\n    a = torch.tensor(range(qmin, qmax + 1), dtype=dtype).to(device)\n    scale, shift = MaxOptimizer()(a, qtype=qtype, axis=0, group_size=None)\n    zeropoint = torch.round(shift / scale)\n    qa = WeightQBitsTensor.quantize(a, qtype, 0, None, scale, zeropoint)\n\n    assert qa._data.dtype == torch.uint8\n    assert isinstance(qa, WeightQBitsTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    assert torch.equal(a, qa.dequantize())\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbytes_tensor_backward.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport torch\nfrom helpers import random_tensor\n\nfrom optimum.quanto import AbsmaxOptimizer, qint8, quantize_weight\n\n\ndef test_weight_qbytes_tensor_requires_grad(device):\n    w = random_tensor((10, 10), dtype=torch.float32).to(device)\n    w.requires_grad = True\n    scale = AbsmaxOptimizer()(w, qtype=qint8, axis=0)\n    qw = quantize_weight(w, qtype=qint8, axis=0, scale=scale)\n    assert qw.requires_grad is True\n\n\ndef test_weight_qbytes_tensor_backward(device):\n    w = random_tensor((10, 10), dtype=torch.float32).to(device)\n    w.requires_grad = True\n    scale = AbsmaxOptimizer()(w, qtype=qint8, axis=0)\n    qw = quantize_weight(w, qtype=qint8, axis=0, scale=scale)\n    gradient = torch.randn((10, 10)).to(device)\n    # Backpropagate gradient to the inner float weights\n    qw.dequantize().backward(gradient)\n    assert torch.equal(w.grad, gradient)\n\n\ndef test_weight_qbytes_tensor_chained_backward(device):\n    a = random_tensor((10, 10), dtype=torch.float32).to(device)\n    a.requires_grad = True\n    scale = AbsmaxOptimizer()(a, qtype=qint8, axis=0)\n    qa = quantize_weight(a, qtype=qint8, axis=0, scale=scale)\n    b = random_tensor((10, 10), dtype=torch.float32).to(device)\n    b.requires_grad = True\n    scale = AbsmaxOptimizer()(b, qtype=qint8, axis=0)\n    qb = quantize_weight(b, qtype=qint8, axis=0, scale=scale)\n    # Evaluate the product\n    prod = qa * qb\n    # Backpropagate\n    gradient = torch.randn((10, 10)).to(device)\n    prod.backward(gradient)\n    assert torch.allclose(a.grad, qb.dequantize() * gradient)\n    assert torch.allclose(b.grad, qa.dequantize() * gradient)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbytes_tensor_dispatch.py",
    "content": "import pytest\nimport torch\nfrom helpers import random_qweight, random_tensor\n\nfrom optimum.quanto import AbsmaxOptimizer, WeightQBytesTensor, qint8, quantize_weight\n\n\ndef test_weight_qytes_tensor_to_device(device):\n    qa = random_qweight((32, 32), qtype=qint8, dtype=torch.float)\n    qa = qa.to(device)\n    assert isinstance(qa, WeightQBytesTensor)\n    assert qa.device.type == device.type\n    assert qa._data.device.type == device.type\n    assert qa._scale.device.type == device.type\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\ndef test_weight_qbytes_tensor_equal(dtype, qtype, axis, device):\n    a = random_tensor((32, 32), dtype=dtype, device=device)\n    scale = AbsmaxOptimizer()(a, qtype=qtype, axis=axis)\n    qa1 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale)\n    qa2 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale)\n    assert torch.equal(qa1, qa2)\n\n\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"qtype\", [qint8])\ndef test_weight_qbytes_tensor_transpose_contiguous(axis, qtype, device):\n    input_shape = (16, 32)\n    qa = random_qweight(input_shape, axis=axis, qtype=qtype, dtype=torch.float32).to(device)\n    assert qa.is_contiguous()\n    tqa = qa.t()\n    assert isinstance(tqa, WeightQBytesTensor)\n    assert not tqa.is_contiguous()\n    tqa = tqa.contiguous()\n    assert tqa.is_contiguous()\n\n\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\n@pytest.mark.parametrize(\"qtype\", [qint8])\ndef test_weight_qbytes_tensor_transposed_stride(axis, qtype, device):\n    input_shape = (16, 32)\n    a = random_tensor(input_shape, dtype=torch.float32).to(device)\n    scale = AbsmaxOptimizer()(a, qtype=qtype, axis=axis)\n    qa = quantize_weight(a, qtype=qtype, axis=axis, scale=scale)\n    assert qa.stride() == a.stride()\n    ta = a.t()\n    tqa = qa.t()\n    assert isinstance(tqa, WeightQBytesTensor)\n    assert tqa.stride() == ta.stride()\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbytes_tensor_instantiate.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport pytest\nimport torch\n\nfrom optimum.quanto import WeightQBytesTensor, qfloat8, qint8\n\n\ndef random_data_scale(input_shape, dtype, qtype):\n    if qtype.is_floating_point:\n        min_value = torch.finfo(qtype.dtype).min\n        max_value = torch.finfo(qtype.dtype).max\n        data = (torch.rand(input_shape) * max_value + min_value).to(qtype.dtype)\n    else:\n        max_value = torch.iinfo(qtype.dtype).max\n        data = torch.randint(-max_value, max_value, input_shape, dtype=qtype.dtype)\n    scale = torch.tensor(1.0 / max_value, dtype=dtype)\n    return data, scale\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (1, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8, qfloat8], ids=[\"qint8\", \"qfloat8\"])\ndef test_qbytestensor_instantiate(input_shape, dtype, qtype, device):\n    if qtype.is_floating_point and device.type == \"mps\":\n        pytest.skip(\"float8 types are not supported on MPS device\")\n    data, scale = random_data_scale(input_shape, dtype, qtype)\n    qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale, activation_qtype=None).to(\n        device\n    )\n    assert torch.max(torch.abs(qa.dequantize())) <= 1\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert qa.shape == input_shape\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10,), (1, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.bfloat16, torch.float16, torch.float32], ids=[\"bf16\", \"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8], ids=[\"qint8\"])\ndef test_qbytestensor_equal(input_shape, dtype, qtype, device):\n    data, scale = random_data_scale(input_shape, dtype, qtype)\n    qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale, activation_qtype=None).to(\n        device\n    )\n    qb = WeightQBytesTensor(\n        qtype, None, data.size(), data.stride(), data.clone(), scale=scale, activation_qtype=None\n    ).to(device)\n    assert qa.equal(qb)\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbytes_tensor_quantize.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport pytest\nimport torch\nfrom helpers import assert_similar, device_eq, random_qweight, random_tensor\n\nfrom optimum.quanto import (\n    WeightQBytesTensor,\n    absmax_scale,\n    qfloat8,\n    qfloat8_e4m3fn,\n    qfloat8_e4m3fnuz,\n    qfloat8_e5m2,\n    qint8,\n)\n\n\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"qtype\", [qint8], ids=[\"qint8\"])\n@pytest.mark.parametrize(\n    \"axis\",\n    [None, 0, -1],\n    ids=[\"per-tensor\", \"first-axis\", \"last-axis\"],\n)\ndef test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=axis)\n    qa = WeightQBytesTensor.quantize(a, qtype, axis, scale)\n    assert isinstance(qa, WeightQBytesTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    assert_similar(a, qa)\n\n\n@pytest.mark.skip_device(\"mps\")\n@pytest.mark.parametrize(\"input_shape\", [(32, 32), (32, 10, 32)])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\n    \"qtype\",\n    [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2],\n    ids=[\"qfloat8\", \"qfloat8_e4m3fn\", \"qfloat8_e4m3fnuz\", \"qfloat8_e5m2\"],\n)\n@pytest.mark.parametrize(\n    \"axis\",\n    [None, 0, -1],\n    ids=[\"per-tensor\", \"first-axis\", \"last-axis\"],\n)\ndef test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device):\n    a = random_tensor(input_shape, dtype=dtype).to(device)\n    scale = absmax_scale(a, qtype=qtype, axis=axis)\n    qa = WeightQBytesTensor.quantize(a, qtype, axis, scale)\n    assert isinstance(qa, WeightQBytesTensor)\n    assert qa.dtype == dtype\n    assert qa.qtype == qtype\n    assert device_eq(qa.device, device)\n    assert_similar(a, qa, atol=5e-3)\n\n\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\ndef test_quantize_weight_axis_dim_1(axis, device):\n    input_shape = (1, 32) if axis == 0 else (32, 1)\n    qa = random_qweight(input_shape, dtype=torch.float32, qtype=qint8, axis=axis, device=device)\n    # Quantizing along an axis of dimension 1 actually means per-tensor\n    assert qa.axis is None\n"
  },
  {
    "path": "tests/tensor/weights/test_weight_qbytes_tensor_serialization.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport io\n\nimport pytest\nimport torch\nfrom helpers import random_qweight\n\nfrom optimum.quanto import qfloat8, qint8\n\n\n@pytest.mark.parametrize(\"input_shape\", [(10, 10), (10, 32, 32)])\n@pytest.mark.parametrize(\"qtype\", [qint8, qfloat8], ids=[\"qint8\", \"qfloat8\"])\n@pytest.mark.parametrize(\"dtype\", [torch.float16, torch.float32], ids=[\"fp16\", \"fp32\"])\n@pytest.mark.parametrize(\"axis\", [0, -1], ids=[\"first-axis\", \"last-axis\"])\ndef test_weights_qbytes_tensor_serialization(input_shape, qtype, dtype, axis):\n    qinputs = random_qweight(input_shape, dtype=dtype, qtype=qtype, axis=axis)\n    b = io.BytesIO()\n    torch.save(qinputs, b)\n    b.seek(0)\n    qinputs_reloaded = torch.load(b, weights_only=False)\n    assert qinputs_reloaded.qtype == qtype\n    assert torch.equal(qinputs_reloaded._scale, qinputs._scale)\n    if qtype.is_floating_point:\n        # Equality is not supported for float8\n        assert torch.equal(qinputs_reloaded._data.to(torch.float32), qinputs._data.to(torch.float32))\n    else:\n        assert torch.equal(qinputs_reloaded._data, qinputs._data)\n    # We cannot test dtype directly as it is not correctly set by torch.load\n    assert qinputs_reloaded._scale.dtype == dtype\n    assert qinputs_reloaded.axis == qinputs.axis\n"
  },
  {
    "path": "tests/tensor/weights/weight_helpers.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom helpers import assert_similar, random_tensor\n\n\ndef check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_err=0.0):\n    dtype = qweight.dtype\n    device = qweight.device\n    out_features, in_features = qweight.shape\n    inputs = torch.rand((batch_size, tokens, in_features), dtype=dtype, device=device)\n    bias = random_tensor((out_features,), dtype=dtype, device=device) if use_bias else None\n    qout = torch.nn.functional.linear(inputs, qweight, bias)\n    out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias)\n    # Verify global alignment\n    assert_similar(out, qout)\n    # Also look for outliers\n    mean_val = out.abs().max()\n    max_err = (out - qout).abs().max()\n    rel_max_err = max_err / mean_val\n    # These values were evaluated empirically without any optimized kernels.\n    rtol = {\"cpu\": 1e-2, \"cuda\": 2e-2, \"mps\": 1e-2, \"xpu\": 2e-2}[device.type]\n    assert rel_max_err < rtol, (\n        f\"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err * 100:.2f} %)\"\n    )\n"
  }
]