Repository: huggingface/quanto Branch: main Commit: ef3aafb30e6b Files: 207 Total size: 766.4 KB Directory structure: gitextract_e7pf933s/ ├── .github/ │ ├── CODEOWNERS │ ├── PULL_REQUEST_TEMPLATE.md │ └── workflows/ │ ├── check-commits.yml │ ├── linux-cpu-tests.yml │ ├── linux-cuda-tests.yml │ ├── linux-examples.yml │ ├── python-quality.yml │ ├── security.yml │ └── stale.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── bench/ │ ├── generation/ │ │ ├── README.md │ │ ├── evaluate_configurations.py │ │ ├── evaluate_many_models.sh │ │ ├── evaluate_model.py │ │ ├── gen_barchart.py │ │ ├── metrics/ │ │ │ ├── __init__.py │ │ │ ├── latency.py │ │ │ ├── perplexity.py │ │ │ └── prediction.py │ │ └── setup/ │ │ ├── __init__.py │ │ ├── awq.py │ │ ├── bnb.py │ │ ├── hqq.py │ │ └── quanto.py │ ├── kernels/ │ │ ├── benchmark.py │ │ ├── benchmark_marlin_fp8.py │ │ └── benchmark_w4a16.py │ └── torch_kernels/ │ ├── README.md │ ├── test_int_mm.py │ ├── test_int_mm_inductor.py │ ├── test_weight_int4pack_mm.py │ └── test_weight_int8pack_mm.py ├── examples/ │ ├── nlp/ │ │ ├── text-classification/ │ │ │ └── sst2/ │ │ │ └── quantize_sst2_model.py │ │ └── text-generation/ │ │ └── quantize_causal_lm_model.py │ ├── speech/ │ │ └── speech_recognition/ │ │ ├── quantize_asr_model.py │ │ └── requirements.txt │ └── vision/ │ ├── StableDiffusion/ │ │ ├── README.md │ │ ├── quantize_StableDiffusion.py │ │ └── requirements.txt │ ├── image-classification/ │ │ ├── mnist/ │ │ │ └── quantize_mnist_model.py │ │ └── pets/ │ │ └── quantize_vit_model.py │ ├── object-detection/ │ │ └── quantize_owl_model.py │ └── text-to-image/ │ └── quantize_pixart_sigma.py ├── external/ │ ├── awq/ │ │ ├── conftest.py │ │ ├── pack_intweight.py │ │ ├── packing_utils.py │ │ ├── test_awq_kernels.py │ │ ├── test_awq_packing.py │ │ └── test_awq_quantize.py │ └── smoothquant/ │ ├── README.md │ └── smoothquant.py ├── optimum/ │ └── quanto/ │ ├── __init__.py │ ├── calibrate.py │ ├── library/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── extensions/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── cpp/ │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── pybind_module.cpp │ │ │ │ ├── unpack.cpp │ │ │ │ └── unpack.h │ │ │ ├── cuda/ │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── awq/ │ │ │ │ │ ├── dequantize.cuh │ │ │ │ │ └── v2/ │ │ │ │ │ ├── gemm_cuda.cu │ │ │ │ │ ├── gemm_cuda.h │ │ │ │ │ ├── gemv_cuda.cu │ │ │ │ │ ├── gemv_cuda.h │ │ │ │ │ └── semaphore.h │ │ │ │ ├── marlin/ │ │ │ │ │ ├── COPYRIGHT │ │ │ │ │ ├── fp8_marlin.cu │ │ │ │ │ ├── fp8_marlin.cuh │ │ │ │ │ ├── gptq_marlin.cuh │ │ │ │ │ ├── gptq_marlin_dtypes.cuh │ │ │ │ │ ├── gptq_marlin_repack.cu │ │ │ │ │ ├── gptq_marlin_repack.cuh │ │ │ │ │ ├── marlin_cuda.cpp │ │ │ │ │ ├── marlin_cuda.h │ │ │ │ │ ├── marlin_cuda_kernel.cu │ │ │ │ │ └── marlin_cuda_kernel.cuh │ │ │ │ ├── pybind_module.cpp │ │ │ │ ├── unpack.cu │ │ │ │ └── unpack.h │ │ │ ├── extension.py │ │ │ ├── hip/ │ │ │ │ ├── __init__.py │ │ │ │ ├── pybind_module.cpp │ │ │ │ ├── unpack.cu │ │ │ │ └── unpack.h │ │ │ ├── mps/ │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── pybind_module.cpp │ │ │ │ ├── unpack.h │ │ │ │ └── unpack.mm │ │ │ └── xpu/ │ │ │ ├── __init__.py │ │ │ ├── pybind_module.cpp │ │ │ ├── unpack.h │ │ │ └── unpack.sycl │ │ ├── qbytes_mm.py │ │ ├── quantize.py │ │ └── unpack.py │ ├── models/ │ │ ├── __init__.py │ │ ├── diffusers_models.py │ │ ├── shared_dict.py │ │ └── transformers_models.py │ ├── nn/ │ │ ├── __init__.py │ │ ├── qconv2d.py │ │ ├── qlayernorm.py │ │ ├── qlinear.py │ │ └── qmodule.py │ ├── quantize.py │ ├── subpackage/ │ │ ├── __init__.py │ │ └── commands/ │ │ ├── __init__.py │ │ ├── base.py │ │ └── quantize.py │ └── tensor/ │ ├── __init__.py │ ├── activations/ │ │ ├── __init__.py │ │ ├── qbytes.py │ │ ├── qbytes_ops.py │ │ └── quantization.py │ ├── core.py │ ├── function.py │ ├── grouped.py │ ├── optimizers/ │ │ ├── __init__.py │ │ ├── absmax_optimizer.py │ │ ├── affine_optimizer.py │ │ ├── hqq_optimizer.py │ │ ├── max_optimizer.py │ │ ├── optimizer.py │ │ └── symmetric_optimizer.py │ ├── packed.py │ ├── qbits.py │ ├── qbytes.py │ ├── qtensor.py │ ├── qtype.py │ └── weights/ │ ├── __init__.py │ ├── awq/ │ │ ├── __init__.py │ │ ├── packed.py │ │ └── qbits.py │ ├── marlin/ │ │ ├── __init__.py │ │ ├── fp8/ │ │ │ ├── __init__.py │ │ │ ├── packed.py │ │ │ └── qbits.py │ │ ├── int4/ │ │ │ ├── __init__.py │ │ │ ├── packed.py │ │ │ └── qbits.py │ │ └── permutations.py │ ├── packing.py │ ├── qbits.py │ ├── qbytes.py │ ├── quantization.py │ ├── reordering.py │ └── tinygemm/ │ ├── __init__.py │ ├── packed.py │ └── qbits.py ├── pyproject.toml ├── setup.sh └── tests/ ├── cli/ │ ├── cli_helpers.py │ └── test_quantize_cli.py ├── conftest.py ├── helpers.py ├── library/ │ ├── test_extensions.py │ ├── test_mm.py │ ├── test_quantize.py │ └── test_unpack.py ├── models/ │ ├── conftest.py │ ├── test_quantized_model_for_causal_lm.py │ └── test_quantized_model_for_pixart.py ├── nn/ │ ├── test_calibrate.py │ ├── test_qattention.py │ ├── test_qconv2d.py │ ├── test_qlayernorm.py │ ├── test_qlinear.py │ └── test_qmodule.py ├── quantize/ │ ├── test_quantize_mlp.py │ ├── test_quantize_patterns.py │ └── test_requantize.py └── tensor/ ├── activations/ │ ├── test_activations_compile.py │ ├── test_activations_dispatch.py │ └── test_activations_quantize.py ├── ops/ │ ├── test_linear_dispatch.py │ └── test_mm_dispatch.py ├── optimizers/ │ └── test_hqq_optimizer.py ├── test_absmax.py ├── test_packed_tensor.py └── weights/ ├── optimized/ │ ├── test_awq_packed_tensor.py │ ├── test_awq_weight_qbits_tensor.py │ ├── test_marlin_fp8_packed_tensor.py │ ├── test_marlin_int4_packed_tensor.py │ ├── test_marlin_int4_weight_qbits_tensor.py │ ├── test_marlin_qbytes_tensor.py │ ├── test_tinygemm_packed_tensor.py │ └── test_tinygemm_weight_qbits_tensor.py ├── test_weight_qbits_tensor.py ├── test_weight_qbits_tensor_dispatch.py ├── test_weight_qbits_tensor_instantiate.py ├── test_weight_qbits_tensor_quantize.py ├── test_weight_qbytes_tensor_backward.py ├── test_weight_qbytes_tensor_dispatch.py ├── test_weight_qbytes_tensor_instantiate.py ├── test_weight_qbytes_tensor_quantize.py ├── test_weight_qbytes_tensor_serialization.py └── weight_helpers.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ * @dacorvo @sunmarc ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ # What does this PR do? Fixes # (issue) ## Before submitting - [ ] Did you read the [contributor guideline](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md#create-a-pull-request), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you run all tests locally and make sure they pass. - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. ================================================ FILE: .github/workflows/check-commits.yml ================================================ name: Check Commits on: [workflow_call] jobs: build: name: Check commits runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: huggingface/action-check-commits@v1.0.0 with: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} max-commits: "10" min-words: "3" forbidden-words: "fixup" ================================================ FILE: .github/workflows/linux-cpu-tests.yml ================================================ name: Linux CPU tests on: push: branches: - main paths: - "optimum/quanto/**" - "tests/**" - "pyproject.toml" pull_request: types: [assigned, opened, synchronize, reopened] paths: - "optimum/quanto/**" - "tests/**" - "pyproject.toml" jobs: check-commits: uses: ./.github/workflows/check-commits.yml python-quality: uses: ./.github/workflows/python-quality.yml test-ubuntu-cpu: needs: [check-commits, python-quality] runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.9", "3.11"] steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 with: python-version: ${{ matrix.python-version }} - name: Build and install quanto run: | pip install --upgrade pip pip install -e .[dev] - name: Run base tests run: | python -m pytest tests --ignore=tests/models --ignore=tests/cli - name: Run models tests run: | pip install accelerate transformers diffusers python -m pytest tests/models - name: Run CLI tests run: | pip install optimum python -m pytest tests/cli run_staging_tests: needs: [check-commits, python-quality] runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.9", "3.11"] steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2 with: python-version: ${{ matrix.python-version }} - name: Build and install quanto run: | pip install --upgrade pip pip install -e .[dev] - name: Run models hub tests run: | pip install accelerate transformers diffusers HUGGINGFACE_CO_STAGING=true python -m pytest tests/models -k "hub" ================================================ FILE: .github/workflows/linux-cuda-tests.yml ================================================ name: Linux CUDA tests on: push: branches: - main paths: - "optimum/quanto/**" - "tests/**" - "pyproject.toml" pull_request: types: [assigned, opened, synchronize, reopened] paths: - "optimum/quanto/**" - "tests/**" - "pyproject.toml" jobs: check-commits: uses: ./.github/workflows/check-commits.yml python-quality: uses: ./.github/workflows/python-quality.yml test-ubuntu-cuda: needs: [check-commits, python-quality] runs-on: group: aws-g5-4xlarge-plus strategy: fail-fast: false matrix: cuda-version: ["11.8", "12.4", "12.6"] container: image: pytorch/pytorch:2.6.0-cuda${{ matrix.cuda-version }}-cudnn9-devel options: --gpus 0 steps: - uses: actions/checkout@v2 - name: Check CUDA installation run: | nvcc -V - name: Build and install quanto run: | pip install --upgrade pip pip install -e .[dev] - name: Run base tests run: | python -m pytest tests --ignore=tests/models --ignore=tests/cli - name: Run models tests run: | pip install accelerate transformers diffusers python -m pytest tests/models - name: Run CLI tests run: | pip install optimum python -m pytest tests/cli ================================================ FILE: .github/workflows/linux-examples.yml ================================================ name: Linux examples (CPU, CUDA) on: push: branches: - main paths: - "optimum/quanto/**" - "examples/**" - "pyproject.toml" pull_request: types: [assigned, opened, synchronize, reopened] paths: - "optimum/quanto/**" - "examples/**" - "pyproject.toml" jobs: check-commits: uses: ./.github/workflows/check-commits.yml python-quality: uses: ./.github/workflows/python-quality.yml run-examples: needs: [check-commits, python-quality] runs-on: group: aws-g5-4xlarge-plus strategy: fail-fast: false matrix: device: ["cpu", "cuda"] container: image: pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel options: --gpus 0 steps: - uses: actions/checkout@v2 - name: Check CUDA installation run: | nvcc -V - name: Build and install packages run: | pip install --upgrade pip pip install -e .[examples] # Run examples - name: Run MNIST classification example run: | for w in int4 int8 float8; do \ for a in none int8 float8; do \ python examples/vision/image-classification/mnist/quantize_mnist_model.py \ --weights $w --activations $a --device ${{ matrix.device }}; \ done; \ done - name: Run OWL detection example run: | for w in int4 int8 float8; do \ python examples/vision/object-detection/quantize_owl_model.py \ --image http://images.cocodataset.org/val2017/000000039769.jpg \ --texts "a photo of a cat" "a remote" \ --weights $w --device ${{ matrix.device }}; \ done - name: Run text-classification example run: | for w in int4 int8; do \ for a in none int8; do \ python examples/nlp/text-classification/sst2/quantize_sst2_model.py \ --weights $w --activations $a --device ${{ matrix.device }}; \ done; \ done - name: Run text-to-image example if: ${{ matrix.device == 'cuda'}} run: | for w in int4 int8 fp8; do \ python examples/vision/text-to-image/quantize_pixart_sigma.py \ --qtype $w --device ${{ matrix.device }}; \ done ================================================ FILE: .github/workflows/python-quality.yml ================================================ name: Python code quality on: [workflow_call] jobs: check_code_quality: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: python-version: 3.9 - name: Install dependencies run: | pip install --upgrade pip pip install .[dev] - run: ruff format bench examples optimum tests --diff - run: ruff check --show-fixes bench examples optimum tests ================================================ FILE: .github/workflows/security.yml ================================================ name: Security Checks on: push: permissions: contents: read jobs: secrets: runs-on: ubuntu-latest steps: - shell: bash env: REF_NAME: ${{ github.ref_name }} HEAD_REF: ${{ github.event.pull_request.head.ref }} run: | if [ "${{ github.event_name }}" == "push" ]; then echo "depth=$(($(jq length <<< '${{ toJson(github.event.commits) }}') + 2))" >> $GITHUB_ENV echo "branch=$REF_NAME" >> $GITHUB_ENV fi if [ "${{ github.event_name }}" == "pull_request" ]; then echo "depth=$((${{ github.event.pull_request.commits }}+2))" >> $GITHUB_ENV echo "branch=$HEAD_REF" >> $GITHUB_ENV fi - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{env.branch}} fetch-depth: ${{env.depth}} - name: Scan for secrets uses: trufflesecurity/trufflehog@6bd2d14f7a4bc1e569fa3550efa7ec632a4fa67b # main ================================================ FILE: .github/workflows/stale.yml ================================================ name: 'Close stale issues and PRs' on: schedule: - cron: '30 1 * * *' workflow_dispatch: permissions: issues: write pull-requests: write jobs: stale: runs-on: ubuntu-latest steps: - uses: actions/stale@v9 with: stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.' stale-pr-message: 'This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.' close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.' close-pr-message: 'This PR was closed because it has been stalled for 5 days with no activity.' days-before-issue-stale: 30 days-before-pr-stale: 15 days-before-issue-close: 5 days-before-pr-close: 5 ================================================ FILE: .gitignore ================================================ __pycache__ .pytest_cache *.egg-info dist .venv build/ ================================================ FILE: CONTRIBUTING.md ================================================ # Contribute to optimum-quanto Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable. It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you. However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/transformers/blob/main/CODE_OF_CONDUCT.md). **This guide is directly inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).** ## Ways to contribute There are several ways you can contribute: * Fix outstanding issues with the existing code. * Submit issues related to bugs or desired new features. * Implement new kernels. > All contributions are equally valuable to the community. 🥰 ## Fixing outstanding issues If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md/#create-a-pull-request) and open a Pull Request! ## Submitting a bug-related issue or feature request Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback. ### Did you find a bug? The `optimum-quanto` backend will become more robust and reliable thanks to users who will report the problems they encounter. Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. If you're unsure whether the bug is in your code or the library, please ask in the [forum](https://discuss.huggingface.co/) first. This helps us respond quicker to fixing issues related to the library versus general questions. Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it: * Your **OS type and version** and **Python** and **PyTorch** versions. * A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s. * The *full* traceback if an exception is raised. * Attach any other additional information, like screenshots, you think may help. ### Do you want a new feature? If there is a new feature you'd like to see, please open an issue and describe: 1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community? Whatever it is, we'd love to hear about it! 2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you. 3. Provide a *code snippet* that demonstrates the features usage. 4. If the feature is related to a paper, please include a link. If your issue is well written we're already 80% of the way there by the time you create it. ## Do you want to implement a new kernel? With the constant evolution of hardware backends, there is always a need for updating the kernels for better performance. * The hardware configuration(s) it will apply to. * If any, a short description of the novel techniques that should be used to implement the kernel. If you are willing to contribute the kernel yourself, let us know so we can help you add it to `optimum-quanto`! ## Create a Pull Request Before writing any code, we strongly advise you to search through the existing PRs or issues to make sure nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback. You will need basic `git` proficiency to contribute. While `git` is not the easiest tool to use, it has the greatest manual. Type `git --help` in a shell and enjoy! If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference. You'll need **Python 3.8** or above to contribute. Follow the steps below to start contributing: 1. Fork the [repository](https://github.com/huggingface/optimum-quanto) by clicking on the **[Fork](https://github.com/huggingface/optimum-quanto/fork)** button on the repository's page. This creates a copy of the code under your GitHub user account. 2. Clone your fork to your local disk, and add the base repository as a remote: ```bash git clone git@github.com:/optimum-quanto.git cd optimum-quanto git remote add upstream https://github.com/huggingface/optimum-quanto.git ``` 3. Create a new branch to hold your development changes: ```bash git checkout -b a-descriptive-name-for-my-changes ``` 🚨 **Do not** work on the `main` branch! 4. Set up a development environment by running the following command in a virtual environment: ```bash pip install -e ".[dev]" ``` If `optimum-quanto` was already installed in the virtual environment, remove it with `pip uninstall optimum-quanto` before reinstalling it in editable mode with the `-e` flag. 5. Develop the features in your branch. As you work on your code, you should make sure the test suite passes. Run the tests impacted by your changes like this: ```bash pytest tests/.py ``` `optimum-quanto` relies on `black` and `ruff` to format its source code consistently. After you make changes, apply automatic style corrections and code verifications that can't be automated in one go with: ```bash make style ``` Once you're happy with your changes, add the changed files with `git add` and record your changes locally with `git commit`: ```bash git add modified_file.py git commit ``` This repository uses a `rebase` strategy when merging pull-requests, meaning that your commits will **not** be squashed automatically. We therefore request you to keep a tidy queue of commits in your pull-request, clearly communicating the changes you made in each commit. **This is enforced by the continuous integration, so your pull-request will not be reviewed if your commit queue is not clean.** Although this is not mandatory, we kindly ask you to consider using [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) (here the full [specification](https://www.conventionalcommits.org/en/v1.0.0/))! This article gives a brief [rationale](https://julien.ponge.org/blog/the-power-of-conventional-commits/) of why this will make our life and yours easier. To keep your copy of the code up to date with the original repository, rebase your branch on `upstream/branch` *before* you open a pull request or if requested by a maintainer: ```bash git fetch upstream git rebase upstream/main ``` Before submitting, cleanup your commit history to make it more readable for the reviewer (like squashing temporary commits and editing commit messages to clearly explain what you changed). Push your changes to your branch: ```bash git push -u origin a-descriptive-name-for-my-changes ``` If you've already opened a pull request, you'll need to force push with the `--force` flag. Otherwise, if the pull request hasn't been opened yet, you can just push your changes normally. 6. Now you can go to your fork of the repository on GitHub and click on **Pull Request** to open a pull request. Make sure you tick off all the boxes on our [checklist](https://github.com/huggingface/optimum-quanto/blob/main/CONTRIBUTING.md/#pull-request-checklist) below. When you're ready, you can send your changes to the project maintainers for review. 7. It's ok if maintainers request changes, it happens to our core contributors too! So everyone can see the changes in the pull request, work in your local branch and push the changes to your fork. They will automatically appear in the pull request. ### Pull request checklist ☐ The pull request title should summarize your contribution.
☐ If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people viewing the issue know you are working on it).
☐ To indicate a work in progress please prefix the title with `[WIP]`. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged.
☐ Make sure existing tests pass.
☐ If adding a new feature, also add tests for it.
☐ All public methods must have informative docstrings.
### Tests An extensive test suite is included to test the library behavior in the [tests](https://github.com/huggingface/optimum-quanto/tree/main/tests) folder. From the root of the repository, specify a *path to a subfolder or a test file* to run the test. ```bash python -m pytest -sv ./tests//.py ``` You can run all tests by typing: ```bash make test ``` ### Style guide For documentation strings, `optimum-quanto` follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). Check `transformers` [documentation writing guide](https://github.com/huggingface/transformers/tree/main/docs#writing-documentation---specification) for more information. ================================================ FILE: LICENSE ================================================ Copyright 2023 - The Hugging Face team. All rights reserved. Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ .PHONY: check test style check_dirs := optimum tests bench examples check: ruff check --show-fixes ${check_dirs} ruff format ${check_dirs} --diff style: ruff check ${check_dirs} --fix ruff format ${check_dirs} test: python -m pytest -sv tests ================================================ FILE: README.md ================================================ # Optimum Quanto > This project is currently in maintenance mode. We accept pull requests only for minor bug fixes, documentation improvements, and other maintenance tasks. Major new features or breaking changes are unlikely to be merged. For production-ready quantization features or active development, consider alternative projects such as [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) or [torchAO](https://github.com/pytorch/ao). 🤗 Optimum Quanto is a pytorch quantization backend for [optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: - all features are available in eager mode (works with non-traceable models), - quantized models can be placed on any device (including CUDA and MPS), - automatically inserts quantization and dequantization stubs, - automatically inserts quantized functional operations, - automatically inserts quantized modules (see below the list of supported modules), - provides a seamless workflow from a float model to a dynamic to a static quantized model, - serialization compatible with pytorch `weight_only` and 🤗 `safetensors`, - accelerated matrix multiplications on CUDA devices (int8-int8, fp16-int4, bf16-int8, bf16-int4), - supports int2, int4, int8 and float8 weights, - supports int8 and float8 activations. Features yet to be implemented: - dynamic activations smoothing, - kernels for all mixed matrix multiplications on all devices, - compatibility with [torch compiler](https://pytorch.org/docs/stable/torch.compiler.html) (aka dynamo). ## Performances In a nutshell: - accuracy: models compiled with `int8`/`float8` weights and `float8` activations are very close to the full-precision models, - latency: whenever optimized kernels are available, the inference of quantized model is comparable with the full-precision models when quantizing only the model weights, - device memory: approximately divided by float bits / integer bits. The paragraph below is just an example. Please refer to the `bench` folder for detailed results per use-case of model. ### meta-llama/Meta-Llama-3.1-8B
meta-llama/Meta-Llama-3.1-8B WikiText perplexity
meta-llama/Meta-Llama-3.1-8B Latency
## Installation Optimum Quanto is available as a pip package. ```sh pip install optimum-quanto ``` ## Quantization workflow for Hugging Face models `optimum-quanto` provides helper classes to quantize, save and reload Hugging Face quantized models. ### LLM models The first step is to quantize the model ```python from transformers import AutoModelForCausalLM from optimum.quanto import QuantizedModelForCausalLM, qint4 model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B') qmodel = QuantizedModelForCausalLM.quantize(model, weights=qint4, exclude='lm_head') ``` Note: the model quantized weights will be frozen. If you want to keep them unfrozen to train them you need to use `optimum.quanto.quantize` directly. The quantized model can be saved using `save_pretrained`: ```python qmodel.save_pretrained('./Llama-3-8B-quantized') ``` It can later be reloaded using `from_pretrained`: ```python from optimum.quanto import QuantizedModelForCausalLM qmodel = QuantizedModelForCausalLM.from_pretrained('Llama-3-8B-quantized') ``` ### Diffusers models You can quantize any of the submodels inside a diffusers pipeline and seamlessly include them later in another pipeline. Here we quantize the `transformer` of a `Pixart` pipeline. ```python from diffusers import PixArtTransformer2DModel from optimum.quanto import QuantizedPixArtTransformer2DModel, qfloat8 model = PixArtTransformer2DModel.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", subfolder="transformer") qmodel = QuantizedPixArtTransformer2DModel.quantize(model, weights=qfloat8) qmodel.save_pretrained("./pixart-sigma-fp8") ``` Later, we can reload the quantized model and recreate the pipeline: ```python from diffusers import PixArtTransformer2DModel from optimum.quanto import QuantizedPixArtTransformer2DModel transformer = QuantizedPixArtTransformer2DModel.from_pretrained("./pixart-sigma-fp8") transformer.to(device="cuda") pipe = PixArtSigmaPipeline.from_pretrained( "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", transformer=None, torch_dtype=torch.float16, ).to("cuda") pipe.transformer = transformer ``` ## Quantization workflow for vanilla pytorch models (low-level API) One thing to keep in mind when using the low-level quanto API is that by default models weights are dynamically quantized: an explicit call must be made to 'freeze' the quantized weights. A typical quantization workflow would consist of the following steps: **1. Quantize** The first step converts a standard float model into a dynamically quantized model. ```python from optimum.quanto import quantize, qint8 quantize(model, weights=qint8, activations=qint8) ``` At this stage, only the inference of the model is modified to dynamically quantize the weights. **2. Calibrate (optional if activations are not quantized)** Quanto supports a calibration mode that allows to record the activation ranges while passing representative samples through the quantized model. ```python from optimum.quanto import Calibration with Calibration(momentum=0.9): model(samples) ``` This automatically activates the quantization of the activations in the quantized modules. **3. Tune, aka Quantization-Aware-Training (optional)** If the performance of the model degrades too much, one can tune it for a few epochs to recover the float model performance. ```python import torch model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data).dequantize() loss = torch.nn.functional.nll_loss(output, target) loss.backward() optimizer.step() ``` **4. Freeze integer weights** When freezing a model, its float weights are replaced by quantized integer weights. ```python from optimum.quanto import freeze freeze(model) ``` **5. Serialize quantized model** Quantized models weights can be serialized to a `state_dict`, and saved to a file. Both `pickle` and `safetensors` (recommended) are supported. ```python from safetensors.torch import save_file save_file(model.state_dict(), 'model.safetensors') ``` In order to be able to reload these weights, you also need to store the quantized model quantization map. ```python import json from optimum.quanto import quantization_map with open('quantization_map.json', 'w') as f: json.dump(quantization_map(model), f) ``` **5. Reload a quantized model** A serialized quantized model can be reloaded from a `state_dict` and a `quantization_map` using the `requantize` helper. Note that you need first to instantiate an empty model. ```python import json from safetensors.torch import load_file from optimum.quanto import requantize state_dict = load_file('model.safetensors') with open('quantization_map.json', 'r') as f: quantization_map = json.load(f) # Create an empty model from your modeling code and requantize it with torch.device('meta'): new_model = ... requantize(new_model, state_dict, quantization_map, device=torch.device('cuda')) ``` Please refer to the [examples](https://github.com/huggingface/quanto/tree/main/examples) for instantiations of that workflow. ## Design overview ### Tensors At the heart of quanto is a Tensor subclass that corresponds to: - the projection of a source Tensor into the optimal range for a given destination type, - the mapping of projected values to the destination type. For floating-point destination types, the mapping is done by the native pytorch cast (i.e. `Tensor.to()`). For integer destination types, the mapping is a simple rounding operation (i.e. `torch.round()`). The goal of the projection is to increase the accuracy of the conversion by minimizing the number of: - saturated values (i.e. mapped to the destination type min/max), - zeroed values (because they are below the smallest number that can be represented by the destination type) The projection is symmetric per-tensor or per-channel for `int8` and `float8`, and group-wise affine (with a shift or 'zero-point') for lower bitwidth. One of the benefits of using a lower-bitwidth representation is that you will be able to take advantage of accelerated operations for the destination type, which is typically faster than their higher precision equivalents. Quanto does not support the conversion of a Tensor using mixed destination types. ### Modules Quanto provides a generic mechanism to replace `torch` modules by `optimum-quanto` modules that are able to process quanto tensors. `optimum-quanto` modules dynamically convert their weights until a model is frozen, which slows down inference a bit but is required if the model needs to be tuned. Weights are usually quantized per-channel along the first dimension (output features). Biases are not converted to preserve the accuracy of a typical `addmm` operation. Explanation: to be consistent with the unquantized arithmetic operations, biases would need to be quantized with a scale that is equal to the product of the input and weight scales, which leads to a ridiculously small scale, and conversely requires a very high bitwidth to avoid clipping. Typically, with `int8` inputs and weights, biases would need to be quantized with at least `12` bits, i.e. in `int16`. Since most biases are today `float16`, this is a waste of time. Activations are dynamically quantized per-tensor using static scales (defaults to the range `[-1, 1]`). To preserve accuracy, the model needs to be calibrated to evaluate the best activation scales (using a momentum). The following modules can be quantized: - [Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) (QLinear). Weights are always quantized, and biases are not quantized. Inputs and outputs can be quantized. - [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) (QConv2D). Weights are always quantized, and biases are not quantized. Inputs and outputs can be quantized. - [LayerNorm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html), Weights and biases are __not__ quantized. Outputs can be quantized. ## Pitfalls to avoid when quantizing activations Activations are always quantized per-tensor because most linear algebra operations in a model graph are not compatible with per-axis inputs: you simply cannot add numbers that are not expressed in the same base (`you cannot add apples and oranges`). Weights involved in matrix multiplications are, on the contrary, always quantized along their first axis, because all output features are evaluated independently from one another. The outputs of a quantized matrix multiplication will anyway always be dequantized, even if activations are quantized, because: - the resulting accumulated values are expressed with a much higher bitwidth (typically `int32` or `float32`) than the activation bitwidth (typically `int8` or `float8`), - they might be combined with a `float` bias. Quantizing activations per-tensor to `int8` can lead to serious quantization errors if the corresponding tensors contain large outlier values. Typically, this will lead to quantized tensors with most values set to zero (except the outliers). A possible solution to work around that issue is to 'smooth' the activations statically as illustrated by [SmoothQuant](https://github.com/mit-han-lab/smoothquant). You can find a script to smooth some model architectures under [external/smoothquant](external/smoothquant). A better option is to represent activations using `float8`. ================================================ FILE: bench/generation/README.md ================================================ # Quanto generation benchmark This repository contains scripts to evaluate the performances of quantized models using three metrics: - `latency.py` evaluates the latency per generated token, - `prediction.py` evaluates the accuracy when predicting the last token of prompts from the [Lambada dataset](https://huggingface.co/datasets/lambada), - `perplexity.py` evaluates the perplexity of the model on the [WikiText dataset](https://huggingface.co/datasets/wikitext), as defined in the [transformers documentation](https://huggingface.co/docs/transformers/en/perplexity). A `evaluate_model.py` utility script is also provided to evaluate the metrics on a specific model for several quantization configurations, and output the result to a `png` barchart and/or a `json` file. Note: the language modeling head (lm_head) of the tested models is not quantized. The paragraphs below display results for some popular models on a NVIDIA A10 GPU. ## meta-llama/Meta-Llama-3.1-8B
meta-llama/Meta-llama-3.1-8B Lambada prediction accuracy
meta-llama/Meta-Llama-3.1-8B WikiText perplexity
meta-llama/Meta-Llama-3.1-8B Latency
## mistralai/Mistral-7B-Instruct-v0.3
mistralai/Mistral-7B-Instruct-v0.3 Lambada prediction accuracy
mistralai/Mistral-7B-Instruct-v0.3 WikiText perplexity
mistralai/Mistral-7B-Instruct-v0.3 Latency
## google/gemma-2b
google-gemma-2b Lambada prediction accuracy
google-gemma-2b WikiText perplexity
google-gemma-2b Latency
================================================ FILE: bench/generation/evaluate_configurations.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import torch from evaluate_model import evaluate from gen_barchart import gen_barchart from transformers import AutoConfig from optimum.quanto import qtype def evaluate_model_configurations( model_id: str, metric: str, device: torch.device, batch_size: int = 32, dtype: torch.dtype = torch.float16 ): weights = [ "int4", "int8", "float8", ] activations = [ "none", "float8", ] def short_name(qtype: qtype): return { "none": "f16" if dtype == torch.float16 else "bf16", "int4": "i4", "int8": "i8", "float8": "f8", }[qtype] results = {} # Evaluate float16/bfloat16 model config_name = f"W{short_name('none')}A{short_name('none')}" print(f"{model_id}[{config_name}]:") results[config_name] = evaluate(model_id, metric, "quanto", "none", "none", batch_size, device, dtype) # Evaluate quantized models for w in weights: for a in activations: config_name = f"W{short_name(w)}A{short_name(a)}" print(f"{model_id}[{config_name}]:") results[config_name] = evaluate(model_id, metric, "quanto", w, a, batch_size, device, dtype) return results def main(): parser = argparse.ArgumentParser(description="Evaluate quantized model predictions on Lambada Dataset") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--model", type=str, default="facebook/opt-350m", help="The name of the trained Model.", ) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"]) parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.") parser.add_argument("--dtype", type=str, help="Use the following dtype to load the model.") parser.add_argument("--json", action="store_true", help="Dump the results to a json file.") parser.add_argument("--png", action="store_true", help="Generate a PNG.") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) if args.dtype is None: config = AutoConfig.from_pretrained(args.model) dtype = getattr(config, "torch_dtype", torch.float16) else: dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 results = evaluate_model_configurations(args.model, args.metric, device, batch_size=args.batch_size, dtype=dtype) if args.json: model_name = args.model.split("/")[-1] json_path = f"{model_name}-{args.metric}.json" with open(json_path, "w") as fp: json.dump({model_name: results}, fp, indent=4) if args.png: if args.metric == "latency": title = f"{args.model}: Mean latency per token" label = "Latency (ms)" elif args.metric == "prediction": title = f"{args.model}: Prediction accuracy on Lambada dataset" label = "Accuracy" elif args.metric == "perplexity": title = f"{args.model}: Perplexity evaluated on WikiText dataset" label = "Perplexity" gen_barchart(args.model, title, label, results, dtype) if __name__ == "__main__": main() ================================================ FILE: bench/generation/evaluate_many_models.sh ================================================ #!/bin/bash # Absolute path to this script, e.g. /home/user/bin/foo.sh SCRIPT=$(readlink -f "$0") # Absolute path this script is in, thus /home/user/bin SCRIPT_PATH=$(dirname "$SCRIPT") models=( google/gemma-2b meta-llama/Meta-Llama-3.1-8B mistralai/Mistral-7B-Instruct-v0.3 ) for m in ${models[@]}; do python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric prediction --png --json --batch_size 16 python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric perplexity --png --json --batch_size 16 python ${SCRIPT_PATH}/evaluate_configurations.py --model $m --metric latency --png --json --batch_size 16 done ================================================ FILE: bench/generation/evaluate_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import importlib import torch from datasets import load_dataset from metrics.latency import latency from metrics.perplexity import perplexity from metrics.prediction import prediction_accuracy if importlib.util.find_spec("awq") is not None: from setup.awq import setup as awq_setup if importlib.util.find_spec("bitsandbytes") is not None: from setup.bnb import setup as bnb_setup if importlib.util.find_spec("hqq") is not None: from setup.hqq import setup as hqq_setup from setup.quanto import setup as quanto_setup from transformers import AutoConfig @torch.no_grad() def calibrate(model, tokenizer, batch_size, batches): samples = batch_size * batches cal_dataset = load_dataset("lambada", split=["validation"])[0] model.eval() total = 0 for batch in cal_dataset.iter(batch_size=batch_size): inputs = tokenizer(batch["text"], return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) model(input_ids, attention_mask=attention_mask) total += input_ids.size(0) if total >= samples: break def evaluate( model_id: str, metric: str, quantizer: str, weights: str, activations: str, batch_size: int, device: torch.device, dtype: torch.dtype = None, ): if quantizer == "quanto": if dtype is None: config = AutoConfig.from_pretrained(model_id) dtype = getattr(config, "torch_dtype", torch.float16) model, tokenizer = quanto_setup(model_id, weights, activations, batch_size, device, dtype) elif quantizer == "awq": model, tokenizer = awq_setup(model_id, weights, activations, group_size=128) elif quantizer == "bnb": model, tokenizer = bnb_setup(model_id, weights, activations, device) elif quantizer == "hqq": model, tokenizer = hqq_setup(model_id, weights, activations, device) else: raise ValueError(f"Unsupported quantizer {quantizer}") dtype = next(model.parameters()).dtype weights = dtype if weights == "none" else weights activations = dtype if activations == "none" else activations print(f"Evaluating {model_id} {metric} with {weights} weights and {activations} activations.") if metric == "latency": return latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=3) elif metric == "prediction": return prediction_accuracy(model, tokenizer, batch_size) elif metric == "perplexity": return perplexity(model, tokenizer) def main(): parser = argparse.ArgumentParser(description="Evaluate quantized model metrics") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--model", type=str, default="facebook/opt-350m", help="The name of the trained Model.", ) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"]) parser.add_argument("--quantizer", type=str, default="quanto", choices=["quanto", "awq", "bnb", "hqq"]) parser.add_argument( "--weights", type=str, default="none", choices=["none", "int4", "int8", "float8"], ) parser.add_argument( "--activations", type=str, default="none", choices=["none", "int8", "float8"], ) parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.") parser.add_argument( "--dtype", type=str, default="none", choices=["none", "fp16", "bf16"], ) args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) dtype = {"none": None, "fp16": torch.float16, "bf16": torch.bfloat16}[args.dtype] evaluate(args.model, args.metric, args.quantizer, args.weights, args.activations, args.batch_size, device, dtype) if __name__ == "__main__": main() ================================================ FILE: bench/generation/gen_barchart.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import matplotlib.pyplot as plt import numpy as np import torch def save_bar_chart(title, labels, ylabel, series, save_path): x = np.arange(len(labels)) # the label locations width = 0.15 # the width of the bars multiplier = 0 fig, ax = plt.subplots(layout="constrained") fig.set_figwidth(10) max_value = 0 for attribute, measurement in series.items(): max_value = max(max_value, max(measurement)) offset = width * multiplier rects = ax.bar(x + offset, measurement, width, label=attribute) ax.bar_label(rects, padding=5) multiplier += 1 # Add some text for labels, title and custom x-axis tick labels, etc. ax.set_ylabel(ylabel) ax.set_title(title) ax.set_xticks(x + width, labels) ax.legend(loc="upper left", ncols=4) ax.set_ylim(0, max_value * 1.2) plt.savefig(save_path) def gen_barchart(model_id, title, label, results, dtype): dtype_str = "f16" if dtype is torch.float16 else "bf16" activations = (dtype_str, "f8") weights = ("i4", "i8", "f8") series = {} reference = round(results[f"W{dtype_str}A{dtype_str}"], 2) series[f"Weights {dtype_str}"] = [ reference, ] * len(activations) for w in weights: name = f"Weights {w}" series[name] = [] for a in activations: result = results[f"W{w}A{a}"] series[name].append(round(result, 2)) model_name = model_id.replace("/", "-") metric_name = label.replace(" ", "_").replace("(", "_").replace(")", "_") save_bar_chart( title=title, labels=[f"Activations {a}" for a in activations], series=series, ylabel=label, save_path=f"{model_name}_{dtype_str}_{metric_name}.png", ) def main(): parser = argparse.ArgumentParser() parser.add_argument("benchmark", type=str, help="A benchmark result file (.json).") parser.add_argument("--title", type=str, required=True, help="The graph title.") parser.add_argument("--label", type=str, required=True, help="The graph vertical label.") args = parser.parse_args() with open(args.benchmark) as f: benchmark = json.load(f) for model_id, results in benchmark.items(): gen_barchart(model_id, args.title, args.label, results) if __name__ == "__main__": main() ================================================ FILE: bench/generation/metrics/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: bench/generation/metrics/latency.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import time import numpy as np import torch from tqdm.auto import tqdm from transformers import GenerationConfig def latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=10): def synchronize(device): if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": torch.mps.synchronize() elif device.type == "xpu": torch.xpu.synchronize() else: torch.cpu.synchronize() def timing_event(device): if device.type == "cuda": return torch.cuda.Event(enable_timing=True) elif device.type == "mps": return torch.mps.Event(enable_timing=True) elif device.type == "xpu": return torch.xpu.Event(enable_timing=True) class CPUEvent: def __init__(self): self.time = None def record(self): self.time = time.time() def elapsed_time(self, other): assert self.time is not None assert other.time is not None return (other.time - self.time) * 1000 return CPUEvent() generation_config = GenerationConfig( max_new_tokens=nb_tokens, min_new_tokens=nb_tokens, use_cache=True, pad_token_id=tokenizer.pad_token_id, num_beams=1, do_sample=False, eos_token_id=None, # This is required for min_new_tokens to actually have an effect. ) if getattr(model, "generation_config", None) is not None: model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect. synchronize(device) if device.type == "cuda": torch.cuda.reset_peak_memory_stats() elif device.type == "xpu": torch.xpu.reset_peak_memory_stats() memory = get_device_memory(device) if memory is not None: print(f"Device memory: {memory / (2**30):.4f} GB") latencies = [] input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device) masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device) for _ in tqdm(range(iterations)): start_event = timing_event(device) end_event = timing_event(device) synchronize(device) start_event.record() _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config) end_event.record() synchronize(device) latency_ms = start_event.elapsed_time(end_event) latencies.append(latency_ms) if device.type == "cuda": peak_memory = torch.cuda.max_memory_allocated() print(f"Peak memory during benchmark: {peak_memory / (2**30):.4f} GB") elif device.type == "xpu": peak_memory = torch.xpu.max_memory_allocated() print(f"Peak memory during benchmark: {peak_memory / (2**30):.4f} GB") mean_latency = np.mean(latencies) / generation_config.min_new_tokens print(f"Average latency per token: {mean_latency} ms") return mean_latency def get_device_memory(device): gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return torch.cuda.memory_allocated() elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() elif device.type == "xpu": torch.xpu.empty_cache() return torch.xpu.memory_allocated() return None ================================================ FILE: bench/generation/metrics/perplexity.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import numpy as np import torch from datasets import load_dataset from tqdm import tqdm class Perplexity: """ A class for calculating the perplexity of a language model. """ def __init__(self, model, tokenizer, dataset_path="wikitext", dataset_name=None, split="test", text_column="text"): """ Calculate perplexity using the same method as seen in llama.cpp. Parameters ---------- model : AutoModelForCausalLM The language model for which the perplexity is calculated. tokenizer : AutoTokenizer The tokenizer corresponding to the model. dataset_path : str, optional The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'. dataset_name : str, optional The name of the dataset. Default is None. split : str, optional The split of the dataset to use. Default is 'test'. text_column : str, optional The name of the column in the dataset that contains the text data. Default is 'text'. """ self._model = model self._tokenizer = tokenizer self._dataset_path = dataset_path self._dataset_name = dataset_name self._split = split self._text_column = text_column self._text = self._prepare_data() def _prepare_data(self): """ Prepares the dataset by loading and formatting. Returns ------- str The formatted dataset as a single string. """ if self._dataset_path == "wikitext": self._dataset_name = "wikitext-2-raw-v1" # Load the dataset data = load_dataset(self._dataset_path, self._dataset_name, split=self._split) # Format the text column of the dataset text_list = [" \n" if s == "" else s for s in data[self._text_column]] return "".join(text_list) @staticmethod def softmax(logits): """ Static method for applying the softmax function. Parameters ---------- logits : np.ndarray The input to the softmax function. Returns ------- np.ndarray The output of the softmax function. """ e_x = np.exp(logits - np.max(logits)) return e_x / e_x.sum(axis=0) def calculate_perplexity(self, n_ctx=512, n_batch=512): """ Calculates the perplexity of the language model. Parameters ---------- n_ctx : int The context size. n_batch : int The batch size. Returns ------- list The list of perplexity scores calculated. """ # Tokenize the text self._tokenizer.model_max_length = sys.maxsize tokens = self._tokenizer(self._text, truncation=False, return_tensors="pt").input_ids.to(self._model.device) nll = 0.0 # Negative log likelihood count = 0 # Counter for processed tokens curr_ppl = 0 all_perplexity = [] with tqdm(range(len(tokens[0]) // n_ctx), desc="Perplexity: - ") as progress: for i in progress: # Process each batch of tokens nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) # Calculate and display the current perplexity curr_ppl = np.exp(nll / count) all_perplexity.append(curr_ppl) progress.set_description(f"Perplexity: {curr_ppl:.4f}") return all_perplexity def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count): """ Processes each batch of tokens. Parameters ---------- i : int The batch index. n_ctx : int The context size. n_batch : int The batch size. tokens : torch.Tensor The tokenized text. nll : float The current negative log likelihood. count : int The current count of processed tokens. Returns ------- float The updated negative log likelihood. int The updated count of processed tokens. """ start = i * n_ctx end = start + n_ctx num_batches = (n_ctx + n_batch - 1) // n_batch logits = [] for j in range(num_batches): batch_start = start + j * n_batch batch_size = min(end - batch_start, n_batch) token_org = tokens[0][batch_start].item() if j == 0: # Replace the first token with the BOS token tokens[0][batch_start] = self._tokenizer.bos_token_id # Compute the logits for the current batch of tokens batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size) tokens[0][batch_start] = token_org logits.append(batch_logits) # We rely on the fact that attention in the forward pass only looks at previous # tokens here, so the logits returned for each token are an accurate representation # of what the model would have predicted at that point. # # Example, we have a context window of 512, we will compute perplexity for each of the # last 256 tokens. Then, we split the input up into context window size chunks to # process the entire prompt. for j in range(min(512, n_ctx // 2), n_ctx - 1): tok_logits = logits[0][0][j].cpu().numpy() # Compute the probability of the next token prob = self.softmax(tok_logits)[tokens[0][start + j + 1]] # Update the negative log likelihood and the count of processed tokens nll += -np.log(prob, where=prob > 0) count += 1 return nll, count def _compute_batch_logits(self, tokens, batch_start, batch_size): """ Computes the logits for a batch of tokens. Parameters ---------- tokens : torch.Tensor The tokenized text. batch_start : int The start index of the batch. batch_size : int The size of the batch. Returns ------- torch.Tensor The logits for the batch of tokens. """ # Compute the logits without keeping track of gradients with torch.no_grad(): outputs = self._model(tokens[:, batch_start : batch_start + batch_size]) return outputs.logits.detach() def perplexity( model, tokenizer, stride: int = 512, ): print("Evaluating perplexity") ppl = Perplexity(model, tokenizer) ppl_value = np.mean(ppl.calculate_perplexity(n_ctx=stride)) return ppl_value ================================================ FILE: bench/generation/metrics/prediction.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import torch from datasets import load_dataset @torch.no_grad() def prediction_accuracy(model, tokenizer, batch_size, samples=None): test_dataset = load_dataset("lambada", split=["test"])[0] model.eval() # The task is to predict the last token of the input. total, hit = 0, 0 start = time.time() for batch in test_dataset.iter(batch_size=batch_size): inputs = tokenizer(batch["text"], return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) labels = input_ids[:, -1] # Pass only the first tokens outputs = model(input_ids[:, :-1], attention_mask=attention_mask[:, :-1]) preds = outputs.logits[:, -1, :].argmax(dim=-1) total += labels.size(0) hit += (preds == labels).sum().item() if samples is not None and total >= samples: break end = time.time() acc = hit / total print(f"{total} sequences evaluated in {end - start:.2f} s. accuracy = {acc:.2f}") return acc ================================================ FILE: bench/generation/setup/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: bench/generation/setup/awq.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from awq import AutoAWQForCausalLM from transformers import AutoTokenizer def prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): if past_key_values is not None: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def setup(model_id: str, weights: str, activations: str, group_size: int = 64, version="GEMV_FAST"): if activations != "none": raise ValueError("Activation quantization is not supported by HQQ") if weights != "int4": raise ValueError("AWQ only supports int4 weights.") quant_config = {"zero_point": True, "q_group_size": group_size, "w_bit": 4, "version": version} # Load model model = AutoAWQForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" # Quantize model.quantize(tokenizer, quant_config=quant_config) # We need to save otherwise it doesn't work quant_path = model_id.replace("/", "-") + f"_{group_size}_{version}" model.save_quantized(quant_path) # Reload model model = AutoAWQForCausalLM.from_quantized(quant_path) # Hack: force transformers 4.36.2 behaviour model.model.prepare_inputs_for_generation = prepare_inputs_for_generation # Hack because AWQ models are not transformers models model.device = next(model.parameters()).device return model, tokenizer ================================================ FILE: bench/generation/setup/bnb.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig def setup( model_id: str, weights: str, activations: str, device: torch.device, ): if activations != "none": raise ValueError("Activation quantization is not supported by BitsAndBytes") if weights == "int4": quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="fp4") elif weights == "int8": quantization_config = BitsAndBytesConfig(load_in_8bit=True) else: raise ValueError("BitsAndBytes only supports int4 and int8 weights.") dtype = torch.float32 if device.type == "cpu" else torch.float16 tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" quantization_config.bnb_4bit_compute_dtype = dtype model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, low_cpu_mem_usage=True, quantization_config=quantization_config ) return model, tokenizer ================================================ FILE: bench/generation/setup/hqq.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from hqq.core.quantize import BaseQuantizeConfig from hqq.engine.hf import HQQModelForCausalLM from transformers import AutoTokenizer def setup(model_id: str, weights: str, activations: str, device: torch.device, group_size: int = 64): if activations != "none": raise ValueError("Activation quantization is not supported by HQQ") if weights == "int4": quant_config = BaseQuantizeConfig(nbits=4, group_size=group_size) elif weights == "int8": quant_config = BaseQuantizeConfig(nbits=8, group_size=group_size) else: raise ValueError("HQQ only supports int4 and int8 weights.") # Load model model = HQQModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16) # Quantize model.quantize_model(quant_config=quant_config, compute_dtype=torch.float16, device=device) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" return model, tokenizer ================================================ FILE: bench/generation/setup/quanto.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from optimum.quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize @torch.no_grad() def calibrate(model, tokenizer, batch_size, batches): samples = batch_size * batches cal_dataset = load_dataset("lambada", split=["validation"])[0] model.eval() total = 0 for batch in cal_dataset.iter(batch_size=batch_size): inputs = tokenizer(batch["text"], return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) model(input_ids, attention_mask=attention_mask) total += input_ids.size(0) if total >= samples: break def setup( model_id: str, weights: str, activations: str, batch_size: int, device: torch.device, dtype: torch.dtype, ): weights = keyword_to_qtype(weights) activations = keyword_to_qtype(activations) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, low_cpu_mem_usage=True).to(device) if weights is not None or activations is not None: print("Quantizing") start = time.time() quantization_root = model if hasattr(model, "model"): quantization_root = model.model quantize(quantization_root, weights=weights, activations=activations) if activations is not None: print("Calibrating") with Calibration(): calibrate(model, tokenizer, batch_size, batches=4) print("Freezing") freeze(model) print(f"Finished: {time.time() - start:.2f}") return model, tokenizer def keyword_to_qtype(k): return { "none": None, "int4": qint4, "int8": qint8, "float8": qfloat8, }[k] ================================================ FILE: bench/kernels/benchmark.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import time from contextlib import nullcontext import numpy as np import torch from tqdm.auto import tqdm from optimum.quanto.library import disable_extensions def get_unpack_bench(bits, device): qmax = 2**bits a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device) def bench_fn(): return torch.ops.quanto.unpack(a, bits) return bench_fn def timing(get_bench_func, device, iterations=10): def synchronize(device): if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": torch.mps.synchronize() elif device.type == "xpu": torch.xpu.synchronize() else: torch.cpu.synchronize() def timing_event(device): if device.type == "cuda": return torch.cuda.Event(enable_timing=True) elif device.type == "mps": return torch.mps.Event(enable_timing=True) elif device.type == "xpu": return torch.xpu.Event(enable_timing=True) class CPUEvent: def __init__(self): self.time = None def record(self): self.time = time.time() def elapsed_time(self, other): assert self.time is not None assert other.time is not None return (other.time - self.time) * 1000 return CPUEvent() synchronize(device) bench_func = get_bench_func(device) # Warmup to load library bench_func() latencies = np.empty((iterations, 2)) for i in tqdm(range(iterations)): for j, context in enumerate([disable_extensions(), nullcontext()]): start_event = timing_event(device) end_event = timing_event(device) synchronize(device) start_event.record() with context: bench_func() end_event.record() synchronize(device) latencies[i, j] = start_event.elapsed_time(end_event) return np.mean(latencies[:, 0]), np.mean(latencies[:, 1]) GET_BENCH_FUNCTIONS = { "unpack_2bit": lambda device: get_unpack_bench(2, device), "unpack_4bit": lambda device: get_unpack_bench(4, device), } def main(): parser = argparse.ArgumentParser(description="Kernel benchmark") parser.add_argument("--kernel", type=str, default=None, help="The kernel to benchmark. None to test all of them") parser.add_argument("--device", type=str, default=None, help="The device to use for benchmark.") parser.add_argument("--it", type=int, default=10, help="The number of benchmark iterations") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) all_kernels = GET_BENCH_FUNCTIONS.keys() kernels = all_kernels if args.kernel is None else [args.kernel] for kernel in kernels: get_bench_fn = GET_BENCH_FUNCTIONS[kernel] python_ms, ext_ms = timing(get_bench_fn, device, iterations=args.it) ratio = python_ms / ext_ms print(f"\n{kernel}[{device.type}]: python = {python_ms:.3f} ms, ext = {ext_ms:.3f} ms, ratio = {ratio:.1f}x") if __name__ == "__main__": main() ================================================ FILE: bench/kernels/benchmark_marlin_fp8.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from typing import Optional import numpy as np import torch from optimum.quanto.tensor.weights.marlin.packed import pack_fp8_as_int32 M_SHAPES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] N_SHAPES = [4096] K_SHAPES = [4096] def run_benchmark( m: Optional[int], n: Optional[int], k: Optional[int], n_runs: int, n_warmup: int, dtype: torch.dtype = torch.float16, ): print(f"\n----------- m={m}, n={n}, k={k}") n_tokens = m in_features = k out_features = n assert m is not None device = torch.device("cuda") inputs = torch.rand(n_tokens, in_features, dtype=dtype, device=device) other_shape = (in_features, out_features) other_data = torch.rand(other_shape, dtype=dtype, device=device).to(torch.float8_e4m3fn) other_data_int32 = pack_fp8_as_int32(other_data) perm = torch.empty(0, dtype=torch.int, device=device) other_data_repack = torch.ops.quanto.gptq_marlin_repack( b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8 ) other_scale = torch.rand(1, dtype=dtype, device=device) other_scale = other_scale.repeat(1, out_features) workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device) latencies_marlin_fp8 = [] latencies_torch = [] with torch.no_grad(): for i in range(n_runs): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize(device) start_event.record() _ = torch.ops.quanto.fp8_marlin_gemm( a=inputs, b_q_weight=other_data_repack, b_scales=other_scale, workspace=workspace, num_bits=8, size_m=n_tokens, size_n=out_features, size_k=in_features, ) end_event.record() torch.cuda.synchronize(device) latency_ms = start_event.elapsed_time(end_event) if i >= n_warmup: latencies_marlin_fp8.append(latency_ms) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize(device) start_event.record() other = other_data.to(dtype) * other_scale _ = torch.matmul(inputs, other) end_event.record() torch.cuda.synchronize(device) latency_ms = start_event.elapsed_time(end_event) if i >= n_warmup: latencies_torch.append(latency_ms) mean_latency_torch = np.mean(latencies_torch) mean_latency_marlin_fp8 = np.mean(latencies_marlin_fp8) print("mean_latency_torch:", mean_latency_torch) print("mean_latency_marlin_fp8:", mean_latency_marlin_fp8) return mean_latency_torch, mean_latency_marlin_fp8 if __name__ == "__main__": parser = argparse.ArgumentParser(description="Marlin FP8 kernel benchmark") parser.add_argument("--nruns", type=int, default=20, help="The number of benchmark iterations") parser.add_argument("--nwarmup", type=int, default=2, help="The number of warmup iterations (deducted from nruns)") parser.add_argument( "--m", type=int, help="m dimension of A=m*k", default=None, ) parser.add_argument( "--n", type=int, help="n dimension of B=k*n (out_features)", default=None, ) parser.add_argument( "--k", type=int, help="k dimension of A=m*k and B=k*n (in_features), hidden_size", default=None, ) args = parser.parse_args() if args.m is not None: def shape_generator(): yield (args.m, args.n, args.k) else: def shape_generator(): for m in M_SHAPES: for n in N_SHAPES: for k in K_SHAPES: yield (m, n, k) result = "m,n_out,k_in,torch_latency_ms,marlin_fp8_latency_ms\n" for m, n, k in shape_generator(): mean_latency_torch, mean_latency_marlin_fp8 = run_benchmark(m, n, k, args.nruns, args.nwarmup) result += ( ",".join( [ str(m), str(n), str(k), f"{mean_latency_torch:.4f}", f"{mean_latency_marlin_fp8:.4f}", ] ) + "\n" ) print("\nResults:") print(result) ================================================ FILE: bench/kernels/benchmark_w4a16.py ================================================ # From: https://github.com/IST-DASLab/marlin/blob/master/bench.py import argparse import time import torch from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking from optimum.quanto.tensor.weights.marlin import marlin_permute from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor def benchmark(f, warmup=1, iter=10): for i in range(warmup + iter): f() # We do not synchronize here in order to hide the kernel launch overhead during benchmarkining as this will also # happen during realistic model inference as many launches are submitted to the kernel queue. if i == warmup - 1: torch.cuda.synchronize() tick = time.time() torch.cuda.synchronize() res = (time.time() - tick) / iter # Make sure there is enough to "cool down" the GPU in between benchmarks to avoid throttling for later runs when # we execute many benchmarks consecutively time.sleep(1.0) return res def get_problem(m, n, k, groupsize=128): dev = torch.device("cuda:0") A = torch.rand((m, k), dtype=torch.half, device=dev) B_4bit = torch.randint(0, 2**4, (n, k), dtype=torch.uint8, device=dev) B_awq = AWQPackedTensor.pack(B_4bit, packing=AWQPacking.V2)._data B_marlin = MarlinInt4PackedTensor.pack(B_4bit)._data B_ref = torch.rand((k, n), dtype=torch.half, device=dev) s = torch.rand((k // groupsize, n), dtype=torch.half, device=dev) / 2**4 s_marlin = marlin_permute(s) z = torch.randint(-(2 ** (4 - 1)), 2 ** (4 - 1), (k // groupsize, n), dtype=torch.int8, device=dev) sz = -z * s sz_marlin = marlin_permute(sz) torch.cuda.synchronize() return A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin def benchmark_dense(A, B, m, n, k): res = benchmark(lambda: torch.matmul(A, B)) return { "s": res, "TFLOP/s": 2 * A.numel() * n / res / 10**12, "GB/s": (2 * A.numel() + 2 * B.numel() + 2 * (m * n)) / res / 10**9, } def benchmark_awq(A, B, s, sz, m, n, k): res = benchmark( lambda: torch.ops.quanto.gemm_f16i4_awq(A, B, s, sz, rows=m, out_cols=n, in_cols=k, bits=4, group_size=128) ) return { "s": res, "TFLOP/s": 2 * (m * k) * n / res / 10**12, "GB/s": (2 * A.numel() + 2 * B.numel() + 2 * (m * n) + 2 * s.numel() + 2 * sz.numel()) / res / 10**9, } def benchmark_marlin(A, B, s, sz, m, n, k): workspace = torch.zeros(n // 128 * 16, dtype=torch.int, device=torch.device("cuda:0")) res = benchmark(lambda: torch.ops.quanto.gemm_f16i4_marlin(A, B, s, sz, workspace)) return { "s": res, "TFLOP/s": 2 * (m * k) * n / res / 10**12, "GB/s": (2 * A.numel() + 4 * B.numel() + 2 * (m * n) + 2 * s.numel() + 2 * sz.numel()) / res / 10**9, } MODELS = { "Llama7B": [(4096, 3 * 4096), (4096, 4096), (4096, 2 * 10752), (10752, 4096)], "Llama13B": [(5120, 3 * 5120), (5120, 5120), (5120, 2 * 13568), (13568, 5120)], "Llama33B": [(6656, 3 * 6656), (6656, 6656), (6656, 2 * 17664), (17664, 6656)], "Llama65B": [(8192, 3 * 8192), (8192, 8192), (8192, 2 * 21760), (21760, 8192)], "Falcon180B": [ # Note that parallel attention and FC allows layer fusions (14848, 14848 * 5 + 1024), (14848 * 5, 14848), ], } def run_benchmark(model, tokens=None): if tokens is None: tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] elif not isinstance(tokens, (list, tuple)): tokens = [tokens] groupsize = 128 layers = MODELS[model] print(model) for m in tokens: tot_awq = {"s": 0, "TFLOP/s": 0, "GB/s": 0, "speedup": 0} tot_marlin = {"s": 0, "TFLOP/s": 0, "GB/s": 0, "speedup": 0} for layer in layers: k, n = layer A, B_ref, B_awq, B_marlin, s, s_marlin, sz, sz_marlin = get_problem(m, n, k, groupsize) res_d = benchmark_dense(A, B_ref, m, n, k) res_awq = benchmark_awq(A, B_awq, s, sz, m, n, k) res_awq["speedup"] = res_d["s"] / res_awq["s"] tot_awq["s"] += res_awq["s"] for key in tot_awq: if key != "s": tot_awq[key] += res_awq[key] * res_awq["s"] res_marlin = benchmark_marlin(A, B_marlin, s_marlin, sz_marlin, m, n, k) res_marlin["speedup"] = res_d["s"] / res_marlin["s"] tot_marlin["s"] += res_marlin["s"] for key in tot_marlin: if key != "s": tot_marlin[key] += res_marlin[key] * res_marlin["s"] for key in tot_awq: if key != "s": tot_awq[key] /= tot_awq["s"] for key in tot_marlin: if key != "s": tot_marlin[key] /= tot_marlin["s"] print( "AWQ, tokens=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f" % (m, tot_awq["s"], tot_awq["TFLOP/s"], tot_awq["GB/s"], tot_awq["speedup"]) ) print( "Marlin, batch=%04d: s=%.5f, TFLOP/s=%07.3f, GB/s=%08.3f, speedup=%.2f" % (m, tot_marlin["s"], tot_marlin["TFLOP/s"], tot_marlin["GB/s"], tot_marlin["speedup"]) ) def main(): parser = argparse.ArgumentParser(description="W4A16 Matrix Multiplication Kernel benchmark") parser.add_argument( "--model", type=str, default=None, help="The model configuration to benchmark. None to test all of them." ) parser.add_argument( "--tokens", type=int, default=None, help="The numbers of input tokens used to benchmark. None to test a predefined range.", ) args = parser.parse_args() models = MODELS if args.model is None else [args.model] for model in models: run_benchmark(model, args.tokens) print() if __name__ == "__main__": main() ================================================ FILE: bench/torch_kernels/README.md ================================================ This contains a few scripts to test pytorch kernels that are relevant for quantization. ================================================ FILE: bench/torch_kernels/test_int_mm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import timeit import torch def main(): parser = argparse.ArgumentParser(description="Torch integer matmul benchmark") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--device", type=str, default=None, help="The device to use for the test.") parser.add_argument("--it", type=int, default=100, help="Number of iterations for average") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) def avg_time(f, it): return timeit.Timer(f).timeit(it) / it # Resstrictions for accelerated integer matmul: # - input matrices must be 2D # - the collapsing dimension must be a multiple of 8 A = torch.randint(1, 10, [2400, 3200]).type(torch.int8).to(device) B = torch.randint(1, 10, [3200, 4800]).type(torch.int8).to(device) print(f"Evaluating integer matmul on {device.type}:") # Warmup (slow) torch._int_mm(A, B) # Average on several calls t = avg_time(lambda: torch._int_mm(A, B), args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") # Convert inputs to float def to_float(x): if x.device.type == ("cpu"): # matrix multiplication is not supported for float16 on CPU return x.to(torch.float32) return x.to(torch.float16) A = to_float(A) B = to_float(B) print(f"Evaluating {A.dtype} matmul on {device.type}:") # Warmup (slow) torch.matmul(A, B) # Average on several calls t = avg_time(lambda: torch.matmul(A, B), args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") if __name__ == "__main__": main() ================================================ FILE: bench/torch_kernels/test_int_mm_inductor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import timeit import torch def mm(a, b): return torch._int_mm(a, b) A = torch.randint(1, 10, [2400, 2400]).type(torch.int8).cuda() B = torch.randint(1, 10, [2400, 2400]).type(torch.int8).cuda() it = 100 # Warmup (slow) mm(A, B) # Get a reference print(timeit.Timer(lambda: mm(A, B)).timeit(it) / it) cmm = torch.compile(mm, backend="inductor") # First invocation will trigger the actual compilation cmm(A, B) # Now compare execution time print(timeit.Timer(lambda: cmm(A, B)).timeit(it) / it) ================================================ FILE: bench/torch_kernels/test_weight_int4pack_mm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import timeit import torch def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert w.dim() == 2 w = w.transpose(0, 1).contiguous() assert q_group_size > 1 assert w.shape[-1] % q_group_size == 0 to_quant = w.reshape(-1, q_group_size) assert torch.isnan(to_quant).sum() == 0 max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 zeros = min_val + scales * (2 ** (n_bit - 1)) assert torch.isnan(zeros).sum() == 0 out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.int32).reshape(w.shape) # Scales and zeros for the same q-group should be contiguous, so we can # load as a 32-bit word scales = scales.view(w.shape[0], -1) zeros = zeros.view(w.shape[0], -1) scales_and_zeros = ( torch.cat( [ scales.reshape(scales.size(0), scales.size(1), 1), zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, ) .transpose(0, 1) .contiguous() ) return out, scales_and_zeros def main(): parser = argparse.ArgumentParser(description="Torch quantized int4 weight matmul benchmark") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], help="floating point type") parser.add_argument("--device", type=str, default=None, help="The device to use for the test.") parser.add_argument("--it", type=int, default=10, help="Number of iterations for average") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) def avg_time(f, it): return timeit.Timer(f).timeit(it) / it dtype = {"fp16": torch.float16, "bf16": torch.bfloat16}[args.dtype] A = torch.rand([2400, 3200], dtype=dtype, device=device) B = torch.rand([3200, 4800], dtype=dtype, device=device) group_size = 128 B_int32, B_scale_and_zeros = _group_quantize_tensor(B, n_bit=4, q_group_size=group_size) if device.type == "cpu": B_packed = torch._convert_weight_to_int4pack_for_cpu(B_int32, innerKTiles=2) else: B_uint8 = (B_int32[::, ::2] << 4 | B_int32[::, 1::2]).to(torch.uint8) B_packed = torch._convert_weight_to_int4pack(B_uint8, innerKTiles=2) # Check quantized mm is close to float mm if device.type == "cpu": qout = torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros) else: qout = torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros) out = torch.mm(A, B) mean_err = ((qout - out).abs() / out.abs()).mean() print(mean_err) print(f"Evaluating quantized int4 matmul on {device.type}:") # Warmup (slow) if device.type == "cpu": torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros) else: torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros) # Average on several calls if device.type == "cpu": t = ( avg_time(lambda: torch._weight_int4pack_mm_for_cpu(A, B_packed, group_size, B_scale_and_zeros), args.it) * 1000 ) else: t = avg_time(lambda: torch._weight_int4pack_mm(A, B_packed, group_size, B_scale_and_zeros), args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") print(f"Evaluating {A.dtype} matmul on {device.type}:") # Warmup (slow) torch.mm(A, B) # Average on several calls t = avg_time(lambda: torch.mm(A, B), args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") if __name__ == "__main__": main() ================================================ FILE: bench/torch_kernels/test_weight_int8pack_mm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import timeit import torch def main(): parser = argparse.ArgumentParser(description="Torch quantized int8 weight matmul benchmark") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--device", type=str, default=None, help="The device to use for the test.") parser.add_argument("--it", type=int, default=10, help="Number of iterations for average") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) def avg_time(f, it): return timeit.Timer(f).timeit(it) / it A = torch.rand([2400, 3200], dtype=torch.bfloat16, device=device) B = torch.randint(-128, 127, [4800, 3200], dtype=torch.int8, device=device) B_scale = torch.rand([4800], dtype=torch.bfloat16, device=device) print(f"Evaluating quantized int8 matmul on {device.type}:") # Warmup (slow) torch._weight_int8pack_mm(A, B, B_scale) # Average on several calls t = avg_time(lambda: torch._weight_int8pack_mm(A, B, B_scale), args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") # Convert weights to float B = B.to(torch.bfloat16).t() print(f"Evaluating {A.dtype} matmul on {device.type}:") # Warmup (slow) torch.matmul(A, B) * B_scale # Average on several calls t = avg_time(lambda: torch.matmul(A, B) * B_scale, args.it) * 1000 print(f"Average inference on {args.it} iterations: {t:.4f} ms") if __name__ == "__main__": main() ================================================ FILE: examples/nlp/text-classification/sst2/quantize_sst2_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import io import time import numpy as np import torch from datasets import load_dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline from transformers.pipelines.pt_utils import KeyDataset from optimum.quanto import Calibration, freeze, qint4, qint8, quantize def evaluate_model(model, tokenizer, dataset, device, batch_size): p = pipeline("sentiment-analysis", model, tokenizer=tokenizer, device=device) results = p(KeyDataset(dataset, "sentence"), batch_size=batch_size) start = time.time() pred_labels = [0 if result["label"] == "NEGATIVE" else 1 for result in results] end = time.time() accuracy = np.sum(np.equal(pred_labels, dataset["label"])) / len(pred_labels) print(f"{len(pred_labels)} sentences evaluated in {end - start:.2f} s. accuracy = {accuracy}") def keyword_to_itype(k): return {"none": None, "int8": qint8, "int4": qint4}[k] def main(): parser = argparse.ArgumentParser(description="Transformers SST2 Example") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--model", type=str, default="distilbert-base-uncased-finetuned-sst-2-english", help="The name of the trained Model.", ) parser.add_argument("--samples", type=int, default=872, help="The number of sst2 samples to use for evaluation.") parser.add_argument("--batch_size", type=int, default=100, help="The batch size to use for evaluation.") parser.add_argument("--weights", type=str, default="int8", choices=["int4", "int8"]) parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8"]) parser.add_argument("--device", type=str, default=None, help="The device to use for evaluation.") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) tokenizer = AutoTokenizer.from_pretrained(args.model) dataset = load_dataset("sst2", split=f"validation[:{args.samples}]") print("Float model") evaluate_model(model, tokenizer, dataset, device, args.batch_size) weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") with Calibration(): evaluate_model(model, tokenizer, dataset, device, args.batch_size) freeze(model) print(f"Quantized model (w: {args.weights}, a: {args.activations})") evaluate_model(model, tokenizer, dataset, device, args.batch_size) b = io.BytesIO() torch.save(model.state_dict(), b) b.seek(0) state_dict = torch.load(b) model_reloaded = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) quantize(model_reloaded, weights=weights, activations=activations) model_reloaded.load_state_dict(state_dict) print("Serialized quantized model") evaluate_model(model, tokenizer, dataset, device, args.batch_size) if __name__ == "__main__": main() ================================================ FILE: examples/nlp/text-generation/quantize_causal_lm_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import time import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from optimum.quanto import Calibration, QuantizedModelForCausalLM, qfloat8, qint4, qint8 @torch.no_grad() def generate(model, tokenizer, device, prompt, max_new_tokens): inputs = tokenizer(prompt, return_tensors="pt", padding=True) start = time.time() outputs = model.generate( input_ids=inputs.input_ids.to(device), max_new_tokens=max_new_tokens, attention_mask=inputs.attention_mask.to(device), do_sample=True, top_k=50, top_p=0.9, ) end = time.time() generated_text = tokenizer.decode(outputs[0]) print(f"Generated '{generated_text}' in [{end - start:.2f} s]") @torch.no_grad() def calibrate(model, tokenizer, dataset, device, batch_size, samples=None): model.eval() total = 0 for batch in dataset.iter(batch_size=batch_size): inputs = tokenizer(batch["text"], return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) model(input_ids, attention_mask=attention_mask) total += input_ids.size(0) if samples is not None and total >= samples: break def keyword_to_itype(k): return { "none": None, "int4": qint4, "int8": qint8, "float8": qfloat8, }[k] def main(): parser = argparse.ArgumentParser(description="Transformers Causal LM Example") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--model", type=str, default="facebook/opt-350m", help="The name of the trained Model.", ) parser.add_argument("--prompt", type=str, default="One of my fondest memory is", help="The generation prompt.") parser.add_argument("--max_new_tokens", type=int, default=20, help="The maximum number of tokens to generate.") parser.add_argument("--batch_size", type=int, default=32, help="The batch_size for evaluation (and calibration).") parser.add_argument("--validation_batch", type=int, default=4, help="The number of batch to use for calibration.") parser.add_argument( "--load_dtype", type=str, default="float16", choices=["float16", "float32", "bfloat16"], help="Precision to load the initial model", ) parser.add_argument( "--weights", type=str, default="int8", choices=["int4", "int8", "float8"], ) parser.add_argument( "--activations", type=str, default="int8", choices=["none", "int8", "float8"], ) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") parser.add_argument( "--no-streamline", action="store_false", help="Do not remove consecutive quantize/dequantize (not recommended).", ) parser.add_argument( "--debug", action="store_true", help="Provide detailed feedback on the console during calibration." ) args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) torch_dtype = ( torch.float16 if args.load_dtype == "float16" else torch.bfloat16 if args.load_dtype == "bfloat16" else torch.float32 ) model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True).to( device ) tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" cal_dataset = load_dataset("lambada", split=["validation"])[0] print(f"{args.model} (w: {args.weights}, a: {args.activations})") weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) qmodel = QuantizedModelForCausalLM.quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") cal_dataset.shuffle(args.seed) with Calibration(streamline=args.no_streamline, debug=args.debug): cal_samples = args.batch_size * args.validation_batch calibrate(qmodel, tokenizer, cal_dataset, device, args.batch_size, samples=cal_samples) generate(qmodel, tokenizer, device, args.prompt, args.max_new_tokens) if __name__ == "__main__": main() ================================================ FILE: examples/speech/speech_recognition/quantize_asr_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # REQUIRES: librosa, soundfile import argparse import io import time from functools import partial import evaluate import numpy as np import torch from datasets import load_dataset from evaluate import load from transformers import WhisperForConditionalGeneration, WhisperProcessor from optimum.quanto import Calibration, freeze, qint4, qint8, quantize def map_to_feats(batch, processor): audio = batch["audio"] input_features = processor( audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt" ).input_features batch["input_features"] = input_features batch["reference"] = processor.tokenizer.normalize(batch["text"]) return batch def transcribe_batch(batch, model, processor): with torch.no_grad(): features = torch.from_numpy(np.array(batch["input_features"], dtype=np.float32)).squeeze(1) predicted_ids = model.generate(features.to(model.device)) transcription = [processor.decode(ids) for ids in predicted_ids] batch["prediction"] = [processor.tokenizer.normalize(x) for x in transcription] return batch def evaluate_model(model, processor, dataset, metric: evaluate.EvaluationModule, batch_size=10): map_fn = partial(transcribe_batch, model=model, processor=processor) start = time.time() result = dataset.map(map_fn, batched=True, batch_size=batch_size) end = time.time() score = 100 * metric.compute(references=result["reference"], predictions=result["prediction"]) print(score) print(f"{len(result)} sentences evaluated in {end - start:.2f} s. {metric.name} = {score}") def keyword_to_itype(k): return {"none": None, "int8": qint8, "int4": qint4}[k] def main(): parser = argparse.ArgumentParser(description="Transformers Whisper Example") parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--model", type=str, default="openai/whisper-medium", help="The name of the trained Model.", ) parser.add_argument( "--samples", type=int, default=872, help="The number of librispeech samples to use for evaluation." ) parser.add_argument("--batch_size", type=int, default=10, help="The batch size to use for evaluation.") parser.add_argument("--weights", type=str, default="int8", choices=["int4", "int8"]) parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8"]) parser.add_argument("--device", type=str, default=None, help="The device to use for evaluation.") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") print("USING CUDA") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print("USING CPU") else: device = torch.device(args.device) model = WhisperForConditionalGeneration.from_pretrained(args.model).to(device) model.config.forced_decoder_ids = None processor = WhisperProcessor.from_pretrained(args.model) dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") processed_dataset = dataset.map(lambda x: map_to_feats(x, processor)) wer = load("wer") print("Float model:") evaluate_model(model, processor, processed_dataset, wer, args.batch_size) weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") with Calibration(): evaluate_model(model, processor, processed_dataset, wer, args.batch_size) freeze(model) print(f"Quantized model (w: {args.weights}, a: {args.activations})") evaluate_model(model, processor, processed_dataset, wer, args.batch_size) b = io.BytesIO() torch.save(model.state_dict(), b) b.seek(0) state_dict = torch.load(b) model_reloaded = WhisperForConditionalGeneration.from_pretrained(args.model).to(device) quantize(model_reloaded, weights=weights, activations=activations) model_reloaded.load_state_dict(state_dict) print("Serialized quantized model") evaluate_model(model, processor, processed_dataset, wer, args.batch_size) if __name__ == "__main__": main() ================================================ FILE: examples/speech/speech_recognition/requirements.txt ================================================ transformers evaluate librosa soundfile jiwer ================================================ FILE: examples/vision/StableDiffusion/README.md ================================================ # Quantize Stable Diffusion examples ## Running locally with PyTorch ### Installing the dependencies Before running the scripts, make sure to install the library's training dependencies: **Important** To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: ```bash git clone https://github.com/huggingface/quanto cd quanto pip install -e . ``` Then cd in the `examples/vision/StableDiffusion` folder and run ```bash pip install -r requirements.txt ``` **Now, we can launch the image generation script:** ```bash python quantize_StableDiffusion.py --batch_size=1 --torch_dtype="fp32" ``` To better track our training experiments, we're using the following flags in the command above: * `batch_size` Batch size is the number of samples used in one iteration of training. * `torch_dtype` {fp32,fp16,bf16} * `unet_qtype` {fp8,int8,int4,none} Our experiments were conducted on a single 24GB A10 GPU. ```bash fp16-fp16 batch_size: 1, torch_dtype: fp16, unet_dtype: none  in 3.307 seconds.Memory: 3.192GB. ``` ```bash bf16-int8 batch_size: 1, torch_dtype: bf16, unet_dtype: int8  in 3.918 seconds.Memory: 2.644GB. ``` ```bash fp16-int8 batch_size: 1, torch_dtype: fp16, unet_dtype: int8  in 3.920 seconds.Memory: 2.634GB. ``` will both get high-quality images at fast speed generation ================================================ FILE: examples/vision/StableDiffusion/quantize_StableDiffusion.py ================================================ import argparse import gc import torch import torch.utils.benchmark as benchmark from diffusers import DiffusionPipeline from optimum.quanto import freeze, qfloat8, qint4, qint8, quantize CKPT = "runwayml/stable-diffusion-v1-5" NUM_INFERENCE_STEPS = 50 WARM_UP_ITERS = 5 PROMPT = "ghibli style, a fantasy landscape with castles" TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} UNET_QTYPES = { "fp8": qfloat8, "int8": qint8, "int4": qint4, "none": None, } def load_pipeline(torch_dtype, unet_dtype=None, device="cpu"): pipe = DiffusionPipeline.from_pretrained(CKPT, torch_dtype=torch_dtype, use_safetensors=True).to(device) if unet_dtype: quantize(pipe.unet, weights=unet_dtype) freeze(pipe.unet) pipe.set_progress_bar_config(disable=True) return pipe def run_inference(pipe, batch_size=1): _ = pipe( prompt=args.prompt, num_inference_steps=args.num_inference_steps, num_images_per_prompt=args.batch_size, generator=torch.manual_seed(0), ) def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) return f"{(t0.blocked_autorange().mean):.3f}" def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}" def get_device_memory(device): gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return torch.cuda.memory_allocated() elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() elif device.type == "xpu": torch.xpu.empty_cache() return torch.xpu.memory_allocated() return None if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, default="ghibli style, a fantasy landscape with castles") parser.add_argument("--output_path", type=str, default=None) parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps") parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys())) parser.add_argument("--unet_qtype", type=str, default=None, choices=list(UNET_QTYPES.keys())) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) pipeline = load_pipeline( TORCH_DTYPES[args.torch_dtype], UNET_QTYPES[args.unet_qtype] if args.unet_qtype else None, device ) for _ in range(WARM_UP_ITERS): run_inference(pipeline, args.batch_size) time = benchmark_fn(run_inference, pipeline, args.batch_size) if device.type == "cuda": memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. elif device.type == "xpu": memory = bytes_to_giga_bytes(torch.xpu.max_memory_allocated()) # in GBs. else: memory = 0 get_device_memory(device) print( f"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_qtype} in {time} seconds." ) print(f"Memory: {memory}GB.") img_name = f"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_qtype}.png" image = pipeline( prompt=args.prompt, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=args.batch_size, ).images[0] image.save(img_name) ================================================ FILE: examples/vision/StableDiffusion/requirements.txt ================================================ quanto diffusers torch transformers accelerate wandb ================================================ FILE: examples/vision/image-classification/mnist/quantize_mnist_model.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import time from tempfile import NamedTemporaryFile import torch import torch.nn.functional as F from accelerate import init_empty_weights from safetensors.torch import load_file, save_file from torchvision import datasets, transforms from transformers import AutoConfig, AutoModel from optimum.quanto import ( Calibration, QTensor, freeze, qfloat8, qint4, qint8, quantization_map, quantize, requantize, ) def test(model, device, test_loader): model.to(device) model.eval() test_loss = 0 correct = 0 with torch.no_grad(): start = time.time() for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) if isinstance(output, QTensor): output = output.dequantize() test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() end = time.time() test_loss /= len(test_loader.dataset) print( "\nTest set evaluated in {:.2f} s: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( end - start, test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) ) ) def train(log_interval, model, device, train_loader, optimizer, epoch): model.to(device) model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) if isinstance(output, QTensor): output = output.dequantize() loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item(), ) ) def keyword_to_itype(k): return {"none": None, "int4": qint4, "int8": qint8, "float8": qfloat8}[k] def main(): # Training settings parser = argparse.ArgumentParser(description="PyTorch MNIST Example") parser.add_argument( "--batch-size", type=int, default=250, metavar="N", help="input batch size for testing (default: 250)" ) parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument("--model", type=str, default="dacorvo/mnist-mlp", help="The name of the trained Model.") parser.add_argument("--weights", type=str, default="int8", choices=["int4", "int8", "float8"]) parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8", "float8"]) parser.add_argument("--device", type=str, default=None, help="The device to use for evaluation.") args = parser.parse_args() torch.manual_seed(args.seed) if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) dataset_kwargs = {"batch_size": args.batch_size} if torch.cuda.is_available() or torch.xpu.is_available(): backend_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} dataset_kwargs.update(backend_kwargs) transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda(lambda x: torch.flatten(x)), ] ) dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **dataset_kwargs) dataset2 = datasets.MNIST("./data", train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(dataset2, **dataset_kwargs) model = AutoModel.from_pretrained(args.model, trust_remote_code=True) model.eval() print("Float model") test(model, device, test_loader) weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") with Calibration(): test(model, device, test_loader) print(f"Quantized model (w: {args.weights}, a: {args.activations})") test(model, device, test_loader) print("Tuning quantized model for one epoch") optimizer = torch.optim.Adadelta(model.parameters(), lr=0.5) train(50, model, device, train_loader, optimizer, 1) print("Quantized tuned model") test(model, device, test_loader) print("Quantized frozen model") freeze(model) test(model, device, test_loader) # Serialize model to a state_dict, save it to disk and reload it with NamedTemporaryFile() as tmp_file: save_file(model.state_dict(), tmp_file.name) state_dict = load_file(tmp_file.name) model_reloaded = AutoModel.from_pretrained(args.model, trust_remote_code=True) # Create an empty model config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) with init_empty_weights(): model_reloaded = AutoModel.from_config(config, trust_remote_code=True) # Requantize it using the serialized state_dict requantize(model_reloaded, state_dict, quantization_map(model), device) print("Serialized quantized model") test(model_reloaded, device, test_loader) if __name__ == "__main__": main() ================================================ FILE: examples/vision/image-classification/pets/quantize_vit_model.py ================================================ import argparse import time from tempfile import NamedTemporaryFile import torch import torch.nn.functional as F from accelerate import init_empty_weights from datasets import load_dataset from safetensors.torch import load_file, save_file from transformers import ( ViTConfig, ViTForImageClassification, ViTImageProcessor, ) from optimum.quanto import ( Calibration, QTensor, freeze, qfloat8, qint4, qint8, quantization_map, quantize, requantize, ) def test(model, device, test_loader): model.to(device) model.eval() test_loss = 0 correct = 0 with torch.no_grad(): start = time.time() for batch in test_loader: data, target = batch["pixel_values"], batch["labels"] data, target = data.to(device), target.to(device) output = model(data).logits if isinstance(output, QTensor): output = output.dequantize() test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() end = time.time() test_loss /= len(test_loader.dataset) print( "\nTest set evaluated in {:.2f} s: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( end - start, test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) ) ) def keyword_to_itype(k): return {"none": None, "int4": qint4, "int8": qint8, "float8": qfloat8}[k] def main(): parser = argparse.ArgumentParser(description="ViT PETS Example") parser.add_argument("--model", type=str, default="super-j/vit-base-pets") parser.add_argument("--device", type=str, default=None, help="The device to use for evaluation.") parser.add_argument("--weights", type=str, default="int8", choices=["int4", "int8", "float8"]) parser.add_argument("--activations", type=str, default="int8", choices=["none", "int8", "float8"]) args = parser.parse_args() dataset_kwargs = {} if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} dataset_kwargs.update(cuda_kwargs) elif all([torch.backends.mps.is_available(), args.weights != "float8", args.activations != "float8"]): device = torch.device("mps") else: device = torch.device("cpu") else: device = torch.device(args.device) # load the processor and model model_name = args.model processor = ViTImageProcessor.from_pretrained(model_name) model = ViTForImageClassification.from_pretrained(model_name) def transform(data_batch): # Take a list of PIL images and turn them to pixel values inputs = processor(data_batch["image"], return_tensors="pt") # Don't forget to include the labels! inputs["labels"] = data_batch["label"] return inputs ds = load_dataset("rokmr/pets") prepared_ds = ds.with_transform(transform) test_loader = torch.utils.data.DataLoader(prepared_ds["test"], **dataset_kwargs) print("Model before quantization...") test(model, device, test_loader) weights = keyword_to_itype(args.weights) activations = keyword_to_itype(args.activations) quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating ...") with Calibration(): test(model, device, test_loader) print(f"Quantized model (w: {args.weights}, a: {args.activations})") test(model, device, test_loader) print("Quantized frozen model") freeze(model) test(model, device, test_loader) # Serialize model to a state_dict, save it to disk and reload it with NamedTemporaryFile() as tmp_file: save_file(model.state_dict(), tmp_file.name) state_dict = load_file(tmp_file.name) model_reloaded = ViTForImageClassification.from_pretrained(model_name) # Create an empty model config = ViTConfig.from_pretrained(model_name) with init_empty_weights(): model_reloaded = ViTForImageClassification.from_pretrained(model_name, config=config) # Requantize it using the serialized state_dict requantize(model_reloaded, state_dict, quantization_map(model), device) print("Serialized quantized model") test(model_reloaded, device, test_loader) if __name__ == "__main__": main() ================================================ FILE: examples/vision/object-detection/quantize_owl_model.py ================================================ import argparse import gc import numpy as np import requests import torch from PIL import Image from transformers import AutoProcessor, Owlv2ForObjectDetection from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from optimum.quanto import freeze, qfloat8, qint4, qint8, quantize def detect(model, processor, image, texts): inputs = processor(text=texts, images=image, return_tensors="pt").to(model.device) # forward pass with torch.no_grad(): outputs = model(**inputs) # Note: boxes need to be visualized on the padded, unnormalized image # hence we'll set the target image sizes (height, width) based on that def get_preprocessed_image(pixel_values): pixel_values = pixel_values.squeeze().cpu().numpy() unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[ :, None, None ] unnormalized_image = (unnormalized_image * 255).astype(np.uint8) unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) unnormalized_image = Image.fromarray(unnormalized_image) return unnormalized_image unnormalized_image = get_preprocessed_image(inputs.pixel_values) target_sizes = torch.Tensor([unnormalized_image.size[::-1]]) # Convert outputs (bounding boxes and class logits) to final bounding boxes and scores results = processor.post_process_object_detection(outputs=outputs, threshold=0.2, target_sizes=target_sizes) i = 0 # Retrieve predictions for the first image for the corresponding text queries text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] if len(boxes) == 0: print("None of the specified labels were detected") return for box, score, label in zip(boxes, scores, labels): box = [round(i, 2) for i in box.tolist()] print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") def get_device_memory(device): gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return torch.cuda.memory_allocated() elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() elif device.type == "xpu": torch.xpu.empty_cache() return torch.xpu.memory_allocated() return None def keyword_to_qtype(k): return {"none": None, "int4": qint4, "int8": qint8, "float8": qfloat8}[k] def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="google/owlv2-base-patch16") parser.add_argument("--image", type=str, required=True) parser.add_argument("--texts", type=str, nargs="+", required=True) parser.add_argument("--weights", type=str, default="none", choices=["none", "int4", "int8", "float8"]) parser.add_argument("--exclude-heads", action="store_true", help="Do not quantize detection heads") parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): # MPS backend does not support torch.float64 that is required for owl models device = torch.device("cpu") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) processor = AutoProcessor.from_pretrained(args.model) model = Owlv2ForObjectDetection.from_pretrained(args.model, low_cpu_mem_usage=True).to(device) weights_qtype = keyword_to_qtype(args.weights) if weights_qtype is not None: if args.exclude_heads: quantize(model.owlv2, weights=weights_qtype) else: quantize(model, weights=weights_qtype) freeze(model) memory = get_device_memory(device) if memory is not None: memory_gb = memory / 2**30 print(f"{device.type} device memory: {memory_gb:.2f} GB.") image_path = args.image if image_path.startswith("http"): image_path = requests.get(args.image, stream=True).raw image = Image.open(image_path) texts = [args.texts] detect(model, processor, image, texts) if __name__ == "__main__": main() ================================================ FILE: examples/vision/text-to-image/quantize_pixart_sigma.py ================================================ import argparse import gc import torch from diffusers import DiffusionPipeline from optimum.quanto import freeze, qfloat8, qint4, qint8, quantize NUM_INFERENCE_STEPS = 50 TORCH_DTYPES = {"fp16": torch.float16, "bf16": torch.bfloat16} QTYPES = { "fp8": qfloat8, "int8": qint8, "int4": qint4, "none": None, } def load_pipeline(model_id, torch_dtype, qtype=None, device="cpu"): pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True).to(device) if qtype: quantize(pipe.transformer, weights=qtype) freeze(pipe.transformer) quantize(pipe.text_encoder, weights=qtype) freeze(pipe.text_encoder) pipe.set_progress_bar_config(disable=True) return pipe def get_device_memory(device): gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return torch.cuda.memory_allocated() elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() elif device.type == "xpu": torch.xpu.empty_cache() return torch.xpu.memory_allocated() return None if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default="PixArt-alpha/PixArt-Sigma-XL-2-1024-MS") parser.add_argument("--prompt", type=str, default="ghibli style, a fantasy landscape with castles") parser.add_argument("--torch_dtype", type=str, default="fp16", choices=list(TORCH_DTYPES.keys())) parser.add_argument("--qtype", type=str, default=None, choices=list(QTYPES.keys())) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") elif torch.xpu.is_available(): device = torch.device("xpu") else: device = torch.device("cpu") else: device = torch.device(args.device) pipeline = load_pipeline( args.model_id, TORCH_DTYPES[args.torch_dtype], QTYPES[args.qtype] if args.qtype else None, device ) print(f"torch_dtype: {args.torch_dtype}, qtype: {args.qtype}.") memory = get_device_memory(device) if memory is not None: memory_gb = memory / 2**30 print(f"{device.type} device memory: {memory_gb:.2f} GB.") if args.qtype == "int4" and device.type == "CUDA": raise ValueError("This example does not work (yet) for int4 on CUDA") img_name = f"pixart-sigma-dtype@{args.torch_dtype}-qtype@{args.qtype}.png" image = pipeline( prompt=args.prompt, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=1, generator=torch.manual_seed(0), ).images[0] image.save(img_name) ================================================ FILE: external/awq/conftest.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch devices = ["cpu"] if torch.cuda.is_available(): devices += ["cuda"] elif torch.backends.mps.is_available(): devices += ["mps"] @pytest.fixture(scope="module", params=devices) def device(request): return torch.device(request.param) def pytest_configure(config): # register additional markers config.addinivalue_line("markers", "skip_device(type): mark test to be skipped for the specified device type") def pytest_runtest_call(item): fixture_name = "device" if fixture_name in item.fixturenames: # TODO: should be able to recover the fixture id instead of the actual value fixture_arg = item.funcargs[fixture_name].type skip_marks = {mark.args[0] for mark in item.iter_markers(name=f"skip_{fixture_name}")} if fixture_arg in skip_marks: pytest.skip(f"Test skipped for {fixture_name} {fixture_arg}") ================================================ FILE: external/awq/pack_intweight.py ================================================ # MIT License # # Copyright (c) 2023 MIT HAN Lab # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch def pack_intweight(unpacked_qweight, interleave, kstride): # unpacked_qweight: [N, K] N = unpacked_qweight.shape[0] K = unpacked_qweight.shape[1] Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) # reorder each 8 weights for fast dequantization # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) Packed_Kernel = Packed_Kernel.reshape(N, K) # interleaving every four rows Packed_Kernel = Packed_Kernel.reshape( N // interleave, interleave, K // kstride, kstride ) # N // 4, K // 64, 4, 64 Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) Packed_Kernel = Packed_Kernel.reshape( N // interleave, K // kstride, kstride, interleave ) # Packing -> (N // 4, K // 64, 64) Packed_Kernel = ( Packed_Kernel[..., 0] | (Packed_Kernel[..., 1] << 4) | (Packed_Kernel[..., 2] << 8) | (Packed_Kernel[..., 3] << 12) ) # reshape to (N // 4, K), FP16 format Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) qweight = ( torch.tensor(Packed_Kernel.astype("int16")) .to(unpacked_qweight.device) .contiguous() ) return qweight ================================================ FILE: external/awq/packing_utils.py ================================================ import torch AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] def pack_awq(intweight: torch.Tensor, reorder=False): bits = 4 pack_num = 32 // bits qweight = torch.zeros(intweight.shape[0], intweight.shape[1] // pack_num, dtype=torch.int32, device=intweight.device) for col in range(intweight.shape[1] // pack_num): if reorder: order_map = [0, 2, 4, 6, 1, 3, 5, 7] else: order_map = [0, 1, 2, 3, 4, 5, 6, 7] for i in range(pack_num): qweight_col = intweight[:, col * pack_num + order_map[i]] qweight[:, col] |= qweight_col << (i * bits) return qweight def unpack_awq(qweight: torch.Tensor, bits: int): shifts = torch.arange(0, 32, bits, device=qweight.device) # unpacking columnwise iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( torch.int8 # smallest dtype available ) iweights = iweights.view(iweights.shape[0], -1) return iweights def reverse_awq_order(iweights: torch.Tensor, bits: int): reverse_order_tensor = torch.arange( iweights.shape[-1], dtype=torch.int32, device=iweights.device, ) reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] reverse_order_tensor = reverse_order_tensor.view(-1) iweights = iweights[:, reverse_order_tensor] return iweights def pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int): shifts = torch.arange(0, 32, bits, device=iweights.device) # packing rowwise iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1) qweight = ( torch.bitwise_left_shift(iweights, shifts[None, :, None]) .sum(dim=1) .to(torch.int32) ) # packing columnwise izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits) qzeros = ( torch.bitwise_left_shift(izeros, shifts[None, None, :]) .sum(dim=-1) .to(torch.int32) ) return qweight, qzeros def unpack_reorder_pack(qweight, qzeros, bits): # Unpack the qweight and qzeros tensors iweight, izeros = unpack_awq(qweight, qzeros, bits) # Reverse the order of the iweight and izeros tensors iweight, izeros = reverse_awq_order(iweight, izeros, bits) # overflow checks iweight = torch.bitwise_and(iweight, (2**bits) - 1) izeros = torch.bitwise_and(izeros, (2**bits) - 1) # Subtract 1 from the izeros tensor (exllama adds 1 during inference) # We can remove it if we remove the +1 in the exllama code izeros = izeros - 1 # Pack the qweight and qzeros tensors qweight, qzeros = pack_exllama(iweight, izeros, bits) return qweight, qzeros def dequantize_gemm(qweight, qzeros, scales, bits, group_size): # Unpack the qweight and qzeros tensors iweight, izeros = unpack_awq(qweight, qzeros, bits) # Reverse the order of the iweight and izeros tensors iweight, izeros = reverse_awq_order(iweight, izeros, bits) # overflow checks iweight = torch.bitwise_and(iweight, (2**bits) - 1) izeros = torch.bitwise_and(izeros, (2**bits) - 1) # fp16 weights scales = scales.repeat_interleave(group_size, dim=0) izeros = izeros.repeat_interleave(group_size, dim=0) iweight = (iweight - izeros) * scales return iweight ================================================ FILE: external/awq/test_awq_kernels.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from pack import pack_awq from optimum.quanto import AffineQuantizer, MaxOptimizer, qint4, ungroup def assert_similar(a, b, atol=None, rtol=None): """Verify that the cosine similarity of the two inputs is close to 1.0 everywhere""" assert a.dtype == b.dtype assert a.shape == b.shape if atol is None: # We use torch finfo resolution atol = torch.finfo(a.dtype).resolution if rtol is None: # Please refer to that discussion for default rtol values based on the float type: # https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype] sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0) if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol): max_deviation = torch.min(sim) raise ValueError(f"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features, out_features", [(256, 256), (512, 256)]) @pytest.mark.parametrize("kernel", ["gemv", "gemm"]) def test_standalone_kernel(in_features, out_features, kernel): """This test verifies that the GEMM operation is equivalent to torch.mm. """ bits = 4 group_size = 128 # Hard-coded in kernels interleave = 4 # Hard-coded in kernels kstride = 64 # Hard-coded in kernels device = torch.device('cuda') batch_size, tokens = (4, 1) if kernel =="gemv" else (10, 128) input_shape = (batch_size, tokens, in_features) # FIXME: does not work if inputs are negative !!?? inputs = torch.rand(input_shape, dtype=torch.float16, device=device) qmax = 2**bits other_shape = (out_features, in_features) other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device) #packed_other_data = pack_intweight(other_data.to(torch.int32), interleave=interleave, kstride=kstride) packed_other_data = pack_awq(other_data.to(torch.int32), interleave=interleave, kstride=kstride) # The GEMM kernel works on transposed scales scales_shape = (in_features // group_size, out_features) other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax # The GEMM kernel works on transposed, negated and scaled zeropoints qmin = -2**(bits -1) qmax = 2**(bits -1) other_zeropoints = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device) # Negate and scale other_scaled_zeropoints = - other_zeropoints * other_scales # Evaluate mm outputs using the GEMM kernel if kernel == "gemv": awq_outputs = torch.ops.quanto.gemv(inputs, packed_other_data, other_scales, other_scaled_zeropoints, rows=inputs.numel() // inputs.shape[-1], out_cols=out_features, in_cols=in_features, bits=4, group_size=group_size) else: awq_outputs = torch.ops.quanto.gemm(inputs, packed_other_data, other_scales, other_scaled_zeropoints, rows=inputs.numel() // inputs.shape[-1], out_cols=out_features, in_cols=in_features, bits=4, group_size=group_size) # Transpose other data and reshape it to align it with transposed scales and zeros other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features) # Dequantize transposed other other_t = (other_data_t - other_zeropoints) * other_scales # Reshape it as expected by the matmul other_t = other_t.reshape(in_features, out_features) # Evaluate the matrix multiplication using pytorch float16 mm pt_outputs = torch.matmul(inputs, other_t) # Verify the results are similar assert_similar(awq_outputs, pt_outputs, rtol=5e-3) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features, out_features", [(256, 256), (512, 256)]) @pytest.mark.parametrize("kernel", ["gemm", "gemv"]) def test_integrated_kernel(in_features, out_features, kernel): group_size = 128 # Hard-coded in kernels interleave = 4 # Hard-coded in kernels kstride = 64 # Hard-coded in kernels device = torch.device('cuda') batch_size, tokens = (4, 1) if kernel == "gemv" else (10, 128) input_shape = (batch_size, tokens, in_features) inputs = torch.rand(input_shape, dtype=torch.float16, device=device) * 2 - 1 other_shape = (out_features, in_features) other = torch.rand(other_shape, dtype=torch.float16, device=device) * 2 - 1 # Quantize using quanto scale, zeropoint = MaxOptimizer()(other, bits=4, axis=0, group_size=128) quanto_base = AffineQuantizer.apply(other, qint4, 0, group_size, scale, zeropoint) # Evaluate mm quanto_outputs = torch.matmul(inputs, quanto_base.t()) # Extract quantized data, unpack and ungroup to recover original shape quanto_data = ungroup(quanto_base._data.unpack(), axis=0, orig_shape=other_shape) # Pack data for AWQ kernel awq_data = pack_awq(quanto_data.to(torch.int32), interleave=interleave, kstride=kstride) # Reshape and transpose scale as expected by AWQ kernel (! buffer must be contiguous) awq_scale = scale.reshape(out_features, in_features // group_size).t().contiguous() # Reshape and transpose zeropoint as expected by AWQ kernel (! buffer must be contiguous) awq_zeropoint = zeropoint.reshape(out_features, in_features // group_size).t().contiguous() # Negate and rescale awq_scaled_zeropoint = - awq_zeropoint * awq_scale # Evaluate mm outputs using the AWQ kernels if kernel == "gemv": awq_outputs = torch.ops.quanto.gemv(inputs, awq_data, awq_scale, awq_scaled_zeropoint, rows=inputs.numel() // inputs.shape[-1], out_cols=out_features, in_cols=in_features, bits=4, group_size=group_size) else: awq_outputs = torch.ops.quanto.gemm(inputs, awq_data, awq_scale, awq_scaled_zeropoint, rows=inputs.numel() // inputs.shape[-1], out_cols=out_features, in_cols=in_features, bits=4, group_size=group_size) # Verify the results are similar assert_similar(awq_outputs, quanto_outputs, rtol=5e-3) ================================================ FILE: external/awq/test_awq_packing.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest import torch from pack_intweight import pack_intweight from packing_utils import pack_awq, reverse_awq_order, unpack_awq from optimum.quanto import AWQPackedTensor, AWQPacking @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("reorder", [True, False]) @pytest.mark.parametrize("random", [True, False]) def test_awq_pack(in_features, out_features, reorder, random): """This test verifies two things: - that we are able to replicate awq packing, - that we can unpack awq packed tensors and recover the original tensor. """ bits = 4 interleave = 4 kstride = 64 qmax = 2**bits shape = (out_features, in_features) device = torch.device('cuda') if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) packed = pack_awq(t.to(torch.int32), reorder=reorder) # Sanity check: verify we can recover the Tensor using AWQ unpacking unpacked = unpack_awq(packed, bits=4) if reorder: unpacked = reverse_awq_order(unpacked, bits=4) unpacked = torch.bitwise_and(unpacked, qmax - 1) assert torch.equal(t, unpacked) # Compare with quanto packing repacked = AWQPackedTensor.pack(t, packing=AWQPacking.V1, reorder=reorder) assert torch.equal(packed, repacked._data) unpacked = repacked.unpack() assert torch.equal(unpacked, t) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("random", [True, False]) def test_awq_pack_v2(in_features, out_features, random): """This test verifies two things: - that we are able to replicate awq packing, - that we can unpack awq packed tensors and recover the original tensor. """ bits = 4 interleave = 4 kstride = 64 qmax = 2**bits shape = (out_features, in_features) device = torch.device('cuda') if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) packed = pack_intweight(t.to(torch.int32), interleave=interleave, kstride=kstride) # Compare with quanto packing repacked = AWQPackedTensor.pack(t, packing=AWQPacking.V2) assert torch.equal(packed, repacked._data) unpacked = repacked.unpack() assert torch.equal(unpacked, t) ================================================ FILE: external/awq/test_awq_quantize.py ================================================ import pytest import torch from optimum.quanto import AffineQuantizer, MaxOptimizer, qint4, ungroup def awq_quantize(base, scales, zeros, group_size): _, in_features = base.shape scale_zeros = scales * zeros intweight = [] # From https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear/gemv_fast.py#L165 for idx in range(in_features): intweight.append( torch.round( (base[:, idx] + scale_zeros[:, idx // group_size]) / scales[:, idx // group_size] ).to(torch.uint8)[:, None] ) intweight = torch.cat(intweight, dim=1) return intweight @pytest.mark.parametrize("in_features, out_features", [(256, 512), (1024, 1024)]) def test_awq_quantize(in_features, out_features): """Verify that AWQ quantization is equivalent to quanto affine quantization """ shape = (out_features, in_features) base = torch.rand(shape, dtype=torch.float16) group_size = 128 # Quantize using quanto scale, zeropoint = MaxOptimizer()(base, bits=4, axis=0, group_size=128) quanto_base = AffineQuantizer.apply(base, qint4, 0, group_size, scale, zeropoint) # Extract quantized data, unpack and ungroup to recover original shape quanto_data = ungroup(quanto_base._data.unpack(), axis=0, orig_shape=shape) # Reshape scale and zeropoint as expected by awq awq_shape = (out_features, in_features // group_size) scale = scale.reshape(awq_shape) zeropoint = zeropoint.reshape(awq_shape) # Compare with awq quantization awq_data = awq_quantize(base, scale, zeropoint, group_size) # FIX: AWQ does not clamp values before packing qmax = 2 ** 4 - 1 awq_data = torch.clamp(awq_data, 0, qmax) mismatches = quanto_data != awq_data n = torch.sum(mismatches).numpy() rate = n / base.numel() print(f"Mismatches: {n}/{base.numel()} ({rate:.8f} %)") # Extract mismatches display = 10 quanto_values = torch.masked_select(quanto_data, mismatches)[:display] awq_values = torch.masked_select(awq_data, mismatches)[:display] print(f"First {display} mismatches") print(list(quanto_values.numpy())) print(list(awq_values.numpy())) # Due to a slightly different order of operations (zero is multiplied by scale before subtracting it), # there are some mismatches assert rate < 5e-4 ================================================ FILE: external/smoothquant/README.md ================================================ # SmoothQuant original conversion script This converts an OPT or Bloom [🤗 transformers](https://github.com/huggingface/transformers) model to a "smoothed" version, as described in [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/abs/2211.10438). ```bash $ python smoothquant.py --model facebook/opt-1.3b --save-path smoothed-models/facebook/opt-1.3b ``` Note: due to hard-coded assumptions on model architecture in the script this only works for OPT models that apply the layer_norm before the attention (`do_layer_norm_before=true` in `config.json`). This means all models but `facebook/opt-350m`. ================================================ FILE: external/smoothquant/smoothquant.py ================================================ import argparse import functools import os import torch import torch.nn as nn from datasets import load_dataset from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.models.bloom.modeling_bloom import BloomBlock from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralRMSNorm from transformers.models.opt.modeling_opt import OPTDecoderLayer def get_act_scales(model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device act_scales = {} def stat_tensor(name, tensor): hidden_dim = tensor.shape[-1] tensor = tensor.view(-1, hidden_dim).abs().detach() comming_max = torch.max(tensor, dim=0)[0].float().cpu() if name in act_scales: act_scales[name] = torch.max(act_scales[name], comming_max) else: act_scales[name] = comming_max def stat_input_hook(m, x, y, name): if isinstance(x, tuple): x = x[0] stat_tensor(name, x) hooks = [] for name, m in model.named_modules(): if isinstance(m, nn.Linear): hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name))) for i in tqdm(range(num_samples)): input_ids = tokenizer( dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True ).input_ids.to(device) model(input_ids) for h in hooks: h.remove() return act_scales @torch.no_grad() def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5): if not isinstance(fcs, list): fcs = [fcs] assert isinstance(ln, (nn.LayerNorm, LlamaRMSNorm, MistralRMSNorm)) for fc in fcs: assert isinstance(fc, nn.Linear) assert ln.weight.numel() == fc.in_features == act_scales.numel() device, dtype = fcs[0].weight.device, fcs[0].weight.dtype act_scales = act_scales.to(device=device, dtype=dtype) weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) ln.weight.div_(scales) if getattr(ln, 'bias', None) is not None: ln.bias.div_(scales) for fc in fcs: fc.weight.mul_(scales.view(1, -1)) @torch.no_grad() def smooth_lm(model, scales, alpha=0.5): for name, module in model.named_modules(): if isinstance(module, OPTDecoderLayer): attn_ln = module.self_attn_layer_norm qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] qkv_input_scales = scales[name + ".self_attn.q_proj"] smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) ffn_ln = module.final_layer_norm fc1 = module.fc1 fc1_input_scales = scales[name + ".fc1"] smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) elif isinstance(module, BloomBlock): attn_ln = module.input_layernorm qkv = module.self_attention.query_key_value qkv_input_scales = scales[name + ".self_attention.query_key_value"] smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) ffn_ln = module.post_attention_layernorm fc1 = module.mlp.dense_h_to_4h fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"] smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha) elif isinstance(module, (LlamaDecoderLayer, MistralDecoderLayer)): attn_ln = module.input_layernorm qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] qkv_input_scales = scales[name + ".self_attn.q_proj"] smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) ffn_ln = module.post_attention_layernorm fc = [module.mlp.gate_proj, module.mlp.up_proj] fc_input_scales = scales[name + ".mlp.gate_proj"] smooth_ln_fcs(ffn_ln, fc, fc_input_scales, alpha) def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="facebook/opt-125m", help="model name") parser.add_argument("--save-path", type=str, default=None, help="smoothed model model save path") parser.add_argument("--num-samples", type=int, default=512) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--device", type=str, default=None, help="The device to use for generation.") args = parser.parse_args() if args.device is None: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") else: device = torch.device(args.device) dataset = load_dataset("lambada", split=f"validation[:{args.num_samples}]").shuffle() tokenizer = AutoTokenizer.from_pretrained(args.model, model_max_length=args.seq_len) model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto").to(device) act_scales = get_act_scales(model, tokenizer, dataset, args.num_samples, args.seq_len) smooth_lm(model, act_scales, 0.5) save_path = args.save_path if save_path is None: save_path = os.path.join("smoothed_models", args.model) model.save_pretrained(save_path) tokenizer.save_pretrained(save_path) if __name__ == "__main__": main() ================================================ FILE: optimum/quanto/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __version__ = "0.2.7dev" from .calibrate import * from .library import * from .models import * from .nn import * from .quantize import * from .tensor import * ================================================ FILE: optimum/quanto/calibrate.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from torch.nn.modules.module import ( register_module_forward_hook, register_module_forward_pre_hook, ) from torch.overrides import TorchFunctionMode from .nn import QModuleMixin from .tensor import ActivationQBytesTensor, QTensor, axis_to_dim, dtype_info, qint8, qtype __all__ = ["Calibration", "absmax_scale"] def _updated_scale(scale, new_scale, momentum): if torch.all(scale == 1): return new_scale return momentum * scale + new_scale * (1.0 - momentum) def absmax_scale(base: torch.Tensor, qtype: qtype = qint8, axis: Optional[int] = None) -> torch.Tensor: """Evaluate the quantization scale using the absmax algorithm. The Absolute Maximum quantization algorithm is a symmetrical quantization algorithm where the scale corresponds to the maximum absolute value of the base divided by the highest positive integer value for the target integer representation. Args: base (`torch.Tensor`): the base tensor on which the scale will be applied. qtype (`quanto.qtype`): the target qtype for quantization. axis (`int`): the index of the axis to preserve, or -1 for the last one. Defaults to None to reduce all axis. Returns: `torch.Tensor`: a scale tensor of the same dtype as the base. """ base = torch.abs(base) if axis is None: qranges = torch.max(base) else: dim = axis_to_dim(base, axis) qranges = torch.amax(base, dim=dim, keepdim=True) info = dtype_info(qtype.dtype) return qranges / info.max class Calibration(TorchFunctionMode): """A custom torch dispatch mode to calibrate quantized modules. In order to improve the accuracy of the quantized activations, the input and output scales of each quantized module are evaluated per-batch using the absmax algorithm and aggregated using a momentum. The dispatch mode also tracks the calls to each torch function down the model graph, and applies optional optimizations: - streamline: do not quantize activations that are immediately consumed by an incompatible function (like `add` or `silu`). Args: momentum (`float`): the momentum to use when updating scales. streamline (`bool`): if True, avoid quantizing activations when they are consumed by an incompatible function. Defaults to True. debug (`bool`): provide very verbose feedback on the console during calibration. """ def __init__(self, *args, momentum: float = 0.9, streamline=True, debug=False, **kwargs): super().__init__(*args, **kwargs) self.momentum = momentum self.streamline = streamline if streamline: self.modules_qactivations = {} self.streamline_hooks = {} self.debug = debug def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs is not None else {} qinput = QTensor in types output = func(*args, **kwargs) if self.streamline and qinput: for i, arg in enumerate(args): module = getattr(arg, "src_module", None) if module is not None: if isinstance(output, ActivationQBytesTensor): # Quantized activations are required for that module self.modules_qactivations[module] = True elif isinstance(output, torch.Tensor): # Quantized activations are not required for that module unless another function requires them qactivations_required = self.modules_qactivations.get(module, False) self.modules_qactivations[module] = qactivations_required return output def __enter__(self): super().__enter__() self.pre_handle = register_module_forward_pre_hook(self.calibrate_input) self.post_handle = register_module_forward_hook(self.calibrate_output) def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) self.pre_handle.remove() self.post_handle.remove() if self.streamline: for handle in self.streamline_hooks.values(): handle.remove() def calibrate_input(self, module: torch.nn.Module, input, momentum: float = 0.9): """Calibrate a module input scale This is registered as a global hook that is called before any module forward pre hook. """ if isinstance(module, QModuleMixin) and module.activation_qtype is not None: input = input[0] if isinstance(input, ActivationQBytesTensor): # Just adopt the maximum scale of the input module.input_scale = torch.max(input._scale) else: # Evaluate the best scale input_scale = absmax_scale(input, module.activation_qtype) module.input_scale = _updated_scale(module.input_scale, input_scale, momentum) if self.streamline and module not in self.streamline_hooks: # Add a hook to tag the module outputs (after the module quantization hook in QModuleMixin) self.streamline_hooks[module] = module.register_forward_hook(self.tag_outputs) return input def calibrate_output( self, module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, ): """Calibrate a module output scale This is registered as a global hook that is called before any module forward hook. When the module is a QModuleMixin, its outputs are not quantized yet because they are only quantized in the QModuleMixin.quantize_output forward hook. """ if isinstance(module, (QModuleMixin)) and module.activation_qtype is not None: # Evaluate the optimal scale per-tensor and update output scale output_scale = absmax_scale(output, module.activation_qtype, axis=None) module.output_scale = _updated_scale(module.output_scale, output_scale, self.momentum) return output else: if self.streamline: for name, child in module.named_children(): if isinstance(child, QModuleMixin) and child.activation_qtype is not None: qactivations_required = self.modules_qactivations.get(child, False) if not qactivations_required: # Disable output quantization for this child as its outputs are only consumed by incompatible functions. child.disable_output_quantization() if self.debug: for name, child in module.named_children(): if isinstance(child, QModuleMixin): classname = child.__class__.__name__ trace = f"{name}({classname}) activations are" if child.activation_qtype is None: trace += " not quantized." else: trace += f" quantized to {child.activation_qtype} with scale {child.output_scale}." print(trace) def tag_outputs( self, module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, ): """Mark outputs as generated by a module This is called as a module forward hook that is called after the QModuleMixin.quantize_output forward hook. This is useful in streamline mode to identify the module that generated a specific QTensor. """ output.src_module = module ================================================ FILE: optimum/quanto/library/README.md ================================================ # Quanto operations library This contains the `quanto::` operations, available in python under `torch.ops.quanto`. To add a new operation: - add a definition for the operation in `library/ops.py`, - provide a default implementation using pytorch operators only under `library/python`, - provide optimized kernels for all devices under `library/ext`. ================================================ FILE: optimum/quanto/library/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .extensions import * from .qbytes_mm import * from .quantize import * from .unpack import * ================================================ FILE: optimum/quanto/library/extensions/README.md ================================================ # Quanto library extensions This folder contains device-specific `quanto::` operations. Implementations can be provided as part of: - the generic C++ pytorch extension under `cpp`, - the CUDA extension under `cuda`, - the Metal Performance Shader extension under `mps`, - the XPU SYCL extension under `xpu`. To provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax: ```python @torch.library.impl("quanto::unpack", ["CPU", "CUDA"]) def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: return ext.unpack(t, bits) ``` To declare a new device-specific operation, you need to add it to the library: ```python torch.library.define( "quanto::gemm_f16i4", "(Tensor input," " Tensor other," " Tensor other_scale," " Tensor other_shift," " int group_size)" " -> Tensor", ) ``` Then you can provide its implementation: ```python @torch.library.impl("quanto::gemm_f16i4", ["CUDA"]) def gemm_f16i4( input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, shift: torch.Tensor, group_size: int, ) -> torch.Tensor: ... ``` Please refer to each extension folder for examples. ================================================ FILE: optimum/quanto/library/extensions/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import platform import torch from packaging import version from .cpp import * from .extension import * if torch.cuda.is_available() and platform.system() == "Linux": if torch.version.cuda: from .cuda import * elif torch.version.hip: from .hip import * if torch.backends.mps.is_available(): from .mps import * def _is_xpu_available(): # SYCL extension support is added in torch>=2.7 on Linux if platform.system() != "Linux": return False if version.parse(torch.__version__).release < version.parse("2.7").release: return False return torch.xpu.is_available() if _is_xpu_available(): from .xpu import * ================================================ FILE: optimum/quanto/library/extensions/cpp/README.md ================================================ # Quanto generic C++ extension Kernels in this extension must use only the C++ syntax. They can use any pytorch operation defined under `aten::` or `c10::`. To add a new implementation for an operation defined in `library./ops.py`: - add the corresponding `.cpp` file to the list of sources in `__init__.py`, - add a binding to `pybind_module.cpp`, - provide an implementation calling the binding in `__init__.py`. ================================================ FILE: optimum/quanto/library/extensions/cpp/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from ..extension import Extension, register_extension __all__ = [] ext = Extension( "quanto_cpp", root_dir=os.path.dirname(__file__), sources=["unpack.cpp", "pybind_module.cpp"], extra_cflags=["-O3"], ) register_extension(ext) @torch.library.impl("quanto::unpack", ["CPU"]) def unpack_cpp(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) ================================================ FILE: optimum/quanto/library/extensions/cpp/pybind_module.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "unpack.h" // !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types, // and need to be explicitly converted using dedicated helpers before calling a C++ method. // As a consequence, when an operation takes such an object as parameter, instead // of creating a binding directly to the C++ method, you must create a binding to a // lambda method that converts the unmapped types and calls the C++ method. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("unpack", &unpack, "unpack"); } ================================================ FILE: optimum/quanto/library/extensions/cpp/unpack.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "unpack.h" #include static torch::Tensor unpack_4bit(torch::Tensor &t) { return torch::cat({ (t & 0x0F), (t & 0xF0).__rshift__(4) }, 0); } static torch::Tensor unpack_2bit(torch::Tensor &t) { return torch::cat({ (t & 0x03), (t & 0x0C).__rshift__(2), (t & 0x30).__rshift__(4), (t & 0xC0).__rshift__(6) }, 0); } torch::Tensor unpack(torch::Tensor &t, int bits) { TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); switch(bits) { case 4: return unpack_4bit(t); case 2: return unpack_2bit(t); default: throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); } } ================================================ FILE: optimum/quanto/library/extensions/cpp/unpack.h ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include torch::Tensor unpack(torch::Tensor &t, int bits); ================================================ FILE: optimum/quanto/library/extensions/cuda/README.md ================================================ # Quanto generic CUDA extension Kernels in this extension can use both the C++ and CUDA syntax. They can use any pytorch operation defined under `aten::` or `c10::`. To add a new implementation for an operation defined in `library./ops.py`: - add the corresponding `.cpp` or `.cu` file to the list of sources in `__init__.py`, - add a binding to `pybind_module.cpp`, - provide an implementation calling the binding in `__init__.py`. ================================================ FILE: optimum/quanto/library/extensions/cuda/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from ..extension import Extension, register_extension __all__ = [] def get_max_cuda_arch(): """Select the maximum CUDA arch supported This is a combination of the CUDA and pytorch version and all detected devices capabilities. """ capability_list = [] supported_sm = [int(arch.split("_")[1]) for arch in torch.cuda.get_arch_list() if "sm_" in arch] if supported_sm: max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) for i in range(torch.cuda.device_count()): capability = torch.cuda.get_device_capability(i) # Capability of the device may be higher than what's supported by the user's # NVCC, causing compilation error. User's NVCC is expected to match the one # used to build pytorch, so we use the maximum supported capability of pytorch # to clamp the capability. capability = min(max_supported_sm, capability) if capability not in capability_list: capability_list.append(capability) max_capability = max(sorted(capability_list)) if len(capability_list) > 0 else (0, 0) return f"{max_capability[0]}{max_capability[1]}0" extra_cflags = ["-g", "-O3"] extra_cuda_cflags = [ "--expt-extended-lambda", "--use_fast_math", ] # We need to know the minimum CUDA Arch to select only the relevant kernels # but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code) quanto_cuda_arch = get_max_cuda_arch() extra_cuda_cflags += [f"-DQUANTO_CUDA_ARCH={quanto_cuda_arch}"] module_path = os.path.dirname(__file__) sources = [ "unpack.cu", "awq/v2/gemm_cuda.cu", "awq/v2/gemv_cuda.cu", "marlin/fp8_marlin.cu", "marlin/gptq_marlin_repack.cu", "marlin/marlin_cuda.cpp", "marlin/marlin_cuda_kernel.cu", "pybind_module.cpp", ] ext = Extension( "quanto_cuda", root_dir=os.path.dirname(__file__), sources=sources, extra_cflags=extra_cflags, extra_cuda_cflags=extra_cuda_cflags, ) register_extension(ext) @torch.library.impl("quanto::unpack", ["CUDA"]) def unpack_cuda(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) torch.library.define( "quanto::gemm_f16i4_awq", "(Tensor input," " Tensor other," " Tensor other_scale," " Tensor other_shift," " int rows," " int out_cols," " int in_cols," " int bits," " int group_size)" " -> Tensor", ) @torch.library.impl("quanto::gemm_f16i4_awq", ["CUDA"]) def gemm_f16i4_awq( input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, shift: torch.Tensor, rows: int, out_cols: int, in_cols: int, bits: int, group_size: int, ): assert out_cols >= 128 assert input.dtype == torch.float16 assert input.numel() == rows * in_cols assert other.dtype == torch.int16 assert scales.dtype == torch.float16 assert scales.shape[-1] == out_cols assert shift.dtype == torch.float16 assert shift.shape[-1] == out_cols assert bits == 4 assert group_size == 128 if rows < 8: return ext.lib.awq_v2_gemv_f16i4(input, other, scales, shift, rows, out_cols, in_cols, group_size) return ext.lib.awq_v2_gemm_f16i4(input, other, scales, shift) torch.library.define( "quanto::gemm_f16f8_marlin", "(Tensor a," "Tensor b_q_weight," "Tensor b_scales," "Tensor workspace," "int num_bits," "int size_m," "int size_n," "int size_k)" " -> Tensor", ) @torch.library.impl("quanto::gemm_f16f8_marlin", ["CUDA"]) def fp8_marlin_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, ) -> torch.Tensor: assert b_scales.dtype == torch.float16 or b_scales.dtype == torch.bfloat16 assert b_q_weight.dim() == 2 assert b_q_weight.dtype == torch.int32 return ext.lib.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k) torch.library.define( "quanto::pack_fp8_marlin", "(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor", ) @torch.library.impl("quanto::pack_fp8_marlin", ["CUDA"]) def gptq_marlin_repack( b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: assert b_q_weight.dim() == 2 assert b_q_weight.dtype == torch.int32 return ext.lib.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) torch.library.define( "quanto::gemm_f16i4_marlin", "(Tensor input, Tensor other, Tensor other_scale, Tensor other_shift, Tensor workspace) -> Tensor", ) @torch.library.impl("quanto::gemm_f16i4_marlin", ["CUDA"]) def gemm_f16i4_marlin( input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, zeropoint: torch.Tensor, workspace: torch.Tensor ) -> torch.Tensor: assert input.dtype == torch.float16 assert other.dtype == torch.int32 assert scales.dtype == torch.float16 assert zeropoint.dtype == torch.float16 assert workspace.dtype == torch.int32 output = torch.empty( input.shape[:-1] + (scales.shape[1],), dtype=input.dtype, device=input.device, ) ext.lib.marlin_gemm_f16i4( input.reshape((-1, input.shape[-1])), other, output.reshape((-1, output.shape[-1])), scales, zeropoint, workspace, -1, -1, -1, 16, ) return output ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/dequantize.cuh ================================================ /* Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ #include #pragma once __inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) { // uint4 result; uint32_t *h = reinterpret_cast(result); uint32_t const i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t BOTTOM_MASK = 0x000f000f; static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // elt_67 to fp16 without having to shift them to the bottom bits before hand. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // immediately before required. const uint32_t top_i4s = i4s >> 8; // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // half2 ctor. In this case, I chose performance reliability over code readability. // This is the half2 {1032, 1032} represented as an integer. // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; // This is the half2 {1 / 16, 1 / 16} represented as an integer. static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; // This is the half2 {-72, -72} represented as an integer. // static constexpr uint32_t NEG_72 = 0xd480d480; // Haotian: Let's use {-64, -64}. static constexpr uint32_t NEG_64 = 0xd400d400; // Finally, we construct the output numbers. // Convert elt_01 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); // Convert elt_23 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // Convert elt_45 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); // Convert elt_67 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); // return result; } ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/v2/gemm_cuda.cu ================================================ #include #include "semaphore.h" #include "gemm_cuda.h" #include "../dequantize.cuh" #include #include #if defined(QUANTO_CUDA_ARCH) and QUANTO_CUDA_ARCH >= 800 // The following GEMMs requires m16n8k16 which is only supported for CUDA arch after sm80 #define kInterleave 4 #define OP_M 16 #define OP_N 8 #define OP_K 16 #define INTRIN_M 16 #define INTRIN_N 16 #define INTRIN_K 16 #define WARP_SIZE 32 #define SMEM_PAD_A 0 #define SMEM_PAD_B 0 #define PACK_SIZE 8 #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) #define L2_CACHEHINT(size) ".L2::" #size "B" #else #define L2_CACHEHINT(size) #endif #define KERNEL_LAUNCH_CODE \ int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \ torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \ auto semaphores = reinterpret_cast(_semaphores.data_ptr()); \ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \ constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(half); \ if (kSmemByteSize >= 99 * 1024) \ { \ printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \ return _out_feats; \ } \ int j_factors1 = num_out_channels / CTA_N / 1; \ dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \ dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ auto kernel_func = gemm_w4a16_T1; \ cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \ kernel_func<<>>( \ in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels); template __inline__ __host__ __device__ int get_log_tile(int n) { if (N >= 8 && n >= 6) return 3; else if (N >= 4 && n >= 3) return 2; else if (N >= 2 && n >= 2) return 1; else return 0; } __inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile) { return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1))); } template __device__ void sync_slice(int slice_id) { if constexpr (SLICES == 1) { __syncthreads(); } else { constexpr int SLICE_GROUP = (SLICES + 7) / 8; constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE; const uint32_t barrier_id = slice_id / SLICE_GROUP + 1; asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads)); } } __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) { uint32_t smem_int_ptr; asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" : "=r"(smem_int_ptr) : "l"(ptr)); return smem_int_ptr; } __inline__ __device__ void ldmatrix_m8n8_x4_b16(half *shared_warp, int ax0_0, uint32_t addr) { __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.shared.b16" "{%0, %1, %2, %3}, [%4];" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) : "r"(addr)); } __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0_0, uint32_t addr) { __asm__ __volatile__( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" "{%0, %1, %2, %3}, [%4];" : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3]) : "r"(addr)); } __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask) { const int cp_size = 16; asm volatile("{" " .reg .pred p;" " setp.ne.b32 p, %0, 0;" " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;" "}" ::"r"((int)mask), "r"(smem_int_ptr), "l"(src), "n"(cp_size)); } __device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp) { __asm__ __volatile__( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};" : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3]) : "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])); } template __device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_A; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int ld_col = (threadIdx.x % threads_per_row); #pragma unroll for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); if constexpr (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); } else { if (local_mask & (ld_row + cta_offset_m < global_nrows)) *(uint4 *)dst_ptr = *src_ptr; } } } template __device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_B; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); #pragma unroll for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col = (threadIdx.x % threads_per_row); int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k); if constexpr (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask); } else { if (local_mask) *(uint4 *)dst_ptr = *src_ptr; } } } template __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask) { constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G; constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int threads_per_row = CTA_N / PACK_SIZE; constexpr int kSmemCol = CTA_N; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G; void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE); uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE); if (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask); uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z); cp_async_cg_A(addr_z, src_ptr_z, local_mask); } else { if (local_mask) { *(uint4 *)dst_ptr = *src_ptr; *(uint4 *)dst_ptr_z = *src_ptr_z; } } } template __device__ __inline__ void share_to_reg_one_stage_A(half *src, half *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k; int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); } } template __device__ __inline__ void share_to_reg_one_stage_B(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_B; int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); int c0 = ((threadIdx.x / 8) % 2) * 8; int r = r0 / 4; int c = (r0 % 4) * 16 + c0; int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; if constexpr (ldmatrix) { #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); } } #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { half scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; half zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; half2 scale2 = make_half2(scale, scale); half2 zero2 = make_half2(zero, zero); half2 loaded[4]; dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); #pragma unroll for (int i = 0; i < 4; i++) { loaded[i] = __hfma2(loaded[i], scale2, zero2); } *reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded); } } template __global__ void gemm_w4a16_T1(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K) { constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE; constexpr int SLICES = CTA_K / WARP_K; int num_blocks_n = (N + CTA_N - 1) / CTA_N; int num_blocks_m = (M + CTA_M - 1) / CTA_M; int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n); const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); blockIdx_m = block_idx_mapping.x; blockIdx_n = block_idx_mapping.y; float C_warp[CTA_M * CTA_N / CTA_SIZE_MN]; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1; constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1; constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load; extern __shared__ half mem_shared[]; half *A_shared = mem_shared; half *B_shared = mem_shared + kSmemSizeA; half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; float *C_shared = reinterpret_cast(mem_shared); half A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; half B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; half B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_n = blockIdx_n * CTA_N; int cta_offset_k = blockIdx_z * (K / SPLITK); int warp_mn = threadIdx.y % NUM_WARPS_MN; int slice_id = threadIdx.y / NUM_WARPS_MN; int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N; int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M; int warp_offset_k = slice_id * WARP_K; for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) C_warp[i] = 0.0; int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK; int k_0_0_ld = 0; int k_0_0 = 0; constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; #pragma unroll for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) { global_to_share_one_stage_A(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); global_to_share_one_stage_B(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true); global_to_share_one_stage_scales( scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N, N, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); if constexpr (STAGES > 1) __pipeline_commit(); } if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2); __syncthreads(); share_to_reg_one_stage_A(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); share_to_reg_one_stage_B(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0); constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) { int ld_stage = k_0_0_ld % STAGES; int compute_stage = k_0_0 % STAGES; half *A_shared_this_compute_stage; half *B_shared_this_compute_stage; half *scales_shared_this_compute_stage; half *zeros_shared_this_compute_stage; #pragma unroll for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) { A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage; B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage; scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N; share_to_reg_one_stage_A(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); if ((iter_k + 1) % kInterleave == 0) { if (compute_stage % 2 == 1) { share_to_reg_one_stage_B( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } else { share_to_reg_one_stage_B( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } } else { if (compute_stage % 2 == 1) { share_to_reg_one_stage_B( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } else { share_to_reg_one_stage_B( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS); } } half *A_shared_warp = A_shared_warp_[iter_k % 2]; half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); } } if (iter_k < WARP_K / INTRIN_K - 1) { if constexpr (STAGES == 1) __syncthreads(); global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); } if (iter_k == WARP_K / INTRIN_K - 2) { if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) { __syncthreads(); } global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); global_to_share_one_stage_scales( scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N, N, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); if constexpr (STAGES > 1) { __pipeline_commit(); __pipeline_wait_prior(STAGES - 2); } compute_stage = (k_0_0 + 1) % STAGES; __syncthreads(); } } } __pipeline_commit(); __pipeline_wait_prior(0); __syncthreads(); if constexpr (SLICES > 1) { #pragma unroll for (int z = 0; z < SLICES; ++z) { if (slice_id == z) { #pragma unroll for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { #pragma unroll for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { #pragma unroll for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { if (z > 0) { C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; } C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]; }; } } } __syncthreads(); } if (slice_id == 0) { #pragma unroll for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { #pragma unroll for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { #pragma unroll for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id) { C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2]; }; } } } } if (slice_id == 0) { Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x); if constexpr (SPLITK > 1) { semaphore.fetch(); } if (blockIdx_z != 0) { semaphore.wait(blockIdx_z); for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); if (write_row < M) { half2 *existing_psum_ptr = reinterpret_cast( C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2); *existing_psum_ptr = __hadd2(*existing_psum_ptr, __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id))); } }; } } } else { for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); if (write_row < M) { *reinterpret_cast( C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)); } }; } } } if constexpr (SPLITK > 1) { int lock = 0; if (SPLITK == blockIdx_z + 1) { lock = 0; } else { lock = blockIdx_z + 1; } semaphore.release(lock); } } } template __device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used; constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_A; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int ld_col = (threadIdx.x % threads_per_row); #pragma unroll for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE; void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled); uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE); if constexpr (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows)); } else { if (local_mask & (ld_row + cta_offset_m < global_nrows)) *(uint4 *)dst_ptr = *src_ptr; } } } template __device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used; constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS; constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K; constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K; constexpr int threads_per_row = CTA_K / PACK_SIZE; constexpr int kSmemCol = CTA_K + SMEM_PAD_B; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); #pragma unroll for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter) { int global_iter = shared_iter_k * partial_global_iters + _global_iter; int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row); int ld_col = (threadIdx.x % threads_per_row); int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7; void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE)); uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE); if constexpr (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask); } else { if (local_mask) *(uint4 *)dst_ptr = *src_ptr; } } } template __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *dst, half *src_z, half *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask) { constexpr int threads_needed = CTA_N / PACK_SIZE / 1; constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE; constexpr int threads_per_row = CTA_N / PACK_SIZE; bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used); int g_idx = global_iter_k * CTA_K / G; void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE); uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE); uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE); if (STAGES > 1) { uint32_t addr = cast_smem_ptr_to_uint(dst_ptr); cp_async_cg_A(addr, src_ptr, local_mask); uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z); cp_async_cg_A(addr_z, src_ptr_z, local_mask); } else { if (local_mask) { *(uint4 *)dst_ptr = *src_ptr; *(uint4 *)dst_ptr_z = *src_ptr_z; } } } template __device__ __inline__ void share_to_reg_one_stage_A_T2(half *src, half *dst, int warp_offset_m, int warp_offset_n, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_A; for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16); int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8; int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE; void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); } } template __device__ __inline__ void share_to_reg_one_stage_B_T2(half *src, half *src_scales, half *src_zeros, half *dst, half *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1) { constexpr int kSmemCol = CTA_K + SMEM_PAD_B; int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8); int c0 = ((threadIdx.x / 8) % 2) * 8; int r = r0 / 4; int c = (r0 % 4) * 16 + c0; int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE; if constexpr (ldmatrix) { #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled); uint32_t addr = cast_smem_ptr_to_uint(addr_ptr); ldmatrix_m8n8_x4_b16(dst, shared_iter, addr); } } #pragma unroll for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter) { half scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; half zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4]; half2 scale2 = make_half2(scale, scale); half2 zero2 = make_half2(zero, zero); half2 loaded[4]; dequantize_s4_to_fp16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded)); #pragma unroll for (int i = 0; i < 4; i++) { loaded[i] = __hfma2(loaded[i], scale2, zero2); } *reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded); } } template __global__ void gemm_w4a16_T2(half *__restrict__ A, half *__restrict__ B, half *__restrict__ scales, half *__restrict__ zeros, half *__restrict__ C, int M, int N, int K) { constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N; constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE; int num_blocks_n = (N + CTA_N - 1) / CTA_N; int num_blocks_m = (M + CTA_M - 1) / CTA_M; int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n); const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N); int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile); int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile); const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile); blockIdx_m = block_idx_mapping.x; blockIdx_n = block_idx_mapping.y; float C_warp[CTA_M * CTA_N / CTA_SIZE]; constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A; constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B; constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA; constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB; constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; constexpr int kSmemSizeScales = CTA_N * STAGES / 2; constexpr int scales_load_interval = G / CTA_K; extern __shared__ half mem_shared[]; half *A_shared = mem_shared; half *B_shared = mem_shared + kSmemSizeA; half *scales_shared = mem_shared + kSmemSizeA + kSmemSizeB; half *zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales; half A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE]; half B_shared_warp_[2][WARP_N * 32 / WARP_SIZE]; half B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE]; int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_n = blockIdx_n * CTA_N; int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M; int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N; for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++) C_warp[i] = 0.0; int gemm_iters = (K + CTA_K - 1) / CTA_K; int k_0_0_ld = 0; int k_0_0 = 0; constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; #pragma unroll for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld) { global_to_share_one_stage_A_T2(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); global_to_share_one_stage_B_T2(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true); global_to_share_one_stage_scales_T2( scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N, zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); if constexpr (STAGES > 1) __pipeline_commit(); } if constexpr (STAGES > 1) __pipeline_wait_prior(STAGES - 2); __syncthreads(); share_to_reg_one_stage_A_T2(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0); share_to_reg_one_stage_B_T2(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0); constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K; for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld) { int ld_stage = k_0_0_ld % STAGES; int compute_stage = k_0_0 % STAGES; half *A_shared_this_compute_stage; half *B_shared_this_compute_stage; half *scales_shared_this_compute_stage; half *zeros_shared_this_compute_stage; for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k) { A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage; B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage; scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N; zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N; share_to_reg_one_stage_A_T2(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); if ((iter_k + 1) % kInterleave == 0) { if (compute_stage % 2 == 1) { share_to_reg_one_stage_B_T2( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); } else { share_to_reg_one_stage_B_T2( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); } } else { if (compute_stage % 2 == 1) { share_to_reg_one_stage_B_T2( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); } else { share_to_reg_one_stage_B_T2( B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage, B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS); } } __syncthreads(); half *A_shared_warp = A_shared_warp_[iter_k % 2]; half *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2]; for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3) { for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4) { mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4); mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8); } } if (iter_k < WARP_K / INTRIN_K - 1) { if constexpr (STAGES == 1) __syncthreads(); global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters); } if (iter_k == WARP_K / INTRIN_K - 2) { if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2) { __syncthreads(); } global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters); global_to_share_one_stage_scales_T2( scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N, zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0); if constexpr (STAGES > 1) { __pipeline_commit(); __pipeline_wait_prior(STAGES - 2); } compute_stage = (k_0_0 + 1) % STAGES; __syncthreads(); } } } for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1) { for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1) { for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)); if (write_row < M) { *reinterpret_cast( C + write_row * N + cta_offset_n + warp_offset_n + ax1_0_1 * 16 + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) = __float22half2_rn(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)); } }; } } } torch::Tensor awq_v2_gemm_f16i4( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros) { std::vector output_shape = _in_feats.sizes().vec(); output_shape.back() = _kernel.size(0) * kInterleave; int num_in_feats = _in_feats.numel() / _in_feats.size(-1); int num_in_channels = _in_feats.size(-1); auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto kernel = reinterpret_cast(_kernel.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); auto options_int = torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device()); at::Tensor _out_feats = torch::empty(output_shape, options); int num_out_feats = _out_feats.numel() / _out_feats.size(-1); int num_out_channels = _out_feats.size(-1); auto out_feats = reinterpret_cast(_out_feats.data_ptr()); if (num_out_feats <= 32) { constexpr int G = 128; constexpr int CTA_M = 16; constexpr int CTA_N = 128; constexpr int CTA_K = 128; constexpr int WARP_M = 16; constexpr int WARP_N = 32; constexpr int WARP_K = 64; constexpr int SPLITK = 2; constexpr int STAGES = 4; KERNEL_LAUNCH_CODE } else if (num_out_feats <= 64) { constexpr int G = 128; constexpr int CTA_M = 16; constexpr int CTA_N = 128; constexpr int CTA_K = 128; constexpr int WARP_M = 16; constexpr int WARP_N = 32; constexpr int WARP_K = 64; constexpr int SPLITK = 1; constexpr int STAGES = 3; KERNEL_LAUNCH_CODE } else if (num_out_feats <= 128) { constexpr int G = 128; constexpr int CTA_M = 32; constexpr int CTA_N = 128; constexpr int CTA_K = 128; constexpr int WARP_M = 32; constexpr int WARP_N = 32; constexpr int WARP_K = 64; constexpr int SPLITK = 1; constexpr int STAGES = 4; KERNEL_LAUNCH_CODE } else if (num_out_feats <= 192) { constexpr int G = 128; constexpr int CTA_M = 64; constexpr int CTA_N = 128; constexpr int CTA_K = 64; constexpr int WARP_M = 64; constexpr int WARP_N = 32; constexpr int WARP_K = 64; constexpr int SPLITK = 1; constexpr int STAGES = 4; KERNEL_LAUNCH_CODE } else { constexpr int G = 128; constexpr int CTA_M = 64; constexpr int CTA_N = 128; constexpr int CTA_K = 64; constexpr int WARP_M = 64; constexpr int WARP_N = 32; constexpr int WARP_K = 64; constexpr int STAGES = 4; constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N); constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(half); if (kSmemByteSize >= 99 * 1024) { printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); return _out_feats; } int j_factors1 = num_out_channels / CTA_N / 1; dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1); dim3 threads_per_block(WARP_SIZE, NUM_WARPS); auto kernel_func = gemm_w4a16_T2; cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); kernel_func<<>>( in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels); } return _out_feats; } #else // if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 torch::Tensor awq_v2_gemm_f16i4( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros) { throw std::runtime_error("This GEMM requires a CUDA arch >= sm80.\n"); } #endif // if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/v2/gemm_cuda.h ================================================ #include torch::Tensor awq_v2_gemm_f16i4(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros); ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/v2/gemv_cuda.cu ================================================ /* * Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv) * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* @article{lin2023awq, title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ #include #include #include #include "gemv_cuda.h" #include "../dequantize.cuh" #define PACK_FACTOR 8 #define WARP_SIZE 32 #define MEM_ACCESS_SIZE 128 // Reduce sum within the warp using the tree reduction algorithm. template __device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4]) { // kInterleave = 4 float fpsum[Num]; #pragma unroll for (int i = 0; i < Num; ++i) { fpsum[i] = __half2float(psum[i]); } #pragma unroll for (int i = 0; i < Num; ++i) { // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4) fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16); fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8); fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1); } __syncthreads(); int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize; if (lane == 0 || lane == 2 || lane == 4 || lane == 6) { #pragma unroll for (int i = 0; i < Num; ++i) { out_smem[warp][i * 4 + lane / 2] = fpsum[i]; } } __syncthreads(); }; __device__ __forceinline__ int make_divisible(int c, int divisor){ return (c + divisor - 1) / divisor; } template __global__ void gemv_kernel( const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs, const int IC, const int OC) { const int kStride = 64; const int kElemsPerThread = MEM_ACCESS_SIZE / 4; const int kThreadsNumPerTile = kStride / kElemsPerThread; // assert(MEM_ACCESS_SIZE == 128); static constexpr int kShuffleSize = 32; static constexpr int kShuffleBasicTile = 2; static constexpr int kShuffleContinous = 4; static constexpr int kShuffleStrided = 4; constexpr int Num = NPerBlock * Batch; constexpr int kInterleave = 4; half local_inputs[kElemsPerThread]; uint32_t local_qweights[MEM_ACCESS_SIZE / 32]; half half_weight_buffer[kElemsPerThread]; half dequantized_weight[kElemsPerThread * NPerBlock]; half local_scale[NPerBlock]; half local_scaled_zeros[NPerBlock]; half psum[Num]; for (int i = 0; i < Num; ++i) psum[i] = __float2half(0.f); extern __shared__ uint8_t shmem[]; float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem); const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave; const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave; const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread; const int group_offset = act_k_offset / GroupSize; // TODO: use make_divisible const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR; const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC; const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC; const half* inputs_ptr = inputs + act_k_offset; const int act_forward_step = BlockSize * kElemsPerThread / kInterleave; const int scale_forward_step = act_forward_step / GroupSize * OC; // Main loop iteration, each block completes the outputs for several OCs for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) { // Load qweight, scales and scaled_zeros #pragma unroll for (int idx = 0; idx < NPerBlock; ++idx) { // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit) *((float4*)(local_qweights)) = *((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR)); local_scale[idx] = *(scale_ptr + idx * kInterleave); local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave); // Map int4 qweight to fp format #pragma unroll for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) { // Converts 32 bits (8 x int4) to 8 fp16 dequantize_s4_to_fp16x2(*reinterpret_cast(local_qweights + i), reinterpret_cast(half_weight_buffer + i * PACK_FACTOR)); } // Dequantize (apply s/z) and shuffle elements to match the weight packing format #pragma unroll for (int i = 0; i < kShuffleContinous; ++i) { #pragma unroll for (int j = 0; j < kShuffleStrided; ++j) { half2 w = *reinterpret_cast( half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile ); w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx])); dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x; dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y; } } } #pragma unroll for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) { const half* local_inputs_ptr = inputs_ptr + batch_idx * IC; #pragma unroll for (int idx = 0; idx < kElemsPerThread / 8; ++idx) { // load activation, 8 halves (128 bits) / step. *((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8)); } // Perform the MACs #pragma unroll for (int x = 0; x < NPerBlock / 2; ++x) { #pragma unroll for (int y = 0; y < kElemsPerThread; ++y) { *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) = __hfma2(*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2), __half2half2(local_inputs[y]), *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2)); } } } inputs_ptr += act_forward_step; scale_ptr += scale_forward_step; zeros_ptr += scale_forward_step; } warp_reduce(psum, out_smem); // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) { int batch_idx = i / (NPerBlock * kInterleave); int oc_idx = i % (NPerBlock * kInterleave); float acc = 0.f; for (int j = 0; j < BlockSize / WARP_SIZE; ++j) { acc += out_smem[j][i]; } outputs[batch_idx * OC + blk_row_offset + oc_idx] = __float2half(acc); } } /* Computes GEMV (PyTorch interface). Args: _in_feats: tensor of shape [B, IC]; _kernel: int tensor of shape [OC, IC // 8]; _zeros: int tensor of shape [OC, IC // G // 8]; _scaling_factors: tensor of shape [OC, IC // G]; blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; Returns: out_feats: tensor of shape [B, OC]; */ torch::Tensor awq_v2_gemv_f16i4( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int m, int n, int k, int group_size) { std::vector output_shape = _in_feats.sizes().vec(); output_shape.back() = n; auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto kernel = reinterpret_cast(_kernel.data_ptr()); auto zeros = reinterpret_cast(_zeros.data_ptr()); auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); at::Tensor _out_feats = torch::empty(output_shape, options); half * out_feats = reinterpret_cast(_out_feats.data_ptr()); static constexpr int N_PER_BLOCK = 2; static constexpr int K_INTERLEAVE = 4; static constexpr int BLOCK_SIZE = 256; dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); dim3 num_threads(BLOCK_SIZE); // if (group_size == 64) // { // gemv_kernel_g64<<>>( // // pointers // in_feats, kernel, zeros, scaling_factors, out_feats, // // constants // num_in_channels, num_out_channels // ); // } if (group_size == 128) { switch (m) { case 1: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 2: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 3: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 4: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 5: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 6: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 7: gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; default: throw std::runtime_error("Unsupported batch size for gemv kernel.\n"); } } else { throw std::runtime_error("Unsupported group size for gemv kernel.\n"); } return _out_feats; } ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/v2/gemv_cuda.h ================================================ #pragma once #include torch::Tensor awq_v2_gemv_f16i4( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int m, int n, int k, int group_size); ================================================ FILE: optimum/quanto/library/extensions/cuda/awq/v2/semaphore.h ================================================ /*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. */ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// // namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// /// CTA-wide semaphore for inter-CTA synchronization. class Semaphore { public: int *lock; bool wait_thread; int state; public: /// Implements a semaphore to wait for a flag to reach a given value __host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_), wait_thread(thread_id < 0 || thread_id == 0), state(-1) { } /// Permit fetching the synchronization mechanism early __device__ void fetch() { if (wait_thread) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); #else asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); #endif } } /// Gets the internal state __device__ int get_state() const { return state; } /// Waits until the semaphore is equal to the given value __device__ void wait(int status = 0) { while (__syncthreads_and(state != status)) { fetch(); } __syncthreads(); } /// Updates the lock with the given result __device__ void release(int status = 0) { __syncthreads(); if (wait_thread) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); #else asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); #endif } } }; ///////////////////////////////////////////////////////////////////////////////////////////////// // } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/COPYRIGHT ================================================ These kernels were vendored from VLLM. The Marlin kernels were developed by Elias Frantar and extended by Neural Magic. --- Copyright (C) Marlin.2024 Elias Frantar Modified by Neural Magic Copyright 2024 The vLLM team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/fp8_marlin.cu ================================================ /* * Modified by Neural Magic * Copyright (C) Marlin.2024 Elias Frantar * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * Adapted from https://github.com/IST-DASLab/marlin */ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" #include "fp8_marlin.cuh" using namespace gptq_marlin; #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ std::is_same::value, \ "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } namespace fp8_marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template shared // fetch pipeline const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) {} } // namespace fp8_marlin torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); } #else // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template __device__ inline void mma(const typename ScalarType::FragA& a_frag, const typename ScalarType::FragB& frag_b, typename ScalarType::FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } // Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 // bf16 Reference: // - FP16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 // - BF16: // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 template __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { // Constants for FP8 (E4M3) and FP16 formats constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; // Calculate MASK for extracting mantissa and exponent constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | (MASK3 >> 16); // Final MASK value: 0x7F007F00 // Extract and shift FP8 values to FP16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias typename ScalarType::FragB frag_b; // Note: reverse indexing is intentional because weights are permuted frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); return frag_b; } template <> __device__ inline typename ScalarType::FragB dequant_8bit(int q) { // Constants for FP8 (E4M3) and BF16 formats constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; // Calculate MASK for extracting mantissa and exponent constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | (MASK3 >> 16); // Final MASK value: 0x7F007F00 // Extract and shift FP8 values to BF16 format int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); // Construct and apply exponent bias constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent // position constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias typename ScalarType::FragB frag_b; // Note: reverse indexing is intentional because weights are permuted frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template __device__ inline void scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Given 2 floats multiply by 2 scales (halves) template __device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do // Guarantee that subsequent writes by this threadblock will be visible // globally. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. __device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { lock[0] = 0; return; } int val = 1; // Make sure that all writes since acquiring this barrier are visible // globally, while releasing the barrier. asm volatile("fence.acq_rel.gpu;\n"); asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); } } template shared // fetch pipeline const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM // example: // 0 1 3 // 0 2 3 // 1 2 4 // While this kind of partitioning makes things somewhat more complicated, it // ensures good utilization of all SMs for many kinds of shape and GPU // configurations, while requiring as few slow global cross-threadblock // reductions as possible. using Dtype = ScalarType; using scalar_t2 = typename ScalarType::scalar_t2; using FragA = typename ScalarType::FragA; using FragB = typename ScalarType::FragB; using FragC = typename ScalarType::FragC; using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); prob_m = 16 * thread_m_blocks; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; } // Compute all information about the current slice which is required for // synchronization. auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { A += 16 * thread_m_blocks * prob_k / 8; C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; } }; init_slice(); // A sizes/strides // stride of the A matrix in global memory int a_gl_stride = prob_k / 8; // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // between subsequent accesses within a tile int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory tile reads constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // within a shared memory tile constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // overall size of a tile constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // number of shared write iterations for a tile constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); // B sizes/strides int b_gl_stride = 16 * prob_n / (pack_factor * 4); constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); constexpr int b_sh_wr_delta = threads * b_thread_vecs; constexpr int b_sh_rd_delta = threads * b_thread_vecs; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; constexpr int g_idx_stage = 0; // constexpr int act_s_row_stride = 1; // int act_s_col_stride = act_s_row_stride * num_groups; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; int tb_n_warps = thread_n_blocks / 4; int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); a_gl_rd += a_gl_rd_delta_o * slice_row; // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; int b_sh_wr = threadIdx.x * b_thread_vecs; int b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order int slice_k_start = tb_k * slice_row; int slice_k_start_shared_fetch = slice_k_start; int slice_n_offset = act_s_col_tb_stride * slice_col; // No act_order int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // We scale a `half2` tile in row-major layout for column-wise quantization. int s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Precompute which thread should not read memory in which iterations; this is // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; // To ensure that writing and reading A tiles to/from shared memory, the // latter in fragment format, is fully bank conflict free, we need to use a // rather fancy XOR-based layout. The key here is that neither reads nor // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the // same shared memory banks. Further, it seems (based on NSight-Compute) that // each warp must also write a consecutive memory segment? auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; FragS frag_s[2][4]; // Zero accumulators. auto zero_accums = [&]() { #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; int sh_num_groups = -1; constexpr int sh_max_num_groups = 32; auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; if (sh_num_groups < sh_max_num_groups) { sh_num_groups = sh_max_num_groups; } if (sh_first_group_id + sh_num_groups > num_groups) { sh_num_groups = num_groups - sh_first_group_id; } int row_offset = first_group_id * s_gl_stride; if (is_async) { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); } } } else { for (int i = 0; i < sh_num_groups; i++) { if (threadIdx.x < s_sh_stride) { sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; } } } }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } } // Insert a fence even when we are winding down the pipeline to ensure that // waiting is also correct at this point. cp_async_fence(); }; // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering // and can only issue the next fetch when it is guaranteed that the previous // shared memory load is fully complete (as it may otherwise be // overwritten). cp_async_wait(); __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { is_same_group[pipe] = false; same_group_id[pipe] = 0; return; }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; frag_b0 = dequant_8bit(b_quant_0); frag_b1 = dequant_8bit(b_quant_1); #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; // Since we slice across the k dimension of a tile in order to increase the // number of warps while keeping the n dimension of a tile reasonable, we have // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { int red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); float* c_wr = reinterpret_cast(&sh[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); } } }; // Since multiple threadblocks may process parts of the same column slice, we // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out // results in FP16 (but still reduce with FP32 compute). constexpr int active_threads = 32 * thread_n_blocks / 4; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; if (!first) { // Interestingly, doing direct global accesses here really seems to mess up // the compiler and lead to slowdowns, hence we also use async-copies even // though these fetches are not actually asynchronous. #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( &sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } } } } }; // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. auto write_result = [&]() { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; c_sh_wr += 32 * (threadIdx.x / 32); int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS& s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } } __syncthreads(); #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { C[c_gl_wr] = sh[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } } }; // Start global fetch and register load pipelines. auto start_pipes = [&]() { #pragma unroll for (int i = 0; i < stages - 1; i++) { fetch_to_shared(i, i, i < slice_iters); } zero_accums(); wait_for_stage(); init_same_group(0); fetch_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); slice_k_start_shared_fetch += tb_k * (stages - 1); }; if (slice_iters) { start_pipes(); } // Main loop. while (slice_iters) { // We unroll over both the global fetch and the register load pipeline to // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); pipe++; wait_for_stage(); init_same_group(pipe % stages); } matmul(k); } slice_iters--; if (slice_iters == 0) { break; } } a_gl_rd += a_gl_rd_delta_o * stages; slice_k_start += tb_k * stages; slice_k_start_shared_fetch += tb_k * stages; // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); thread_block_reduce(); cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float(reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } } if (slice_count > 1) { // only globally reduce if there is more than one // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; slice_col++; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); } } } } #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ locks); \ } typedef struct { int thread_k; int thread_n; int num_threads; } thread_config_t; typedef struct { int max_m_blocks; thread_config_t tb_cfg; } exec_config_t; thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {128, 128, 256}, {64, 128, 128}, {128, 64, 128}, }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {64, 256, 256}, {64, 128, 128}, {128, 64, 128}, }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size) { int tb_n = th_config.thread_n; // Get max scale groups per thread-block // Fixed for channelwise int tb_groups = 1; int tb_scales = tb_groups * tb_n * 2; return tb_scales * pipe_stages; } bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; int b_size = (tb_k * tb_n / pack_factor) * 4; // Get A size int m_blocks = div_ceil(prob_m, 16); int tb_max_m = 16; while (true) { if (m_blocks >= max_m_blocks) { tb_max_m *= max_m_blocks; break; } max_m_blocks--; if (max_m_blocks == 0) { TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); } } int a_size = (tb_max_m * tb_k) * 2; float pipe_size = (a_size + b_size) * pipe_stages; TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { return false; } // Verify K/N are divisible by thread K/N if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { return false; } // Verify min for thread K/N if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { return false; } // num_threads must be at least 128 (= 4 warps) if (th_config.num_threads < 128) { return false; } // Determine cache for scales int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size); // Check that pipeline fits into cache if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, scales_cache_size, max_shared_mem)) { return false; } return true; } exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } else { for (auto th_config : large_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } max_m_blocks--; // Process less M blocks per invocation to reduce cache // usage } return exec_config_t{0, {-1, -1, -1}}; } #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) template void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, int prob_n, int prob_k, void* workspace, int num_bits, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int tot_m = prob_m; int tot_m_blocks = div_ceil(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) { cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); } int max_shared_mem = 0; cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); // Set thread config exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config exec_cfg = exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; } else { // Auto config exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, max_shared_mem); } TORCH_CHECK( exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, max_shared_mem), "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, ", thread_k = ", exec_cfg.tb_cfg.thread_k, ", thread_n = ", exec_cfg.tb_cfg.thread_n, ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); int num_threads = exec_cfg.tb_cfg.num_threads; thread_k = exec_cfg.tb_cfg.thread_k; thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int blocks = sms; TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); int group_blocks = -1; const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; int* locks = (int*)workspace; // Main loop for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations if (false) { } CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) CALL_IF(8, 8, 8, 256) CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } } } // namespace fp8_marlin torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); int pack_factor = 32 / num_bits; // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m); TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k); // Verify B TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not divisible by tile_size = ", gptq_marlin::tile_size); int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c = torch::empty({size_m, size_n}, options); // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_k = -1; // thread_n: `n` size of a thread_tile in `weights` (can usually be left as // auto -1) int thread_n = -1; // sms: number of SMs to use for the kernel (can usually be left as auto -1) int sms = -1; // Detect groupsize and act_order int num_groups = -1; int group_size = -1; int b_rank = b_scales.sizes().size(); TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), " is not size_n = ", size_n); // Channelwise only for FP8 TORCH_CHECK(b_scales.size(0) == 1) num_groups = b_scales.size(0); // Verify workspace size TORCH_CHECK( size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { fp8_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { fp8_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, gptq_marlin::max_par); } else { TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); } return c; } #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/fp8_marlin.cuh ================================================ // #pragma once #include #include // #ifndef _fp8_marlin_cuh // #define _fp8_marlin_cuh // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // assert(0); // #else torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k); // #endif // #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/gptq_marlin.cuh ================================================ #pragma once #include #include #include #include #include #include #include namespace gptq_marlin { // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; template struct Vec { T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // No support for async #else __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" ::"r"((int)pred), "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } template __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } #endif } // namespace gptq_marlin ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_dtypes.cuh ================================================ #ifndef _data_types_cuh #define _data_types_cuh #include "gptq_marlin.cuh" #include #include namespace gptq_marlin { template class ScalarType {}; template <> class ScalarType { public: using scalar_t = half; using scalar_t2 = half2; // Matrix fragments for tensor core instructions; their precise layout is // documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; static __device__ float inline num2float(const half x) { return __half2float(x); } static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } }; template <> class ScalarType { public: using scalar_t = nv_bfloat16; using scalar_t2 = nv_bfloat162; using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } #endif }; } // namespace gptq_marlin #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_repack.cu ================================================ #include "gptq_marlin.cuh" namespace gptq_marlin { static constexpr int repack_stages = 8; static constexpr int repack_threads = 256; static constexpr int tile_k_size = tile_size; static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template __global__ void marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {} } // namespace gptq_marlin torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); } #else template __global__ void marlin_repack_kernel( uint32_t const* __restrict__ b_q_weight_ptr, uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); int start_k_tile = blockIdx.x * block_k_tiles; if (start_k_tile >= k_tiles) { return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&]() { // We only have `stages - 2` active fetches since we are double buffering // and can only issue the next fetch when it is guaranteed that the previous // shared memory load is fully complete (as it may otherwise be // overwritten). cp_async_wait(); __syncthreads(); }; extern __shared__ int4 sh[]; constexpr int perm_size = tile_k_size / 4; int4* sh_perm_ptr = sh; int4* sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } constexpr int tile_ints = tile_k_size / pack_factor; constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; } __syncthreads(); }; auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { if (n_tile_id >= n_tiles) { cp_async_fence(); return; } int first_n = n_tile_id * tile_n_size; int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } } else { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } } cp_async_fence(); }; auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { if (n_tile_id >= n_tiles) { return; } int warp_id = threadIdx.x / 32; int th_id = threadIdx.x % 32; if (warp_id >= 4) { return; } int tc_col = th_id / 4; int tc_row = (th_id % 4) * 2; constexpr int tc_offsets[4] = {0, 1, 8, 9}; int cur_n = warp_id * 16 + tc_col; constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; if constexpr (has_perm) { for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; uint32_t src_k = sh_perm_int_ptr[k_idx]; uint32_t src_k_pos = src_k % pack_factor; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; vals[i] = b1_cur_val; vals[4 + i] = b2_cur_val; } } else { uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; #pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } #pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; int cur_pos = cur_elem % pack_factor; vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h if constexpr (num_bits == 4) { constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; #pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } out_ptr[out_offset + th_id * 4 + warp_id] = res; } else { constexpr int pack_idx[4] = {0, 2, 1, 3}; uint32_t res1 = 0; uint32_t res2 = 0; #pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); } out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; } }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; #pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; if constexpr (has_perm) { load_perm_to_shared(k_tile_id); } start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { #pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); repack_tile(pipe, k_tile_id, n_tile_id + pipe); wait_for_stage(); } n_tile_id += repack_stages; } } } } // namespace gptq_marlin #define CALL_IF(NUM_BITS, HAS_PERM) \ else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ cudaFuncSetAttribute( \ gptq_marlin::marlin_repack_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ gptq_marlin::marlin_repack_kernel \ <<>>( \ b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); int const pack_factor = 32 / num_bits; // Verify B TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), ", size_k = ", size_k, ", pack_factor = ", pack_factor); TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); // Verify device and strides TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); torch::Tensor out = torch::empty({size_k / gptq_marlin::tile_size, size_n * gptq_marlin::tile_size / pack_factor}, options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; // Get ptrs uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); // Get dev info int dev = b_q_weight.get_device(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); int max_shared_mem = 0; cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); if (false) { } CALL_IF(4, false) CALL_IF(4, true) CALL_IF(8, false) CALL_IF(8, true) else { TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); } return out; } #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/gptq_marlin_repack.cuh ================================================ #include #include #include #ifndef _gptq_marlin_repack_cuh #define _gptq_marlin_repack_cuh torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/marlin_cuda.cpp ================================================ /* * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "marlin_cuda.h" #include #include #include #include #include "marlin_cuda_kernel.cuh" const int ERR_PROB_SHAPE = 1; const int ERR_KERN_SHAPE = 2; void mul( const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, const torch::Tensor& s, const torch::Tensor& sz, // ADDED: add scaled zero point torch::Tensor& workspace, int thread_k, int thread_n, int sms, int max_par ) { int prob_m = A.size(0); int prob_n = C.size(1); int prob_k = A.size(1); int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0); if (groupsize != -1 && groupsize * s.size(0) != prob_k) AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); if (workspace.numel() < prob_n / 128 * max_par) AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); int dev = A.get_device(); int err = marlin_cuda( A.data_ptr(), B.data_ptr(), C.data_ptr(), s.data_ptr(), sz.data_ptr(), // ADDED: add scaled zero point prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par ); if (err == ERR_PROB_SHAPE) { AT_ERROR( "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." ); } else if (err == ERR_KERN_SHAPE) { AT_ERROR( "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." ); } } ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/marlin_cuda.h ================================================ /* * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include void mul( const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, const torch::Tensor& s, const torch::Tensor& sz, torch::Tensor& workspace, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 8 ); ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cu ================================================ /* * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MARLIN_CUDA_KERNEL_CUH #define MARLIN_CUDA_KERNEL_CUH #include #include #include #include constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core // operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee this. template struct Vec { T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; // Matrix fragments for tensor core instructions; their precise layout is documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales // Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that // are not multiples of 16. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" " @p cp.async.cg.shared.global [%1], [%2], %3;\n" "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) ); } // Asynchronous global->shared copy __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. __device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } // Wait until at most `n` async copy stages are still pending. template __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); } // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. __device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { const uint32_t* a = reinterpret_cast(&a_frag); const uint32_t* b = reinterpret_cast(&frag_b); float* c = reinterpret_cast(&frag_c); asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) ); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. __device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) ); } // Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to // automatically recognize it in all cases. template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) ); return res; } // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. // We mostly follow the strategy in the link below, with some small changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h __device__ inline FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. // const int SUB = 0x64086408; // const int MUL = 0x2c002c00; // const int ADD = 0xd480d480; // MODIFIED: use scaled zero point so do not need to map to [-8, 7] const int SUB = 0x64006400; const int MUL = 0x2c002c00; const int ADD = 0xd400d400; FragB frag_b; frag_b[0] = __hsub2( *reinterpret_cast(&lo), *reinterpret_cast(&SUB) ); frag_b[1] = __hfma2( *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) ); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. // MODIFIED: add scaled zero point __device__ inline void scale(FragB& frag_b, FragS& frag_s, FragS& frag_sz, int i) { half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); half2 sz = __half2half2(reinterpret_cast<__half*>(&frag_sz)[i]); // frag_b[0] = __hmul2(frag_b[0], s); // frag_b[1] = __hmul2(frag_b[1], s); frag_b[0] = __hfma2(frag_b[0], s, sz); frag_b[1] = __hfma2(frag_b[1], s, sz); } // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do // Guarantee that subsequent writes by this threadblock will be visible globally. asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. __device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { lock[0] = 0; return; } int val = 1; // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. asm volatile ("fence.acq_rel.gpu;\n"); asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); } } template < const int threads, // number of threads in a threadblock const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock const int thread_n_blocks, // same for n dimension (output) const int thread_k_blocks, // same for k dimension (reduction) const int stages, // number of stages for the async global->shared fetch pipeline const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn int4* __restrict__ C, // fp16 output buffer of shape mxn const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn // ADDED: add scaled zero point const int4* __restrict__ sz, // fp16 quantization scaled zero points of shape (k/groupsize)xn int prob_m, // batch dimension m int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: // 0 1 3 // 0 2 3 // 1 2 4 // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as // possible. // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); prob_m = 16 * thread_m_blocks; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case // where a stripe starts in the middle of group. if (group_blocks != -1) iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to top // We can easily implement parallel problem execution by just remapping indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; locks += (slice_col_par / n_tiles) * n_tiles; slice_col = slice_col_par % n_tiles; } // Compute all information about the current slice which is required for synchronization. auto init_slice = [&] () { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { A += 16 * thread_m_blocks * prob_k / 8; C += 16 * thread_m_blocks * prob_n / 8; locks += n_tiles; slice_col = 0; } }; init_slice(); int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time constant constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); constexpr int b_sh_wr_delta = threads; constexpr int b_sh_rd_delta = threads; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stage = s_sh_stride; int s_gl_rd_delta = s_gl_stride; // Global A read index of current thread. int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); a_gl_rd += a_gl_rd_delta_o * slice_row; // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; int b_sh_wr = threadIdx.x; int b_sh_rd = threadIdx.x; int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; int s_sh_wr = threadIdx.x; int s_sh_rd; // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major // layout in the former and in row-major in the latter case. if (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than // required for a certain tilesize or when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based // on NSight-Compute) that each warp must also write a consecutive memory segment? auto transform_a = [&] (int i) { int row = i / a_gl_rd_delta_o; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; }; // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory // accesses are static, we simply precompute both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. const int4* B_ptr[b_sh_wr_iters]; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_s = sh_b + (stages * b_sh_stage); // ADDED: shared memory storage for scaled zero points int4* sh_sz = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; FragC frag_c[thread_m_blocks][4][2]; FragS frag_s[2][4]; // ADDED: register storage for scaled zero points FragS frag_sz[2][4]; // Zero accumulators. auto zero_accums = [&] () { #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { if (pred) { int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i] ); } int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { // ADDED: fetch scaled zero pointers too int4* sh_s_stage = sh_s + s_sh_stage * pipe; int4* sh_sz_stage = sh_sz + s_sh_stage * pipe; if (s_sh_wr_pred) { cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); cp_async4(&sh_sz_stage[s_sh_wr], &sz[s_gl_rd]); } s_gl_rd += s_gl_rd_delta; } } // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. cp_async_fence(); }; // Wait until the next thread tile has been loaded to shared memory. auto wait_for_stage = [&] () { // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). cp_async_wait(); __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. auto fetch_to_registers = [&] (int k, int pipe) { // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the // compiler and correspondingly a noticable drop in performance. if (group_blocks != -1) { // ADDED: load scaled zero pointers too int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); int4* sh_sz_stage = sh_sz + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; reinterpret_cast(&frag_sz[k % 2])[0] = sh_sz_stage[s_sh_rd]; } int4* sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4* sh_b_stage = sh_b + b_sh_stage * pipe; frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&] (int k) { // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. // MODIFIED: add scaled zero point if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], frag_sz[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], frag_sz[k % 2][j], 1); #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&] () { constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { int red_idx = threadIdx.x / b_sh_stride; constexpr int red_sh_stride = b_sh_stride * 4 * 2; constexpr int red_sh_delta = b_sh_stride; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, // e.g., for two warps we write only once by warp 1 and read only once by warp 0. #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); float* c_wr = reinterpret_cast(&sh[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } __syncthreads(); } } }; // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&] (bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. // To do this, we write out results in FP16 (but still reduce with FP32 compute). constexpr int active_threads = 32 * thread_n_blocks / 4; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; int row = (threadIdx.x % 32) / 4; if (!first) { // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, // hence we also use async-copies even though these fetches are not actually asynchronous. #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( &sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m ); } cp_async_fence(); cp_async_wait<0>(); } #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( reinterpret_cast<__half*>(&c_red)[j] ); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast<__half*>(&c)[j] = __float2half( reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] ); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } } } } }; // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, // the reduction above is performed in fragment layout. auto write_result = [&] () { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; c_sh_wr += 32 * (threadIdx.x / 32); int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; // We first reorder in shared memory to guarantee the most efficient final global write patterns auto write = [&] (int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); ((half2*) sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } } __syncthreads(); #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { C[c_gl_wr] = sh[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } } }; // Start global fetch and register load pipelines. auto start_pipes = [&] () { #pragma unroll for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); }; start_pipes(); // Main loop. while (slice_iters) { // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); pipe++; wait_for_stage(); } matmul(k); } slice_iters--; if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before write-out if (group_blocks == -1 && last) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); // ADDED: fetch scaled zero pointers too cp_async4(&sh_sz[s_sh_wr], &sz[s_gl_rd]); } cp_async_fence(); } thread_block_reduce(); if (group_blocks == -1 && last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; // ADDED: load scaled zero pointers too reinterpret_cast(&frag_sz)[0] = sh_sz[s_sh_rd + 0]; reinterpret_cast(&frag_sz)[1] = sh_sz[s_sh_rd + 4]; } } if (slice_count > 1) { // only globally reduce if there is more than one block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; slice_col++; init_slice(); if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); } } } } // 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more // latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) // ADDED: add scaled zero pointer #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ else if ( \ thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS \ ) { \ cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ SHARED_MEM \ ); \ Marlin< \ THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS \ ><<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, sz_ptr,\ prob_m, prob_n, prob_k, \ locks \ ); \ } const int ERR_PROB_SHAPE = 1; const int ERR_KERN_SHAPE = 2; // ADDED: add scaled zero pointer int marlin_cuda( const void* A, const void* B, void* C, void* s, void* sz, int prob_m, int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16 ) { int tot_m = prob_m; int tot_m_blocks = ceildiv(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); if (thread_k == -1 || thread_n == -1) { if (prob_m <= 16) { // For small batchizes, better partioning is slightly more important than better compute utilization thread_k = 128; thread_n = 128; } else { thread_k = 64; thread_n = 256; } } int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) return ERR_PROB_SHAPE; if (prob_m == 0 || prob_n == 0 || prob_k == 0) return 0; const int4* A_ptr = (const int4*) A; const int4* B_ptr = (const int4*) B; int4* C_ptr = (int4*) C; const int4* s_ptr = (const int4*) s; // ADDED: add scaled zero pointer const int4* sz_ptr = (const int4*) sz; int cols = prob_n / thread_n; int* locks = (int*) workspace; int ret = 0; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; if (thread_m_blocks > 4) { // Note that parallel > 1 currently only works for inputs without any padding par = (16 * thread_m_blocks - pad) / 64; if (par > max_par) par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; } // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) // in our testing, however many more are, in principle, possible. if (false) {} CALL_IF(1, 8, 8, -1) CALL_IF(1, 8, 8, 8) CALL_IF(1, 16, 4, -1) CALL_IF(1, 16, 4, 8) CALL_IF(2, 16, 4, -1) CALL_IF(2, 16, 4, 8) CALL_IF(3, 16, 4, -1) CALL_IF(3, 16, 4, 8) CALL_IF(4, 16, 4, -1) CALL_IF(4, 16, 4, 8) else ret = ERR_KERN_SHAPE; A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } return ret; } #endif ================================================ FILE: optimum/quanto/library/extensions/cuda/marlin/marlin_cuda_kernel.cuh ================================================ /* * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include int marlin_cuda( const void* A, const void* B, void* C, void* s, void* sz, // ADDED: add scaled zero point int prob_m, int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16 ); ================================================ FILE: optimum/quanto/library/extensions/cuda/pybind_module.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "awq/v2/gemm_cuda.h" #include "awq/v2/gemv_cuda.h" #include "unpack.h" #include "marlin/fp8_marlin.cuh" #include "marlin/gptq_marlin_repack.cuh" #include "marlin/marlin_cuda.h" // !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types, // and need to be explicitly converted using dedicated helpers before calling a C++ method. // As a consequence, when an operation takes such an object as parameter, instead // of creating a binding directly to the C++ method, you must create a binding to a // lambda method that converts the unmapped types and calls the C++ method. // See the binding of quantize_symmetric for instance. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("awq_v2_gemm_f16i4", &awq_v2_gemm_f16i4, "awq_v2_gemm_f16i4"); m.def("awq_v2_gemv_f16i4", &awq_v2_gemv_f16i4, "awq_v2_gemv_f16i4"); m.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin_repack"); m.def("fp8_marlin_gemm", &fp8_marlin_gemm, "fp8_marlin_gemm"); m.def("marlin_gemm_f16i4", &mul, "marlin_gemm_f16i4"); m.def("unpack", &unpack, "unpack"); } ================================================ FILE: optimum/quanto/library/extensions/cuda/unpack.cu ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} #define BLOCK_SIZE 256 using namespace at; static torch::Tensor allocate_output(const torch::Tensor& input, int bits) { int n_packed = 8 / bits; auto output_shape = input.sizes().vec(); output_shape[0] = output_shape[0] * n_packed; return torch::empty(output_shape, input.options()); } __global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) return; output[i] = (input[i] & 0x0F); output[i + n] = (input[i] & 0xF0) >> 4; } static torch::Tensor unpack_4bit(const torch::Tensor& input){ auto output = allocate_output(input, 4); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); unpack_4bit_kernel<<>>( input.data_ptr(), output.data_ptr(), numel ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } __global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) return; output[i] = (input[i] & 0x03); output[i + n] = (input[i] & 0x0C) >> 2; output[i + n*2] = (input[i] & 0x30) >> 4; output[i + n*3] = (input[i] & 0xC0) >> 6; } static torch::Tensor unpack_2bit(const torch::Tensor& input){ auto output = allocate_output(input, 2); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); unpack_2bit_kernel<<>>( input.data_ptr(), output.data_ptr(), numel ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } torch::Tensor unpack(torch::Tensor &t, int bits) { TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor."); TORCH_CHECK(t.is_contiguous(), "t must be contiguous."); switch(bits) { case 4: return unpack_4bit(t); case 2: return unpack_2bit(t); default: throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); } } ================================================ FILE: optimum/quanto/library/extensions/cuda/unpack.h ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include torch::Tensor unpack(torch::Tensor &t, int bits); ================================================ FILE: optimum/quanto/library/extensions/extension.py ================================================ import os import shutil import warnings from typing import List import torch from torch.utils.cpp_extension import load __all__ = ["is_extension_available", "get_extension"] class Extension(object): def __init__( self, name: str, root_dir: str, sources: List[str], extra_cflags: List[str] = None, extra_cuda_cflags: List[str] = None, ): self.name = name self.sources = [f"{root_dir}/{source}" for source in sources] self.extra_cflags = extra_cflags self.extra_cuda_cflags = extra_cuda_cflags self.build_directory = os.path.join(root_dir, "build") self._lib = None @property def lib(self): if self._lib is None: # We only load the extension when the lib is required version_file = os.path.join(self.build_directory, "pytorch_version.txt") if os.path.exists(version_file): # The extension has already been built: check the torch version for which it was built with open(version_file, "r") as f: pytorch_build_version = f.read().rstrip() if pytorch_build_version != torch.__version__: shutil.rmtree(self.build_directory) warnings.warn( f"{self.name} was compiled with pytorch {pytorch_build_version}, but {torch.__version__} is installed: it will be recompiled." ) os.makedirs(self.build_directory, exist_ok=True) self._lib = load( name=self.name, sources=self.sources, extra_cflags=self.extra_cflags, extra_cuda_cflags=self.extra_cuda_cflags, build_directory=self.build_directory, ) if not os.path.exists(version_file): with open(version_file, "w") as f: f.write(torch.__version__) return self._lib _extensions = {} def register_extension(extension: Extension): assert extension.name not in _extensions _extensions[extension.name] = extension def get_extension(extension_type: str): """Get an extension Args: extension_type (`str`): The extension type. Returns: The corresponding extension. """ return _extensions[extension_type] def is_extension_available(extension_type: str): """Check is an extension is available Args: extension_type (`str`): The extension type. Returns: True if the extension is available. """ return extension_type in _extensions ================================================ FILE: optimum/quanto/library/extensions/hip/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from ..extension import Extension, register_extension __all__ = [] ext = Extension( "quanto_hip", root_dir=os.path.dirname(__file__), sources=["unpack.cu", "pybind_module.cpp"], extra_cflags=["-std=c++17"], ) register_extension(ext) @torch.library.impl("quanto::unpack", ["CUDA"]) def unpack_hip(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) ================================================ FILE: optimum/quanto/library/extensions/hip/pybind_module.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "unpack.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("unpack", &unpack, "unpack"); } ================================================ FILE: optimum/quanto/library/extensions/hip/unpack.cu ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} #define BLOCK_SIZE 256 using namespace at; static torch::Tensor allocate_output(const torch::Tensor& input, int bits) { int n_packed = 8 / bits; auto output_shape = input.sizes().vec(); output_shape[0] = output_shape[0] * n_packed; return torch::empty(output_shape, input.options()); } __global__ void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) return; output[i] = (input[i] & 0x0F); output[i + n] = (input[i] & 0xF0) >> 4; } static torch::Tensor unpack_4bit(const torch::Tensor& input){ auto output = allocate_output(input, 4); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); unpack_4bit_kernel<<>>( input.data_ptr(), output.data_ptr(), numel ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } __global__ void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) return; output[i] = (input[i] & 0x03); output[i + n] = (input[i] & 0x0C) >> 2; output[i + n*2] = (input[i] & 0x30) >> 4; output[i + n*3] = (input[i] & 0xC0) >> 6; } static torch::Tensor unpack_2bit(const torch::Tensor& input){ auto output = allocate_output(input, 2); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); unpack_2bit_kernel<<>>( input.data_ptr(), output.data_ptr(), numel ); C10_CUDA_KERNEL_LAUNCH_CHECK(); return output; } torch::Tensor unpack(torch::Tensor &t, int bits) { TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); TORCH_CHECK(t.device().is_cuda(), "t must be a CUDA tensor."); TORCH_CHECK(t.is_contiguous(), "t must be contiguous."); switch(bits) { case 4: return unpack_4bit(t); case 2: return unpack_2bit(t); default: throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); } } ================================================ FILE: optimum/quanto/library/extensions/hip/unpack.h ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include torch::Tensor unpack(torch::Tensor &t, int bits); ================================================ FILE: optimum/quanto/library/extensions/mps/README.md ================================================ # Quanto Metal Performance Shaders extension To add a new implementation for an operation defined in `library./ops.py`: - add the corresponding `.mm` file to the list of sources in `__init__.py`, - add a binding to `pybind_module.cpp`, - provide an implementation calling the binding in `__init__.py`. Note: torch JIT extensions for MPS requires the xcode command-line tools. ================================================ FILE: optimum/quanto/library/extensions/mps/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from ..extension import Extension, register_extension __all__ = [] ext = Extension( "quanto_mps", root_dir=os.path.dirname(__file__), sources=["unpack.mm", "pybind_module.cpp"], extra_cflags=["-std=c++17"], ) register_extension(ext) @torch.library.impl("quanto::unpack", "MPS") def unpack_mps(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) ================================================ FILE: optimum/quanto/library/extensions/mps/pybind_module.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "unpack.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("unpack", &unpack, "unpack"); } ================================================ FILE: optimum/quanto/library/extensions/mps/unpack.h ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include torch::Tensor unpack(const torch::Tensor &input, int bits); ================================================ FILE: optimum/quanto/library/extensions/mps/unpack.mm ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "unpack.h" #include #import #import // Defines a Metal custom kernel to mask and shift a buffer element-wise. static char *MASK_AND_SHIFT = R"MPS_MASK&SHIFT( #include using namespace metal; [[host_name("mask_and_rshift")]] kernel void mask_and_rshift(constant uint8_t* input [[buffer(0)]], device uint8_t* output [[buffer(1)]], constant uint8_t& mask [[buffer(2)]], constant int& shift [[buffer(3)]], uint index [[thread_position_in_grid]]) { output[index] = (input[index] & mask) >> shift; } )MPS_MASK&SHIFT"; // Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`. static inline id getMTLBufferStorage(const torch::Tensor& tensor) { return __builtin_bit_cast(id, tensor.storage().data()); } torch::Tensor& mask_and_shift(const torch::Tensor& input, torch::Tensor& output, uint8_t mask, int shift) { @autoreleasepool { id device = MTLCreateSystemDefaultDevice(); NSError *error = nil; // Set the number of threads equal to the number of elements within the input tensor. int num_threads = input.numel(); // Load the custom mask and shift shader. id library = [device newLibraryWithSource:[NSString stringWithUTF8String:MASK_AND_SHIFT] options:nil error:&error]; TORCH_CHECK(library, "Failed to to create custom kernel library, error: ", error.localizedDescription.UTF8String); id kernel = [library newFunctionWithName:[NSString stringWithUTF8String:"mask_and_rshift"]]; TORCH_CHECK(kernel, "Failed to create function state object for mask_and_rshift"); // Create a compute pipeline state object for the soft shrink kernel. id pso = [device newComputePipelineStateWithFunction:kernel error:&error]; TORCH_CHECK(pso, error.localizedDescription.UTF8String); // This is required if torch already encoded something in the command buffer torch::mps::synchronize(); // Get a reference to the command buffer for the MPS stream. id command_buffer = torch::mps::get_command_buffer(); TORCH_CHECK(command_buffer, "Failed to retrieve command buffer reference"); // Get a reference to the dispatch queue for the MPS stream, which encodes the synchronization with the CPU. dispatch_queue_t serial_queue = torch::mps::get_dispatch_queue(); dispatch_sync(serial_queue, ^(){ // Start a compute pass. id compute_encoder = [command_buffer computeCommandEncoder]; TORCH_CHECK(compute_encoder, "Failed to create compute command encoder"); // Encode the pipeline state object and its parameters. [compute_encoder setComputePipelineState:pso]; [compute_encoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() atIndex:0]; [compute_encoder setBuffer:getMTLBufferStorage(output) offset:output.storage_offset() * output.element_size() atIndex:1]; [compute_encoder setBytes:&mask length:sizeof(uint8_t) atIndex:2]; [compute_encoder setBytes:&shift length:sizeof(int) atIndex:3]; MTLSize grid_size = MTLSizeMake(num_threads, 1, 1); // Calculate a thread group size. NSUInteger thread_group_size = pso.maxTotalThreadsPerThreadgroup; if (thread_group_size > num_threads) { thread_group_size = num_threads; } MTLSize mtl_size = MTLSizeMake(thread_group_size, 1, 1); // Encode the compute command. [compute_encoder dispatchThreads:grid_size threadsPerThreadgroup:mtl_size]; [compute_encoder endEncoding]; // Commit the work. torch::mps::commit(); }); torch::mps::synchronize(); } return output; } torch::Tensor unpack_4bit(const torch::Tensor &input) { torch::Tensor output = torch::empty_like(input); mask_and_shift(input, output, 0x0F, 0); torch::Tensor output1 = torch::empty_like(input); mask_and_shift(input, output1, 0xF0, 4); return torch::cat({output, output1}, 0); } torch::Tensor unpack_2bit(const torch::Tensor &input) { torch::Tensor output = torch::empty_like(input); mask_and_shift(input, output, 0x03, 0); torch::Tensor output1 = torch::empty_like(input); mask_and_shift(input, output1, 0x0C, 2); torch::Tensor output2 = torch::empty_like(input); mask_and_shift(input, output2, 0x30, 4); torch::Tensor output3 = torch::empty_like(input); mask_and_shift(input, output3, 0xC0, 6); return torch::cat({output, output1, output2, output3}, 0); } // C++ op dispatching the Metal unpack operation. torch::Tensor unpack(const torch::Tensor &input, int bits) { // Check whether the input tensor resides on the MPS device and whether it's contiguous. TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); // Check the supported data types for soft shrink. TORCH_CHECK(input.scalar_type() == torch::kUInt8, "Unsupported data type: ", input.scalar_type()); switch(bits) { case 4: return unpack_4bit(input); case 2: return unpack_2bit(input); default: throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); } } ================================================ FILE: optimum/quanto/library/extensions/xpu/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # Copyright 2024 Intel Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from packaging import version from ..extension import Extension, register_extension __all__ = [] module_path = os.path.dirname(__file__) sources = [ "unpack.sycl", "pybind_module.cpp", ] ext = Extension( "quanto_xpu", root_dir=os.path.dirname(__file__), sources=sources, ) register_extension(ext) @torch.library.impl("quanto::unpack", "XPU") def unpack_xpu(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) if version.parse(torch.__version__).release >= version.parse("2.8.0").release: torch.library.define( "quanto::gemm_f16i4_awq", "(Tensor input," " Tensor other," " Tensor other_scale," " Tensor other_shift," " int rows," " int out_cols," " int in_cols," " int bits," " int group_size)" " -> Tensor", ) @torch.library.impl("quanto::gemm_f16i4_awq", "XPU") def gemm_f16i4_awq( input: torch.Tensor, other: torch.Tensor, scales: torch.Tensor, shift: torch.Tensor, rows: int, out_cols: int, in_cols: int, bits: int, group_size: int, ): orig_act_size = input.size() orig_dtype = input.dtype input = input.reshape(-1, input.shape[-1]) # XPU does not support float32 for now. if input.dtype == torch.float32: input = input.to(torch.bfloat16) if scales.dtype != input.dtype: scales = scales.to(input.dtype) y = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(input, other, group_size, scales, shift) # remove out_feature padding y = y[:, :out_cols] y = y.reshape(*orig_act_size[:-1], out_cols) return y.to(orig_dtype) ================================================ FILE: optimum/quanto/library/extensions/xpu/pybind_module.cpp ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "unpack.h" // !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types, // and need to be explicitly converted using dedicated helpers before calling a C++ method. // As a consequence, when an operation takes such an object as parameter, instead // of creating a binding directly to the C++ method, you must create a binding to a // lambda method that converts the unmapped types and calls the C++ method. // See the binding of quantize_symmetric for instance. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("unpack", &unpack, "unpack"); } ================================================ FILE: optimum/quanto/library/extensions/xpu/unpack.h ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include torch::Tensor unpack(torch::Tensor &t, int bits); ================================================ FILE: optimum/quanto/library/extensions/xpu/unpack.sycl ================================================ // Copyright 2024 The HuggingFace Team. All rights reserved. // Copyright 2024 Intel Corporation. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} #define BLOCK_SIZE 256 using namespace at; static torch::Tensor allocate_output(const torch::Tensor& input, int bits) { int n_packed = 8 / bits; auto output_shape = input.sizes().vec(); output_shape[0] = output_shape[0] * n_packed; return torch::empty(output_shape, input.options()); } void unpack_4bit_kernel(unsigned char* input, unsigned char* output, int n, const sycl::nd_item<3> &item_ct1) { int i = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (i>=n) return; output[i] = (input[i] & 0x0F); output[i + n] = (input[i] & 0xF0) >> 4; } class Unpack4BitKrn { public: void operator()(sycl::nd_item<3> item_ct1) const { unpack_4bit_kernel(ct0, ct1, numel, item_ct1); } Unpack4BitKrn(unsigned char* _ct0, unsigned char* _ct1, int64_t _numel): ct0(_ct0), ct1(_ct1), numel(_numel) {} private: unsigned char* ct0; unsigned char* ct1; int64_t numel; }; static torch::Tensor unpack_4bit(const torch::Tensor& input){ auto output = allocate_output(input, 4); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); auto krn = [&](sycl::handler &cgh) { auto input_data_ptr_unsigned_char_ct0 = input.data_ptr(); auto output_data_ptr_unsigned_char_ct1 = output.data_ptr(); Unpack4BitKrn krn2(input_data_ptr_unsigned_char_ct0, output_data_ptr_unsigned_char_ct1, numel); cgh.parallel_for( sycl::nd_range<3>( sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, BLOCK_SIZE), sycl::range<3>(1, 1, BLOCK_SIZE)), krn2); }; queue.submit(krn); return output; } void unpack_2bit_kernel(unsigned char* input, unsigned char* output, int n, const sycl::nd_item<3> &item_ct1) { int i = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (i>=n) return; output[i] = (input[i] & 0x03); output[i + n] = (input[i] & 0x0C) >> 2; output[i + n*2] = (input[i] & 0x30) >> 4; output[i + n*3] = (input[i] & 0xC0) >> 6; } class Unpack2BitKrn { public: void operator()(sycl::nd_item<3> item_ct1) const { unpack_2bit_kernel(ct0, ct1, numel, item_ct1); } Unpack2BitKrn(unsigned char* _ct0, unsigned char* _ct1, int64_t _numel): ct0(_ct0), ct1(_ct1), numel(_numel) {} private: unsigned char* ct0; unsigned char* ct1; int64_t numel; }; static torch::Tensor unpack_2bit(const torch::Tensor& input){ auto output = allocate_output(input, 2); const auto numel = input.numel(); int blocks = cdiv(numel, BLOCK_SIZE); sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); auto krn = [&](sycl::handler &cgh) { auto input_data_ptr_unsigned_char_ct0 = input.data_ptr(); auto output_data_ptr_unsigned_char_ct1 = output.data_ptr(); Unpack2BitKrn krn2(input_data_ptr_unsigned_char_ct0, output_data_ptr_unsigned_char_ct1, numel); cgh.parallel_for( sycl::nd_range<3>( sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, BLOCK_SIZE), sycl::range<3>(1, 1, BLOCK_SIZE)), krn2); }; queue.submit(krn); return output; } torch::Tensor unpack(torch::Tensor &t, int bits) { TORCH_CHECK(t.scalar_type() == torch::kUInt8, "Unsupported data type: ", t.scalar_type()); TORCH_CHECK(t.device().is_xpu(), "t must be a XPU tensor."); TORCH_CHECK(t.is_contiguous(), "t must be contiguous."); switch(bits) { case 4: return unpack_4bit(t); case 2: return unpack_2bit(t); default: throw std::invalid_argument("Can only unpack 2-bit or 4-bit tensors."); } } ================================================ FILE: optimum/quanto/library/qbytes_mm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from packaging import version __all__ = [] torch.library.define("quanto::qbytes_mm", "(Tensor A, Tensor B, Tensor scales) -> Tensor") def qbytes_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: activations = activations.to(output_scales.dtype) if weights.dtype.is_floating_point: # Float8 requires an explicit promotion weights = weights.to(output_scales.dtype) # Apply the scale to the weights before the matrix multiplication to put them back # into their initial numerical range and avoid overflows weights = output_scales * weights return torch.matmul(activations, weights.t()) def qbytes_int_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: in_features = activations.shape[-1] out_features = weights.shape[0] # torch._int_mm works on transposed weights, i.e (in_features, out_features) weights = weights.t() if activations.ndim == 2: out_data = torch._int_mm(activations, weights) else: output_shape = activations.shape[:-1] + (out_features,) out_data = torch._int_mm(activations.reshape(-1, in_features), weights) out_data = out_data.reshape(output_shape) # We must evaluate the output as float32 because the multiplication # of the int32 data by the scales might overflow fp32_output = out_data.to(torch.float32) * output_scales.t() return fp32_output.to(output_scales.dtype) def qbytes_int8pack_mm(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: # torch._weight_int8pack_mm expects a vector of scales output_scales = output_scales.flatten() if activations.ndim == 2: return torch._weight_int8pack_mm(activations, weights, output_scales) else: in_features = activations.shape[-1] out_features = weights.shape[0] output_shape = activations.shape[:-1] + (out_features,) out_data = torch._weight_int8pack_mm(activations.reshape(-1, in_features), weights, output_scales) return out_data.reshape(output_shape) @torch.library.impl("quanto::qbytes_mm", "default") def qbytes_mm_impl_default( activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor ) -> torch.Tensor: return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto::qbytes_mm", "CUDA") def qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: assert activations.ndim in (2, 3) in_features = activations.shape[-1] tokens = activations.shape[0] if activations.ndim == 2 else activations.shape[0] * activations.shape[1] out_features = weights.shape[0] if ( activations.dtype == torch.int8 and weights.dtype == torch.int8 and tokens > 16 and tokens % 8 == 0 and in_features % 8 == 0 and out_features % 8 == 0 ): return qbytes_int_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto::qbytes_mm", "CPU") def qbytes_mm_impl_cpu(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: if ( # FIXME: accuracy issues with 2.4.x version.parse(torch.__version__).release >= version.parse("2.6.0").release and activations.dtype == torch.int8 and weights.dtype == torch.int8 ): return qbytes_int_mm(activations, weights, output_scales) in_features = activations.shape[-1] if activations.dtype == torch.bfloat16 and weights.dtype == torch.int8 and in_features % 4 == 0: if type(activations) is not torch.Tensor: activations = activations.dequantize() return qbytes_int8pack_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales) @torch.library.impl("quanto_py::qbytes_mm", "MPS") def qbytes_mm_impl_mps(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor: in_features = activations.shape[-1] out_features = weights.shape[0] if ( version.parse(torch.__version__).release >= version.parse("2.4.0").release and activations.dtype == torch.bfloat16 and weights.dtype == torch.int8 and in_features % 32 == 0 and out_features % 32 == 0 ): if type(activations) is not torch.Tensor: activations = activations.dequantize() return qbytes_int8pack_mm(activations, weights, output_scales) return qbytes_mm(activations, weights, output_scales) ================================================ FILE: optimum/quanto/library/quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Union import torch from ..tensor import dtype_info, group torch.library.define( "quanto::quantize_symmetric", "(Tensor base, ScalarType dtype, int? axis, Tensor scale) -> Tensor" ) @torch.library.impl("quanto::quantize_symmetric", "default") def quantize_symmetric( base: torch.Tensor, dtype: torch.dtype, axis: Union[int, None], scale: torch.Tensor ) -> torch.Tensor: # Sanity checks if axis is None: if scale.ndim > 0: raise ValueError("Scale must be a scalar when quantizing per-tensor") else: if base.ndim == 1: raise ValueError("1D Tensors cannot be quantized per-axis") if axis == base.ndim - 1: # Align on the general convention to index the last dimension axis = -1 if axis not in (0, -1): raise ValueError("Quantization is only supported along the first or last axis.") if base.shape[axis] == 1: raise ValueError(f"Cannot quantize Tensor of shape {base.shape} along axis {axis} of size 1") if torch.squeeze(scale).ndim > 1: raise ValueError("Quantizing along multiple axis is not supported") if scale.ndim != base.ndim: raise ValueError( "When quantizing per-axis, the scale must be broadcastable to the base (Tip: try to add missing dims of length zero)." ) data = base / scale if not dtype.is_floating_point: data = torch.round(data) info = dtype_info(dtype) return torch.clamp(data, min=info.min, max=info.max).to(dtype) torch.library.define( "quanto::quantize_affine", "(Tensor base, int bits, int axis, int? group_size, Tensor scale, Tensor shift) -> Tensor", ) @torch.library.impl("quanto::quantize_affine", "default") def quantize_affine( base: torch.Tensor, bits: int, axis: int, group_size: Union[int, None], scale: torch.Tensor, shift: torch.Tensor ) -> torch.Tensor: if axis not in (0, -1): raise ValueError("axis parameter must be 0 (first axis) or -1 (last axis)") if group_size is not None: base = group(base, axis=axis, group_size=group_size) if shift.dtype.is_floating_point: data = torch.round((base + shift) / scale) else: # Shift is an integer representing zero (i.e. zero-point) data = torch.round(base / scale) + shift return torch.clamp(data, min=0, max=2**bits - 1).to(torch.uint8) ================================================ FILE: optimum/quanto/library/unpack.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch torch.library.define("quanto::unpack", "(Tensor self, int bits) -> Tensor") @torch.library.impl("quanto::unpack", "default") def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: """ Un-Pack int4 / int2 weights (packed in a uint8) into a torch.uint8 tensor What un-packing means? Assume we have packed 4 2-bit values in 8-bit (because torch does not have native support for 2-bit datatypes) > 1110 0100 Unpacking them means retrieving the original 4 2-bit values: > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000 Args: packed (`torch.Tensor`): The packed tensor in `torch.uint8` precision bits (`int`): The number of bits per encoded value. Can be 2 or 4. """ unpacked = [] values_per_item = 8 // bits def rshift(t: torch.Tensor, bits: int): if t.device.type == "mps": # rshift is not supported on MPS device return t // (2**bits) return t >> bits # Unpack each set of values independently for i in range(values_per_item): mask = 2 ** (bits * (i + 1)) - 1 unpacked.append(rshift(packed & mask, bits * i)) # Return the concatenated unpacked tensors return torch.cat(unpacked).to(torch.uint8) ================================================ FILE: optimum/quanto/models/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from collections.abc import Mapping from typing import Any, Dict, List, Optional, Union def is_transformers_available() -> bool: return importlib.util.find_spec("transformers") is not None def is_diffusers_available() -> bool: return importlib.util.find_spec("diffusers") is not None if is_transformers_available(): from .transformers_models import * if is_diffusers_available(): from .diffusers_models import * ================================================ FILE: optimum/quanto/models/diffusers_models.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from pathlib import Path from typing import Any, List, Optional, Union from huggingface_hub import ModelHubMixin, snapshot_download from ..quantize import Optimizer, freeze, qtype, quantization_map, quantize, requantize from . import is_diffusers_available __all__ = ["QuantizedDiffusersModel", "QuantizedPixArtTransformer2DModel"] if not is_diffusers_available(): raise ImportError(f"{__all__} require the diffusers library") from diffusers import PixArtTransformer2DModel from diffusers.models.model_loading_utils import load_state_dict from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import ( CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, _get_checkpoint_shard_files, is_accelerate_available, ) from .shared_dict import ShardedStateDict class QuantizedDiffusersModel(ModelHubMixin): BASE_NAME = "quanto" base_class = None def __init__(self, model: ModelMixin): if not isinstance(model, ModelMixin) or len(quantization_map(model)) == 0: raise ValueError("The source model must be a quantized diffusers model.") self._wrapped = model def __getattr__(self, name: str) -> Any: """If an attribute is not found in this class, look in the wrapped module.""" try: return super().__getattr__(name) except AttributeError: wrapped = self.__dict__["_wrapped"] return getattr(wrapped, name) def forward(self, *args, **kwargs): return self._wrapped.forward(*args, **kwargs) def __call__(self, *args, **kwargs): return self._wrapped.forward(*args, **kwargs) @staticmethod def _qmap_name(): return f"{QuantizedDiffusersModel.BASE_NAME}_qmap.json" @classmethod def quantize( cls, model: ModelMixin, weights: Optional[Union[str, qtype]] = None, activations: Optional[Union[str, qtype]] = None, optimizer: Optional[Optimizer] = None, include: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, ): """Quantize the specified model By default, each layer of the model will be quantized if is quantizable. If include patterns are specified, the layer name must match one of them. If exclude patterns are specified, the layer must not match one of them. Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See https://docs.python.org/3/library/fnmatch.html for more details. Note: quantization happens in-place and modifies the original model. Note that the resulting quantized model will be frozen: if you wish to do quantization-aware training then you should use `optimum.quanto.quantize` instead, and call `optimum.quanto.freeze` only after the training. Args: model (`PreTrainedModel`): the model to quantize. weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. include (`Optional[Union[str, List[str]]]`): Patterns constituting the allowlist. If provided, layer names must match at least one pattern from the allowlist. exclude (`Optional[Union[str, List[str]]]`): Patterns constituting the denylist. If provided, layer names must not match any patterns from the denylist. """ if not isinstance(model, ModelMixin): raise ValueError("The source model must be a diffusers model.") quantize( model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude ) freeze(model) return cls(model) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): if cls.base_class is None: raise ValueError("The `base_class` attribute needs to be configured.") if not is_accelerate_available(): raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") from accelerate import init_empty_weights if os.path.isdir(pretrained_model_name_or_path): working_dir = pretrained_model_name_or_path else: working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs) # Look for a quantization map qmap_path = os.path.join(working_dir, cls._qmap_name()) if not os.path.exists(qmap_path): raise ValueError( f"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?" ) # Look for original model config file. model_config_path = os.path.join(working_dir, CONFIG_NAME) if not os.path.exists(model_config_path): raise ValueError(f"{CONFIG_NAME} not found in {pretrained_model_name_or_path}.") with open(qmap_path, "r", encoding="utf-8") as f: qmap = json.load(f) with open(model_config_path, "r", encoding="utf-8") as f: original_model_cls_name = json.load(f)["_class_name"] configured_cls_name = cls.base_class.__name__ if configured_cls_name != original_model_cls_name: raise ValueError( f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." ) # Create an empty model config = cls.base_class.load_config(pretrained_model_name_or_path, **kwargs) with init_empty_weights(): model = cls.base_class.from_config(config) # Look for the index of a sharded checkpoint checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME) if os.path.exists(checkpoint_file): # Convert the checkpoint path to a list of shards _, sharded_metadata = _get_checkpoint_shard_files(working_dir, checkpoint_file) # Create a mapping for the sharded safetensor files state_dict = ShardedStateDict(working_dir, sharded_metadata["weight_map"]) else: # Look for a single checkpoint file checkpoint_file = os.path.join(working_dir, SAFETENSORS_WEIGHTS_NAME) if not os.path.exists(checkpoint_file): raise ValueError(f"No safetensor weights found in {pretrained_model_name_or_path}.") # Get state_dict from model checkpoint state_dict = load_state_dict(checkpoint_file) # Requantize and load quantized weights from state_dict requantize(model, state_dict=state_dict, quantization_map=qmap) model.eval() return cls(model) def _save_pretrained(self, save_directory: Path) -> None: self._wrapped.save_pretrained(save_directory) # Save quantization map to be able to reload the model qmap_name = os.path.join(save_directory, self._qmap_name()) qmap = quantization_map(self._wrapped) with open(qmap_name, "w", encoding="utf8") as f: json.dump(qmap, f, indent=4) class QuantizedPixArtTransformer2DModel(QuantizedDiffusersModel): base_class = PixArtTransformer2DModel ================================================ FILE: optimum/quanto/models/shared_dict.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections.abc import Mapping from typing import Any, Dict from safetensors import safe_open class ShardedStateDict(Mapping): """A pytorch state_dict stored in multiple safetensors files This class implements the `collections.abc.Mapping` interface. It can be passed to `torch.nn.Module.load_state_dict()` to recursively load the module tensors. """ def __init__(self, base_dir: str, tensor_index: Dict[str, str]): self._base_dir = base_dir self._index = tensor_index self._handles = {} def __iter__(self): yield from self._index def __len__(self): return self._index.__len__() def __getitem__(self, key: Any) -> Any: filename = self._index.__getitem__(key) if filename not in self._handles: f = safe_open(os.path.join(self._base_dir, filename), framework="pytorch") self._handles[filename] = f f = self._handles[filename] return f.get_tensor(key) def __contains__(self, key: object) -> bool: return self._index.__contains__(key) def keys(self): return self._index.keys() ================================================ FILE: optimum/quanto/models/transformers_models.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os from pathlib import Path from typing import Any, List, Optional, Union from huggingface_hub import ModelHubMixin, snapshot_download from ..nn import QModuleMixin from ..quantize import Optimizer, freeze, qtype, quantization_map, quantize, requantize from . import is_transformers_available from .shared_dict import ShardedStateDict __all__ = ["QuantizedTransformersModel", "QuantizedModelForCausalLM"] if not is_transformers_available(): raise ImportError(f"{__all__} require the transformers library") from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available class QuantizedTransformersModel(ModelHubMixin): BASE_NAME = "quanto" auto_class = None def __init__(self, model: PreTrainedModel): if not isinstance(model, PreTrainedModel) or len(quantization_map(model)) == 0: raise ValueError("The source model must be a quantized transformers model.") self._wrapped = model def __getattr__(self, name: str) -> Any: """If an attribute is not found in this class, look in the wrapped module.""" try: return super().__getattr__(name) except AttributeError: wrapped = self.__dict__["_wrapped"] return getattr(wrapped, name) def forward(self, *args, **kwargs): return self._wrapped.forward(*args, **kwargs) def __call__(self, *args, **kwargs): return self._wrapped.forward(*args, **kwargs) def __repr__(self): return self._wrapped.__repr__() @staticmethod def _qmap_name(): return f"{QuantizedTransformersModel.BASE_NAME}_qmap.json" @classmethod def quantize( cls, model: PreTrainedModel, weights: Optional[Union[str, qtype]] = None, activations: Optional[Union[str, qtype]] = None, optimizer: Optional[Optimizer] = None, include: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, ): """Quantize the specified model By default, each layer of the model will be quantized if is quantizable. If include patterns are specified, the layer name must match one of them. If exclude patterns are specified, the layer must not match one of them. Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See https://docs.python.org/3/library/fnmatch.html for more details. Note: quantization happens in-place and modifies the original model. Note that the resulting quantized model will be frozen: if you wish to do quantization-aware training then you should use `optimum.quanto.quantize` instead, and call `optimum.quanto.freeze` only after the training. Args: model (`PreTrainedModel`): the model to quantize. weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. include (`Optional[Union[str, List[str]]]`): Patterns constituting the allowlist. If provided, layer names must match at least one pattern from the allowlist. exclude (`Optional[Union[str, List[str]]]`): Patterns constituting the denylist. If provided, layer names must not match any patterns from the denylist. """ if not isinstance(model, PreTrainedModel): raise ValueError("The source model must be a transformers model.") quantize( model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude ) freeze(model) return cls(model) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): if cls.auto_class is None: raise ValueError( "Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead." ) if not is_accelerate_available(): raise ValueError("Reloading a quantized transformers model requires the accelerate library.") from accelerate import init_empty_weights if os.path.isdir(pretrained_model_name_or_path): working_dir = pretrained_model_name_or_path else: working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs) # Look for a quantization map qmap_path = os.path.join(working_dir, cls._qmap_name()) if not os.path.exists(qmap_path): raise ValueError( f"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?" ) with open(qmap_path, "r", encoding="utf-8") as f: qmap = json.load(f) # Create an empty model config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) with init_empty_weights(): model = cls.auto_class.from_config(config) # Look for the index of a sharded checkpoint checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME) if os.path.exists(checkpoint_file): # Convert the checkpoint path to a list of shards checkpoint_file, sharded_metadata = get_checkpoint_shard_files(working_dir, checkpoint_file) # Create a mapping for the sharded safetensor files state_dict = ShardedStateDict(working_dir, sharded_metadata["weight_map"]) else: # Look for a single checkpoint file checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_NAME) if not os.path.exists(checkpoint_file): raise ValueError(f"No safetensor weights found in {pretrained_model_name_or_path}.") # Get state_dict from model checkpoint state_dict = load_state_dict(checkpoint_file) # Requantize and load quantized weights from state_dict requantize(model, state_dict=state_dict, quantization_map=qmap) if getattr(model.config, "tie_word_embeddings", True): # Tie output weight embeddings to input weight embeddings # Note that if they were quantized they would NOT be tied model.tie_weights() # Set model in evaluation mode as it is done in transformers model.eval() return cls(model) def _save_pretrained(self, save_directory: Path) -> None: model = self._wrapped if getattr(model.config, "tie_word_embeddings", True): # The original model had tied embedding inputs and outputs if isinstance(model.get_input_embeddings(), QModuleMixin) or isinstance( model.get_output_embeddings(), QModuleMixin ): # At least one of the two is quantized, so they are not tied anymore model.config.tie_word_embeddings = False self._wrapped.save_pretrained(save_directory, safe_serialization=True) # Save quantization map to be able to reload the model qmap_name = os.path.join(save_directory, self._qmap_name()) qmap = quantization_map(self._wrapped) with open(qmap_name, "w", encoding="utf8") as f: json.dump(qmap, f, indent=4) class QuantizedModelForCausalLM(QuantizedTransformersModel): auto_class = AutoModelForCausalLM ================================================ FILE: optimum/quanto/nn/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .qconv2d import * from .qlayernorm import * from .qlinear import * from .qmodule import * ================================================ FILE: optimum/quanto/nn/qconv2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule __all__ = ["QConv2d"] @register_qmodule(torch.nn.Conv2d) class QConv2d(QModuleMixin, torch.nn.Conv2d): @classmethod def qcreate( cls, module, weights: qtype, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, device: Optional[torch.device] = None, ): return cls( in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, bias=module.bias is not None, padding_mode=module.padding_mode, dtype=module.weight.dtype, device=device, weights=weights, activations=activations, optimizer=optimizer, ) def forward(self, input: torch.Tensor) -> torch.Tensor: return self._conv_forward(input, self.qweight, self.bias) ================================================ FILE: optimum/quanto/nn/qlayernorm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule __all__ = ["QLayerNorm"] @register_qmodule(torch.nn.LayerNorm) class QLayerNorm(QModuleMixin, torch.nn.LayerNorm): @classmethod def qcreate( cls, module, weights: Optional[qtype] = None, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, device: Optional[torch.device] = None, ): if activations is None: return None dtype = None if module.weight is None else module.weight.dtype return cls( module.normalized_shape, module.eps, module.elementwise_affine, module.bias is not None, dtype=dtype, device=device, weights=None, # We never quantize QLayerNorm weights activations=activations, optimizer=None, # We never quantize QLayerNorm weights ) def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) ================================================ FILE: optimum/quanto/nn/qlinear.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ..tensor import Optimizer, qtype from .qmodule import QModuleMixin, register_qmodule __all__ = ["QLinear"] @register_qmodule(torch.nn.Linear) class QLinear(QModuleMixin, torch.nn.Linear): @classmethod def qcreate( cls, module, weights: qtype, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, device: Optional[torch.device] = None, ): return cls( module.in_features, module.out_features, module.bias is not None, dtype=module.weight.dtype, device=device, weights=weights, activations=activations, optimizer=optimizer, quantize_input=True, ) def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, self.qweight, bias=self.bias) ================================================ FILE: optimum/quanto/nn/qmodule.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from typing import Optional, Union import torch from ..tensor import ( AbsmaxOptimizer, ActivationQBytesTensor, MaxOptimizer, Optimizer, QTensor, SymmetricOptimizer, WeightQBitsTensor, WeightQBytesTensor, qint2, qint4, qtype, qtypes, quantize_activation, quantize_weight, ) __all__ = ["QModuleMixin", "register_qmodule", "quantize_module"] _QMODULE_TABLE = {} def register_qmodule(module_cls): """ Used for registering a new quantized module. The QModule must implement two abstract methods: - qcreate: class method to instantiate a new QModule from an nn.Module, without copying its weights, - forward: instance method for quantized inference. The code to register a new module looks like: ``` @register_qmodule() class MyQModule(QModuleMixin, ): @classmethod def qcreate(cls, module: torch.nn.Module, weights: Optional[qtype], activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None): ... def forward(self, input: torch.Tensor) -> torch.Tensor: ... ``` """ def wrapper(cls): _QMODULE_TABLE[module_cls] = cls return cls return wrapper def quantize_module( module, weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, ): for cls in _QMODULE_TABLE: if isinstance(module, cls): qcls = _QMODULE_TABLE[cls] return qcls.from_module(module, weights=weights, activations=activations, optimizer=optimizer) return None class QModuleMixin(ABC): def __init__( self, *args, weights: Optional[Union[qtype, str]] = None, activations: Optional[Union[qtype, str]] = None, optimizer: Optional[Optimizer] = None, quantize_input: Optional[bool] = False, device: Optional[torch.device] = None, **kwargs, ): # The tests below are meant to help people writing their own quantized Module class mro = self.__class__.__mro__ if torch.nn.Module not in mro: raise TypeError("Quantized modules must inherit from a torch.nn.Module class") if mro.index(__class__) > mro.index(torch.nn.Module): raise TypeError( "QModuleMixin must be placed before any torch.nn.Module class in quantized module inheritance." ) # This will setup the torch.nn.Module super().__init__(*args, device=device, **kwargs) if weights is not None and not isinstance(weights, qtype): weights = qtypes[weights] if activations is not None and not isinstance(activations, qtype): activations = qtypes[activations] self.weight_qtype = weights self.weight_group_size = None if self.weight_qtype in (qint2, qint4): out_features = self.weight.shape[0] in_features = self.weight.numel() // out_features group_size = 128 if in_features > group_size: while in_features % group_size != 0 and group_size > 32: group_size -= 32 if in_features % group_size == 0: self.weight_group_size = group_size self.activation_qtype = activations self._quantize_hooks = {} if activations is not None: if quantize_input: self._quantize_hooks["input"] = self.register_forward_pre_hook(self.quantize_input) self._quantize_hooks["output"] = self.register_forward_hook(self.quantize_output) if optimizer is None and self.weight_qtype is not None: optimizer = AbsmaxOptimizer() if self.weight_qtype.bits == 8 else MaxOptimizer() self.optimizer = optimizer scale_dtype = torch.float32 if self.weight is None else self.weight.dtype self.register_buffer("input_scale", torch.ones((), dtype=scale_dtype, device=device)) self.register_buffer("output_scale", torch.ones((), dtype=scale_dtype, device=device)) def disable_output_quantization(self): if "output" in self._quantize_hooks: self._quantize_hooks["output"].remove() def _save_to_state_dict(self, destination, prefix, keep_vars): if self.weight_qtype is None or not self.frozen: # Save standard weight Tensor destination[prefix + "weight"] = ( self.weight if (self.weight is None or keep_vars) else self.weight.detach() ) else: # Save QTensor using dedicated method self.weight.save_to_state_dict(destination, prefix + "weight.", keep_vars) if self.bias is not None: destination[prefix + "bias"] = self.bias if keep_vars else self.bias.detach() destination[prefix + "input_scale"] = self.input_scale if keep_vars else self.input_scale.detach() destination[prefix + "output_scale"] = self.output_scale if keep_vars else self.output_scale.detach() def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): weight_name = prefix + "weight" if self.weight_qtype is not None and weight_name not in state_dict: # The weight Tensor is not present because it is a flattened QTensor weight_prefix = weight_name + "." # note: deserialized_weight can be None if a key is missing in the state_dict if self.weight_qtype.bits == 8: deserialized_weight = WeightQBytesTensor.load_from_state_dict( state_dict, weight_prefix, qtype=self.weight_qtype, axis=0, size=self.weight.size(), stride=self.weight.stride(), activation_qtype=self.activation_qtype, missing_keys=missing_keys, ) else: deserialized_weight = WeightQBitsTensor.load_from_state_dict( state_dict, weight_prefix, qtype=self.weight_qtype, axis=0, group_size=self.weight_group_size, size=self.weight.size(), stride=self.weight.stride(), missing_keys=missing_keys, ) if deserialized_weight is not None: deserialized_weight = deserialized_weight.optimize() assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) if assign_to_params_buffers and (deserialized_weight is not None): self.weight = torch.nn.Parameter(deserialized_weight) elif deserialized_weight is not None: if type(self.weight.data) is not type(deserialized_weight): # Reloading frozen weights into unfrozen module: move to the correct device and force assignment self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) else: # FIXME: here we should copy frozen weights into frozen module, but this leads to grad error self.weight = torch.nn.Parameter(deserialized_weight.to(self.weight.device)) super()._load_from_state_dict( state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs ) @classmethod def from_module( cls, module: torch.nn.Module, weights: Optional[qtype] = None, activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, ): # Create the quantized module on the meta device to prevent weights intialization qmodule = cls.qcreate(module, weights, activations, optimizer, device="meta") if qmodule is None: return None # Move the quantized module to the target device, but with empty weights device = torch.device("cpu") if module.weight is None else module.weight.device qmodule = qmodule.to_empty(device=device) # Set scales that were initialized to empty values qmodule.input_scale = torch.ones_like(qmodule.input_scale) qmodule.output_scale = torch.ones_like(qmodule.output_scale) with torch.no_grad(): qmodule.weight = module.weight if module.bias is not None: qmodule.bias = module.bias return qmodule.to(device) @classmethod def qcreate( cls, module: torch.nn.Module, weights: Optional[qtype], activations: Optional[qtype] = None, optimizer: Optional[Optimizer] = None, device: Optional[torch.device] = None, ): raise NotImplementedError @property def qweight(self): """Return the module quantized weight When the module is frozen or does not quantize its weight parameter, it simply returns the weight. When the module is not frozen, this property is required to add the dynamic quantization of the weight parameter to the graph and allow gradients to be propagated to the underlying weight float values. """ if self.weight_qtype is None: # QModule that does not quantize its weights return None if isinstance(self.weight, QTensor): # Frozen QModule return self.weight # Quantize dynamically the weights per-axis if isinstance(self.optimizer, SymmetricOptimizer): scale = self.optimizer(self.weight, qtype=self.weight_qtype, axis=0) shift = None else: optimizer_kwargs = {"qtype": self.weight_qtype, "axis": 0, "group_size": self.weight_group_size} if self.weight.device.type == "xpu": optimizer_kwargs.update({"zeropoint": True}) scale, shift = self.optimizer(self.weight, **optimizer_kwargs) return quantize_weight( self.weight, qtype=self.weight_qtype, axis=0, scale=scale, shift=shift, group_size=self.weight_group_size, activation_qtype=self.activation_qtype, ) def qforward(self, input: torch.Tensor) -> torch.Tensor: raise NotImplementedError def quantize_input(self, module: torch.nn.Module, input: torch.Tensor) -> torch.Tensor: input = input[0] if isinstance(input, ActivationQBytesTensor): if input.qtype != self.activation_qtype: raise ValueError( "Models with heterogeneous quantized activations are not supported:" f" expected {self.activation_qtype.name} input but got {input.qtype.name} instead." ) else: input = quantize_activation(input, qtype=self.activation_qtype, scale=self.input_scale) return input def quantize_output( self, module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor, ) -> torch.Tensor: return quantize_activation(output, qtype=self.activation_qtype, scale=self.output_scale) def freeze(self): qweight = self.qweight if qweight is not None: # Replace float weights by quantized weights self.weight = torch.nn.Parameter(qweight) @property def frozen(self): return isinstance(self.weight, QTensor) ================================================ FILE: optimum/quanto/quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from fnmatch import fnmatch from typing import Any, Dict, List, Optional, Union import torch from .nn import QModuleMixin, quantize_module from .tensor import Optimizer, qtype __all__ = ["quantize", "freeze", "requantize", "quantization_map"] def set_module_by_name(parent_module, name, child_module): module_names = name.split(".") if len(module_names) == 1: setattr(parent_module, name, child_module) else: parent_module_name = name[: name.rindex(".")] parent_module = parent_module.get_submodule(parent_module_name) setattr(parent_module, module_names[-1], child_module) def _quantize_submodule( model: torch.nn.Module, name: str, module: torch.nn.Module, weights: Optional[Union[str, qtype]] = None, activations: Optional[Union[str, qtype]] = None, optimizer: Optional[Optimizer] = None, ): qmodule = quantize_module(module, weights=weights, activations=activations, optimizer=optimizer) if qmodule is not None: set_module_by_name(model, name, qmodule) qmodule.name = name for name, param in module.named_parameters(): # Save device memory by clearing parameters setattr(module, name, None) del param def quantize( model: torch.nn.Module, weights: Optional[Union[str, qtype]] = None, activations: Optional[Union[str, qtype]] = None, optimizer: Optional[Optimizer] = None, include: Optional[Union[str, List[str]]] = None, exclude: Optional[Union[str, List[str]]] = None, ): """Quantize the specified model submodules Recursively quantize the submodules of the specified parent model. Only modules that have quantized counterparts will be quantized. If include patterns are specified, the submodule name must match one of them. If exclude patterns are specified, the submodule must not match one of them. Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See https://docs.python.org/3/library/fnmatch.html for more details. Note: quantization happens in-place and modifies the original model and its descendants. Args: model (`torch.nn.Module`): the model whose submodules will be quantized. weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. include (`Optional[Union[str, List[str]]]`): Patterns constituting the allowlist. If provided, module names must match at least one pattern from the allowlist. exclude (`Optional[Union[str, List[str]]]`): Patterns constituting the denylist. If provided, module names must not match any patterns from the denylist. """ if include is not None: include = [include] if isinstance(include, str) else include if exclude is not None: exclude = [exclude] if isinstance(exclude, str) else exclude for name, m in model.named_modules(): if include is not None and not any(fnmatch(name, pattern) for pattern in include): continue if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): continue _quantize_submodule(model, name, m, weights=weights, activations=activations, optimizer=optimizer) def requantize( model: torch.nn.Module, state_dict: Dict[str, Any], quantization_map: Dict[str, Dict[str, str]], device: torch.device = None, ): if device is None: device = next(model.parameters()).device if device.type == "meta": device = torch.device("cpu") # Quantize the model with parameters from the quantization map for name, m in model.named_modules(): qconfig = quantization_map.get(name, None) if qconfig is not None: weights = qconfig["weights"] if weights == "none": weights = None activations = qconfig["activations"] if activations == "none": activations = None _quantize_submodule(model, name, m, weights=weights, activations=activations) # Move model parameters and buffers to CPU before materializing quantized weights for name, m in model.named_modules(): def move_tensor(t, device): if t.device.type == "meta": return torch.empty_like(t, device=device) return t.to(device) for name, param in m.named_parameters(recurse=False): setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu"))) for name, param in m.named_buffers(recurse=False): setattr(m, name, move_tensor(param, "cpu")) # Move to target device model.to(device) # Load the quantized model weights model.load_state_dict(state_dict, strict=False) def freeze(model): for name, m in model.named_modules(): if isinstance(m, QModuleMixin): m.freeze() def quantization_map(model: torch.nn.Module) -> Dict[str, Dict[str, str]]: """Returns the quantization map of a module The quantization map is a dictionary of quantization parameters indexed by the module submodule names (including prefix). This is mainly used for serialization. Args: model (`torch.nn.Module`): the root module to map. Returns: a dictionary of quantization parameters indexed by layer names. """ config = {} for name, m in model.named_modules(): if isinstance(m, QModuleMixin): config[name] = { "weights": "none" if m.weight_qtype is None else m.weight_qtype.name, "activations": "none" if m.activation_qtype is None else m.activation_qtype.name, } return config ================================================ FILE: optimum/quanto/subpackage/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .commands import * ================================================ FILE: optimum/quanto/subpackage/commands/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .base import * ================================================ FILE: optimum/quanto/subpackage/commands/base.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from optimum.commands.base import BaseOptimumCLICommand, CommandInfo from optimum.commands.optimum_cli import optimum_cli_subcommand from .quantize import QuantizeCommand __all__ = ["QuantoCommand"] @optimum_cli_subcommand() class QuantoCommand(BaseOptimumCLICommand): COMMAND = CommandInfo(name="quanto", help="Hugging Face models quantization tools") SUBCOMMANDS = ( CommandInfo( name="quantize", help="Quantize Hugging Face models.", subcommand_class=QuantizeCommand, ), ) ================================================ FILE: optimum/quanto/subpackage/commands/quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hugging Face models quantization command-line interface class.""" from typing import TYPE_CHECKING import torch from optimum.commands.base import BaseOptimumCLICommand from optimum.exporters.tasks import TasksManager from ...models import QuantizedTransformersModel if TYPE_CHECKING: from argparse import ArgumentParser SUPPORTED_LIBRARIES = ["transformers"] def parse_quantize_args(parser: "ArgumentParser"): required_group = parser.add_argument_group("Required arguments") required_group.add_argument( "output", type=str, help="The path to save the quantized model.", ) required_group.add_argument( "-m", "--model", type=str, required=True, help="Hugging Face Hub model id or path to a local model.", ) required_group.add_argument( "--weights", type=str, default="int8", choices=["int2", "int4", "int8", "float8"], help="The Hugging Face library to use to load the model.", ) optional_group = parser.add_argument_group("Optional arguments") optional_group.add_argument( "--revision", type=str, default=None, help="The Hugging Face model revision.", ) optional_group.add_argument( "--trust_remote_code", action="store_true", default=False, help="Trust remote code when loading the model.", ) optional_group.add_argument( "--library", type=str, default=None, choices=SUPPORTED_LIBRARIES, help="The Hugging Face library to use to load the model.", ) optional_group.add_argument( "--task", type=str, default=None, help="The model task (useful for models supporting multiple tasks).", ) optional_group.add_argument( "--torch_dtype", type=str, default="auto", choices=["auto", "fp16", "bf16"], help="The torch dtype to use when loading the model weights.", ) optional_group.add_argument( "--device", type=str, default="cpu", help="The device to use when loading the model.", ) class QuantizeCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): return parse_quantize_args(parser) def run(self): model_name_or_path = self.args.model library_name = self.args.library if library_name is None: library_name = TasksManager.infer_library_from_model(model_name_or_path) if library_name not in SUPPORTED_LIBRARIES: raise ValueError( f"{library_name} models are not supported by this CLI, but can be quantized using the python API directly." ) task = self.args.task if task is None: task = TasksManager.infer_task_from_model(model_name_or_path) torch_dtype = self.args.torch_dtype if torch_dtype != "auto": torch_dtype = torch.float16 if self.args.torch_dtype == "fp16" else torch.bfloat16 model = TasksManager.get_model_from_task( task, model_name_or_path, revision=self.args.revision, trust_remote_code=self.args.trust_remote_code, framework="pt", torch_dtype=torch_dtype, device=torch.device(self.args.device), library_name=library_name, low_cpu_mem_usage=True, ) weights = f"q{self.args.weights}" qmodel = QuantizedTransformersModel.quantize(model, weights=weights) qmodel.save_pretrained(self.args.output) ================================================ FILE: optimum/quanto/tensor/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .activations import * from .core import * from .grouped import * from .optimizers import * from .qbits import * from .qbytes import * from .qtensor import * from .qtype import * from .weights import * ================================================ FILE: optimum/quanto/tensor/activations/__init__.py ================================================ from .qbytes import * from .quantization import * ================================================ FILE: optimum/quanto/tensor/activations/qbytes.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.autograd import Function from ..qbytes import QBytesTensor from ..qtensor import qfallback from ..qtype import qtype, qtypes __all__ = ["ActivationQBytesTensor"] class ActivationQBytesQuantizer(Function): @staticmethod def forward(ctx, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor: if qtype.bits != 8: raise ValueError("QBytesTensor can only be of 8-bit qtype") size = base.size() stride = base.stride() data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=None, scale=scale) # The instantiation of the quantized tensor must happen within the context of the Function # for the autograd magic to work. return ActivationQBytesTensor(qtype, size, stride, data, scale) @staticmethod def backward(ctx, gO): # For autograd, quantization is a no-op return gO, None, None, None, None, None class ActivationQBytesTensor(QBytesTensor): @staticmethod def __new__(cls, qtype, size, stride, data, scale, requires_grad=False): assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, size, stride, data, scale, requires_grad=False): super().__init__(qtype, None, size, stride, data, scale, requires_grad) @classmethod def quantize(cls, base: torch.Tensor, qtype: qtype, scale: torch.Tensor) -> torch.Tensor: return ActivationQBytesQuantizer.apply(base, qtype, scale) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 3 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return ActivationQBytesTensor(qtype, size, stride, data, scale) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): from .qbytes_ops import get_qbytestensor_op_dispatch kwargs = kwargs or {} # Do not use directly op, but rather its overload op = op.overloadpacket qdispatch = get_qbytestensor_op_dispatch(op) if qdispatch is not None: return qdispatch(*args, **kwargs) # No dispatch available: qfallback return qfallback(op, *args, **kwargs) ================================================ FILE: optimum/quanto/tensor/activations/qbytes_ops.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numbers from functools import partial from typing import Callable, List import torch from ..core import dtype_info from ..qtensor import QTensor, qfallback from ..qtype import qint8 from .qbytes import ActivationQBytesTensor from .quantization import quantize_activation __all__ = ["get_qbytestensor_op_dispatch", "register_qbytestensor_op"] _QBYTESTENSOR_OP_TABLE = {} def register_qbytestensor_op(aten_ops: List[Callable]): """ Used for registering a new __torch_dispatch__ aten operation to QBytesTensor. The code to register a new operation looks like: @register_qbytestensor_op(list_of_ops) def foo(op, *args, **kwargs): """ def wrapper(op): for aten_op in aten_ops: _QBYTESTENSOR_OP_TABLE[aten_op] = partial(op, aten_op) return wrapper def get_qbytestensor_op_dispatch(aten_op): return _QBYTESTENSOR_OP_TABLE.get(aten_op, None) def is_scalar(t): return isinstance(t, numbers.Number) or type(t) is torch.Tensor and len(t.shape) == 0 @register_qbytestensor_op([torch.ops.aten._to_copy, torch.ops.aten.to]) def _to_copy(op, t, dtype=None, **kwargs): # For data, ignore dtype and use the inner type instead out_data = op(t._data, dtype=t._data.dtype, **kwargs) # Apply the new dtype on the scale only out_scale = op(t._scale, dtype=dtype, **kwargs) return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale) @register_qbytestensor_op([torch.ops.aten.detach]) def detach(op, t): # Detach both data and scale out_data = op(t._data) out_scale = op(t._scale) return ActivationQBytesTensor(t.qtype, t.size(), t.stride(), out_data, out_scale) @register_qbytestensor_op([torch.ops.aten.cat]) def cat(op, inputs, dim=0): if len(inputs) == 2: t1, t2 = inputs # Only quantized tensors with identical scalar scales can be concatenated if ( isinstance(t1, ActivationQBytesTensor) and isinstance(t2, ActivationQBytesTensor) and torch.equal(t1._scale, t2._scale) and t1.qtype == t2.qtype ): if t1.qtype.is_floating_point or t2.qtype.is_floating_point: # Cat is not supported for float8 return qfallback(op, inputs, dim) out_data = op([t1._data, t2._data], dim) return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale) return qfallback(op, inputs, dim) @register_qbytestensor_op([torch.ops.aten.lt]) def lt(op, input, other): # Only quantized tensors with identical scales can be compared if ( isinstance(input, ActivationQBytesTensor) and isinstance(other, ActivationQBytesTensor) and torch.equal(input._scale, other._scale) ): return op(input._data, other._data) return qfallback(op, input, other) @register_qbytestensor_op([torch.ops.aten.clone]) def clone(op, t, memory_format=torch.preserve_format): # We need to restore the data original shape before cloning to get the correct strides data_shape = t._data.shape out_data = t._data.reshape(t.shape) out_data = op(t._data, memory_format=memory_format) out_stride = out_data.stride() out_data = out_data.reshape(data_shape) out_scale = op(t._scale, memory_format=memory_format) return ActivationQBytesTensor(t.qtype, t.size(), out_stride, out_data, out_scale) @register_qbytestensor_op([torch.ops.aten.copy_]) def copy_(op, dest, src): assert dest.qtype == src.qtype dest._data = op(dest._data, src._data) dest._scale = op(dest._scale, src._scale) return dest @register_qbytestensor_op([torch.ops.aten.div]) def div(op, input, other): if not is_scalar(other): return op(input.dequantize(), other) # We just divide the scale return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, op(input._scale, other)) @register_qbytestensor_op([torch.ops.aten.neg]) def neg(op, input, *args, **kwargs): if input.qtype.is_floating_point: # Neg is not supported for float8 return op(input.dequantize(), *args, **kwargs) out_data = op(input._data, *args, **kwargs) return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale) @register_qbytestensor_op( [ torch.ops.aten.expand, torch.ops.aten.permute, torch.ops.aten.select, torch.ops.aten.slice, torch.ops.aten.unsqueeze, ] ) def unary_type_agnostic_op(op, input, *args, **kwargs): if input.axis is not None: return op(input.dequantize(), *args, **kwargs) # When quantization is per-tensor, these operations can be transparently applied # without modifying the scale. out_data = op(input._data, *args, **kwargs) return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale) @register_qbytestensor_op([torch.ops.aten.is_same_size]) def is_same_size(op, input, other): a = input._data if isinstance(input, ActivationQBytesTensor) else input b = other._data if isinstance(other, ActivationQBytesTensor) else other return op(a, b) def cannot_mm(t: QTensor): """True if the QTensor data cannot be passed to an mm op""" return t.axis is not None and t.size() != t._data.size() @register_qbytestensor_op([torch.ops.aten.bmm]) def bmm(op, input, other): if not isinstance(input, ActivationQBytesTensor): return op(input, other.dequantize()) if not isinstance(other, QTensor) or input.axis is not None: return op(input.dequantize(), other) if input.qtype != qint8 or other.qtype != qint8 or cannot_mm(other): return qfallback(op, input, other) # Cast data to float32 and do the operation out_data = op(input._data.to(torch.float32), other._data.to(torch.float32)) out_scale = (input._scale * other._scale).to(torch.float32) return (out_data * out_scale).to(input._scale.dtype) @register_qbytestensor_op([torch.ops.aten.mul]) def mul(op, input, other): # If one of the multiplicands is a scalar, just multiply the scale if is_scalar(input): return ActivationQBytesTensor(other.qtype, other.size(), other.stride(), other._data, input * other._scale) if is_scalar(other): return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), input._data, other * input._scale) return qfallback(op, input, other) @register_qbytestensor_op([torch.ops.aten.relu]) def relu(op, input): if input.qtype.is_floating_point: # Relu is not supported for float8 types return qfallback(op, input) out_data = op(input._data) return ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale) @register_qbytestensor_op([torch.ops.aten._softmax]) def _softmax(op, input, dim, half_to_float): # Softmax must be performed in float float_data = op(input.dequantize(), dim, half_to_float) # Since softmax is normalized, we know the optimal scale out_scale = torch.tensor(1 / dtype_info(input.qtype.dtype).max, dtype=input._scale.dtype).to(input.device) return quantize_activation(float_data, qtype=input.qtype, scale=out_scale) @register_qbytestensor_op([torch.ops.aten.stack]) def stack(op, inputs, dim=0): if len(inputs) == 2: t1, t2 = inputs # Only quantized tensors with identical scales can be stacked if ( isinstance(t1, ActivationQBytesTensor) and isinstance(t2, ActivationQBytesTensor) and t1.axis is None and t2.axis is None and torch.equal(t1._scale, t2._scale) and t1.qtype == t2.qtype ): out_data = op([t1._data, t2._data], dim) return ActivationQBytesTensor(t1.qtype, out_data.size(), out_data.stride(), out_data, t1._scale) return qfallback(inputs, dim) @register_qbytestensor_op([torch.ops.aten.split]) def split(op, input, *args, **kwargs): if input.axis is not None: return qfallback(op, input, *args, **kwargs) out_datas = op(input._data, *args, **kwargs) return [ ActivationQBytesTensor(input.qtype, input.size(), input.stride(), out_data, input._scale) for out_data in out_datas ] @register_qbytestensor_op([torch.ops.aten.transpose]) def transpose(op, input, *args): out_data = op(input._data, *args) out_size = out_data.size() out_stride = out_data.stride() out_scale = input._scale return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale) @register_qbytestensor_op([torch.ops.aten.t]) def transpose2d(op, input): out_data = op(input._data) out_scale = input._scale # Manually reverse size and stride because we cannot trust the out_data shape dim0, dim1 = input.size() out_size = torch.Size([dim1, dim0]) out_stride = input.stride()[::-1] return ActivationQBytesTensor(input.qtype, out_size, out_stride, out_data, out_scale) @register_qbytestensor_op([torch.ops.aten.view, torch.ops.aten._unsafe_view]) def view(op, input, *shape): if input.axis is None: # The view is transparent for QTensor with scalar scales out_data = op(input._data, *shape) return ActivationQBytesTensor(input.qtype, out_data.size(), out_data.stride(), out_data, input._scale) return qfallback(op, input, *shape) @register_qbytestensor_op([torch.ops.aten.where]) def where(op, condition, input, other): if isinstance(condition, QTensor) or isinstance(other, QTensor): raise NotImplementedError float_data = op(condition, input.dequantize(), other) if input.axis is None: # We requantize with the input scale return quantize_activation(float_data, qtype=input.qtype, scale=input._scale) return float_data ================================================ FILE: optimum/quanto/tensor/activations/quantization.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from ..qtype import qtype from .qbytes import ActivationQBytesTensor __all__ = ["quantize_activation"] def quantize_activation(t: torch.Tensor, qtype: qtype, scale: torch.Tensor): """Quantize an activation Tensor. Activations are always quantized per-tensor with a scalar scale. Args: base (`torch.Tensor`): the Tensor to quantize qtype (`quanto.qtype`): The target quantization type scale (`torch.Tensor`): The scalar quantization scale Returns: A quantized Tensor. """ if scale.numel() != 1: raise ValueError("Parameter scale must be a scalar because activations can only be quantized per-tensor") return ActivationQBytesTensor.quantize(t, qtype, scale) ================================================ FILE: optimum/quanto/tensor/core.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch __all__ = ["axis_to_dim", "dtype_info"] def dtype_info(dtype): info = torch.finfo if dtype.is_floating_point else torch.iinfo return info(dtype) def axis_to_dim(t, axis): dim = list(range(t.ndim)) if axis == -1: dim = dim[:-1] else: dim.remove(axis) return dim ================================================ FILE: optimum/quanto/tensor/function.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch __all__ = ["QuantizedLinearFunction"] class QuantizedLinearFunction(torch.autograd.Function): """Quantized linear function. This is a quantized implementation of torch.nn.functional.linear. It defines explicitly the backward pass instead of letting pytorch build it by combining the gradients of the underlying quantized operations. This has two main benefits: - this saves computations, - this allows to use operations that do not have a registered backward method, such as quanto custom operations. The drawback is that the extra tensors involved in the quantization graph, such as the scales and shift, cannot be trained. This is however consistent with the quanto quantizers backward pass, that returns a zero gradient for these tensors. """ @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) output = torch.matmul(input, other.t()) if bias is not None: output = output + bias return output def backward(ctx, gO): input_gO = other_gO = bias_gO = None input, other = ctx.saved_tensors out_features, in_features = other.shape if ctx.needs_input_grad[0]: # grad(A@(B.t()) = gO => grad(A) = gO@(B.t().t()) = gO@B input_gO = torch.matmul(gO, other) if ctx.needs_input_grad[1]: # grad(B@A.t()) = gO.t() => grad(B) = gO.t()@(A.t().t()) = gO.t()@A other_gO = torch.matmul(gO.view(-1, out_features).t(), input.view(-1, in_features)) if ctx.needs_input_grad[2]: # Bias gradient is the sum on all dimensions but the last one dim = tuple(range(gO.ndim - 1)) bias_gO = gO.sum(dim) return input_gO, other_gO, bias_gO ================================================ FILE: optimum/quanto/tensor/grouped.py ================================================ import math from typing import List import torch __all__ = ["group", "ungroup", "grouped_shape"] def grouped_shape(shape: List, axis: int, group_size: int) -> List: if axis not in (0, -1): raise ValueError("Axis must be 0 or -1 for group-wise quantization") n_groups = math.prod(shape) // group_size return (n_groups, group_size) if axis == 0 else (group_size, n_groups) def group(base: torch.Tensor, axis: int, group_size: int): if axis not in (0, -1): raise ValueError("Axis must be 0 or -1 for group-wise quantization") # In standard per-axis quantization, we have one scale per axis dim axis_dim = base.shape[axis] # This scale is evaluated over axis_numel items for each feature along axis axis_numel = base.numel() // axis_dim if group_size > axis_numel or axis_numel % group_size != 0: raise ValueError(f"Group size ({group_size}) must be a divisor of ({axis_numel})") # Group-wise quantization further splits axis_numel into multiple groups per axis axis_groups = axis_numel // group_size if axis == 0: # Easy-peasy: we simply need to reshape to (axis_dim * axis_groups, group_size) return base.reshape([-1, group_size]) # More difficult: reshape to (group_size, axis_dim * axis_groups) # First, split by groups, preserving the axis dimension grouped = base.reshape((axis_groups, group_size, axis_dim)) # Permute to (group_size, axis_dim, axis_groups) grouped = grouped.permute(1, 2, 0) return grouped.reshape(group_size, axis_dim * axis_groups) def ungroup(grouped: torch.Tensor, axis: int, orig_shape: torch.Size): if grouped.shape == orig_shape: return grouped if axis == 0: # No transposition required, just reshape return grouped.reshape(orig_shape) group_size = grouped.shape[0] if axis == -1 else grouped.shape[-1] axis_dim = orig_shape[axis] axis_groups = grouped.numel() // axis_dim // group_size ungrouped = grouped.reshape(group_size, axis_dim, axis_groups) # Permute to (axis_groups, group_size, axis_dim) ungrouped = ungrouped.permute(2, 0, 1) return ungrouped.reshape(orig_shape) ================================================ FILE: optimum/quanto/tensor/optimizers/__init__.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .absmax_optimizer import * from .affine_optimizer import * from .hqq_optimizer import * from .max_optimizer import * from .optimizer import * from .symmetric_optimizer import * ================================================ FILE: optimum/quanto/tensor/optimizers/absmax_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import torch from ..qtype import qtype from .symmetric_optimizer import SymmetricOptimizer __all__ = ["AbsmaxOptimizer"] class AbsmaxOptimizer(SymmetricOptimizer): def optimize( self, base: torch.Tensor, qtype: qtype, axis: Optional[int] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: base = torch.abs(base) if axis is None: rmax = torch.max(base) else: dim = list(range(1, base.ndim)) if (axis == 0) else list(range(0, base.ndim - 1)) rmax = torch.amax(torch.abs(base), dim=dim, keepdim=True) return rmax / qtype.qmax ================================================ FILE: optimum/quanto/tensor/optimizers/affine_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch from ..grouped import group from ..qtype import qtype from .optimizer import Optimizer __all__ = ["AffineOptimizer"] class AffineOptimizer(Optimizer): def __call__( self, base: torch.Tensor, qtype: qtype, axis: int, group_size: Optional[int] = None, zeropoint: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: base (`torch.Tensor`): the weight Tensor to quantize qtype (`quanto.qtype`): The target quantization type axis ('int`): The quantization axis (0 or -1) group_size (`Optional[int]`): The quantization group size zeropoint (`bool`): Allow an exact representation of zero. If True, the shifts are stored as integer instead of float, which results in a slightly smaller model, but might also reduce the model performance. Defaults to False. Returns: A tuple of scale, shift Tensor. """ if axis not in [0, -1]: raise ValueError("axis parameter must be 0 (first axis) or -1 (last axis)") if group_size is not None: base = group(base, axis, group_size) if axis is not None and base.shape[axis] == 1: axis = None scale, shift = self.optimize(base, qtype, axis) assert scale.dtype == base.dtype assert shift.dtype == base.dtype if zeropoint: # Round shift to make sure zero can be represented exactly using 'shift' as quantized value shift = torch.clamp(torch.round(shift / scale), 0, 2**qtype.bits - 1) shift = shift.to(torch.int8) if base.device.type == "xpu" else shift.to(torch.uint8) return scale, shift def optimize(self, base: torch.Tensor, qtype: qtype, axis: int) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError ================================================ FILE: optimum/quanto/tensor/optimizers/hqq_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple, Union import torch from ..qtype import qtype from ..weights import quantize_weight from .max_optimizer import MaxOptimizer __all__ = ["HqqOptimizer"] # Shrinking operator def shrink_lp_op(x: torch.Tensor, beta: float, lp_norm: float) -> torch.Tensor: if lp_norm == 1: return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) else: return torch.sign(x) * torch.nn.functional.relu( torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1) ) class HqqOptimizer(MaxOptimizer): """Implementation of the HQQ algorithm This is an implementation of the algorithm described in "Half-Quadratic Quantization of Large Machine Learning Models", by Hicham Badri and Appu Shaji (https://mobiusml.github.io/hqq_blog/). This is an adaption of the original implementation at https://github.com/mobiusml/hqq. """ def __init__( self, lp_norm: Optional[float] = 0.7, beta: Optional[int] = 1e1, kappa: Optional[float] = 1.01, iters: Optional[int] = 20, verbose: Optional[bool] = False, ) -> None: self.lp_norm = lp_norm self.beta = beta self.kappa = kappa self.iters = iters self.verbose = verbose def optimize( self, base: torch.Tensor, qtype: qtype, axis: int ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: scale, shift = super().optimize(base, qtype, axis) best_error = None beta = self.beta base_q = quantize_weight(base, qtype=qtype, axis=axis, scale=scale, shift=shift) for i in range(self.iters): error = base - base_q if best_error is None: best_error = float(torch.abs(base - base_q).mean()) if self.verbose: print(f"Start error: {best_error:.6f}") e = shrink_lp_op(error, beta, self.lp_norm) mean_axis = 0 if axis == -1 else -1 hqq_shift = torch.mean(base_q._data * scale - (base - e), axis=mean_axis, keepdim=True) base_q = quantize_weight(base, qtype=qtype, axis=axis, scale=scale, shift=hqq_shift) mean_error = float(torch.abs(base - base_q).mean()) if self.verbose: print(f"HQQ error at it #{i}: {mean_error:.6f}") if mean_error < best_error: best_error = mean_error shift = hqq_shift beta *= self.kappa else: break return scale, shift ================================================ FILE: optimum/quanto/tensor/optimizers/max_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union import torch from ..qtype import qtype from .affine_optimizer import AffineOptimizer __all__ = ["MaxOptimizer"] class MaxOptimizer(AffineOptimizer): def optimize( self, base: torch.Tensor, qtype: qtype, axis: int ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: dim = list(range(1, base.ndim)) if (axis == 0) else list(range(0, base.ndim - 1)) rmin = torch.amin(base, dim=dim, keepdim=True) rmax = torch.amax(base, dim=dim, keepdim=True) qmin = -(2 ** (qtype.bits - 1)) qmax = 2 ** (qtype.bits - 1) - 1 scale = (rmax - rmin) / (qmax - qmin) shift = -rmin return scale, shift ================================================ FILE: optimum/quanto/tensor/optimizers/optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from typing import Optional, Tuple, Union import torch __all__ = ["Optimizer"] class Optimizer(ABC): def __call__( self, base: torch.Tensor, bits: int, axis: int, group_size: Optional[int] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: raise NotImplementedError ================================================ FILE: optimum/quanto/tensor/optimizers/symmetric_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ..qtype import qtype from .optimizer import Optimizer __all__ = ["SymmetricOptimizer"] class SymmetricOptimizer(Optimizer): def __call__(self, base: torch.Tensor, qtype: qtype, axis: Optional[int] = None) -> torch.Tensor: if axis not in [None, 0, -1]: raise ValueError("axis parameter must be None, 0 (first axis) or -1 (last axis)") if axis is not None and base.shape[axis] == 1: axis = None scale = self.optimize(base, qtype, axis) assert scale.dtype == base.dtype return scale def optimize(self, base: torch.Tensor, qmax: float, axis: Optional[int] = None) -> torch.Tensor: raise NotImplementedError ================================================ FILE: optimum/quanto/tensor/packed.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.utils import _pytree as pytree __all__ = ["PackedTensor"] def pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor: """ Pack int4 / int2 weights in a uint8 tensor What packing means? Assume we have 4 values that are in 2bit but encoded in 8bit (because torch does not have native support for 2-bit datatypes) > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000 We can pack them in a single 8-bit uint value > 1110 0100 Therefore instead of saving 4 values in 8-bit precision we save a single value of 8-bit precision saving 24 bits in total. Args: intweights (`torch.Tensor`): The un-packed `torch.uint8` tensor bits (`int`): The actual `bits` - can be 2, 4 """ original_shape = intweights.shape values_per_item = 8 // bits row_dim = (original_shape[0] + values_per_item - 1) // values_per_item if len(original_shape) == 1: packed_tensor_shape = (row_dim,) else: packed_tensor_shape = (row_dim, *original_shape[1:]) packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8) unpacked = intweights.to(torch.uint8) def lshift(t: torch.Tensor, bits: int): if t.device.type == "mps": # lshift is not supported on MPS device return t * (2**bits) return t << bits it = min(values_per_item, (original_shape[0] // row_dim) + 1) for i in range(it): start = i * row_dim end = min(start + row_dim, original_shape[0]) packed[: (end - start)] |= lshift(unpacked[start:end], bits * i) return packed class PackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, bits, size, stride, requires_grad=False): # PackedTensor represents uint8 data and can therefore NEVER require gradient assert data.dtype == torch.uint8 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, bits, size, stride, requires_grad=False): self._bits = bits self._data = data def __repr__(self): autograd_info = ( f", grad_fn={self.grad_fn}" if self.grad_fn else ", requires_grad=True" if self.requires_grad else "" ) return f"PackedTensor({self._data}, bits={self._bits}, public_dtype={self.dtype}{autograd_info})" @classmethod def pack(cls, t, bits=4): assert bits in (2, 4) # XPU use int8 dtype assert t.dtype in (torch.uint8, torch.int8) data = pack_weights(t, bits) # We need to store size and stride to make sure the unpacked data has the correct shape return PackedTensor(data, bits, t.size(), t.stride()) def unpack(self): unpacked_data = torch.ops.quanto.unpack(self._data, self._bits) # Adjust the first dimension, as unpacked data may have extra rows if the original shape is not a multiple of 8 // bits return unpacked_data[: self.shape[0]] @property def bits(self): return self._bits @property def dtype(self): return torch.uint8 @staticmethod def load_from_state_dict(state_dict, prefix, bits, size, stride, missing_keys): if prefix + "_data" not in state_dict: missing_keys.append(prefix + "_data") return inner_tensors_dict = {"_data": state_dict.pop(prefix + "_data")} meta = [name.replace(prefix, "") for name in state_dict.keys() if name.startswith(prefix)] meta = {"bits": str(bits), "size": str(list(size)), "stride": str(stride)} return PackedTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None) def __tensor_flatten__(self): inner_tensors = ["_data"] # Since meta can be used for serialization, use only AST compatible strings meta = {"bits": str(self._bits), "size": str(list(self.size())), "stride": str(self.stride())} return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 3 data = inner_tensors["_data"] # Meta should contain only AST compatible strings bits = ast.literal_eval(meta["bits"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return PackedTensor(data, bits, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Convert back to tensor before calling any operation except detach if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return PackedTensor(data, t._bits, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"PackedTensor are torch.uint8 only and cannot be moved to {dtype}.") # Move data data = op(t._data, **kwargs) return PackedTensor(data, t._bits, t.size(), t.stride()) args, kwargs = pytree.tree_map_only(PackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy() ================================================ FILE: optimum/quanto/tensor/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.autograd import Function from .grouped import ungroup from .packed import PackedTensor from .qtensor import QTensor __all__ = ["QBitsTensor"] class QBitsDequantizer(Function): @staticmethod def forward(ctx, t): if isinstance(t._data, PackedTensor): data = t._data.unpack() else: data = t._data shift = t._shift if not shift.dtype.is_floating_point: # Remove shift before multiplying by the scale data = data.to(torch.int8) - shift.to(torch.int8) if t.qtype.is_floating_point: # Upcast explicitly to the scale dtype dqt = t._scale * data.to(t._scale.dtype) else: dqt = t._scale * data if shift.dtype.is_floating_point: # Remove scaled shift dqt -= shift if t.axis is None: return dqt # Restore the original shape (if needed) return ungroup(dqt, axis=t.axis, orig_shape=t.shape) @staticmethod def backward(ctx, gO): return gO class QBitsTensor(QTensor): def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): super().__init__(qtype, axis) self._data = data self._scale = scale self._shift = shift self._group_size = group_size def __repr__(self): return f"{type(self).__name__}({self._data}, scale={self._scale}, shift={self._shift}, dtype={self.dtype})" def dequantize(self): return QBitsDequantizer.apply(self) ================================================ FILE: optimum/quanto/tensor/qbytes.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch.autograd import Function from .qtensor import QTensor __all__ = ["QBytesTensor"] class QBytesDequantizer(Function): @staticmethod def forward(ctx, t): if t.qtype.is_floating_point: # Upcast explicitly to the scale dtype dqt = t._scale * t._data.to(t._scale.dtype) else: dqt = t._scale * t._data return dqt @staticmethod def backward(ctx, gO): # For autograd, dequantization is a no-op return gO class QBytesTensor(QTensor): def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False): super().__init__(qtype, axis) self._data = data self._scale = scale def __repr__(self): return f"{self.__class__}({self._data}, scale={self._scale}, dtype={self.dtype})" def dequantize(self): """Differentiable dequantization function""" return QBytesDequantizer.apply(self) ================================================ FILE: optimum/quanto/tensor/qtensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.utils import _pytree as pytree __all__ = ["QTensor", "qfallback"] def qfallback(callable, *args, **kwargs): """Fallback method for QTensor inputs. When a torch function or an aten operation is not supported for the specified QTensor arguments, each QTensor arg or kwarg is dequantized to a torch.Tensor before calling the target function or op. """ args, kwargs = pytree.tree_map_only(QTensor, lambda x: x.dequantize(), (args, kwargs or {})) return callable(*args, **kwargs) class QTensor(torch.Tensor): def __init__(self, qtype, axis): self._qtype = qtype self._axis = axis def dequantize(self): raise NotImplementedError def save_to_state_dict(self, destination, prefix, keep_vars): def serialize_tensor_subclass(t, destination, prefix, keep_vars): inner_tensors, meta = t.__tensor_flatten__() for name in inner_tensors: inner_tensor = getattr(t, name) if type(inner_tensor) is torch.Tensor: # Leaf Tensor, we can serialize it destination[prefix + name] = inner_tensor if keep_vars else inner_tensor.detach() else: # Flatten also this inner Tensor serialize_tensor_subclass(inner_tensor, destination, prefix + name + ".", keep_vars) # Recursively flatten QTensor into individual tensors serialize_tensor_subclass(self, destination, prefix, keep_vars) @property def axis(self): return self._axis @property def qtype(self): return self._qtype def numpy(self): return self.dequantize().cpu().numpy() def equal(self, other): if type(self) is not type(other): return False self_tensors, self_meta = self.__tensor_flatten__() _, other_meta = other.__tensor_flatten__() for name, value in self_meta.items(): if other_meta[name] != value: return False for name in self_tensors: self_t = getattr(self, name) other_t = getattr(other, name) if self_t.device.type == "cpu" and self_t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): # torch.equal is not implemented on CPU for float8 types if self_t.dtype != other_t.dtype: return False if not torch.equal(self_t.to(torch.float32), other_t.to(torch.float32)): return False elif not torch.equal(self_t, other_t): return False return True ================================================ FILE: optimum/quanto/tensor/qtype.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass import torch @dataclass class qtype: """A quantized type class mimicking torch dtype""" name: str is_floating_point: bool bits: int # This defines the storage dtype dtype: torch.dtype qmin: float qmax: float def __str__(self): return f"quanto.{self.name}" def __hash__(self): return hash(str(self)) # Integer qtypes def qint(bits): qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) - 1 return qtype(f"qint{bits}", is_floating_point=False, bits=bits, dtype=torch.int8, qmin=qmin, qmax=qmax) qint2 = qint(2) qint4 = qint(4) qint8 = qint(8) # Float qtypes def qfloat(dtype: torch.dtype): finfo = torch.finfo(dtype) qmin = finfo.min qmax = finfo.max return qtype(f"q{finfo.dtype}", is_floating_point=True, bits=8, dtype=dtype, qmin=qmin, qmax=qmax) qfloat8_e4m3fn = qfloat(torch.float8_e4m3fn) qfloat8_e4m3fnuz = qfloat(torch.float8_e4m3fnuz) qfloat8_e5m2 = qfloat(torch.float8_e5m2) # Alias the float8 representation that has the better support and inference efficiency qfloat8 = qfloat8_e4m3fn # Convenience dict to get a dtype from its name qtypes = {name: q for (name, q) in locals().items() if isinstance(q, qtype)} __all__ = ["qtype", "qtypes"] + [str(name) for name in qtypes.keys()] ================================================ FILE: optimum/quanto/tensor/weights/__init__.py ================================================ from .qbits import * from .qbytes import * from .quantization import * ================================================ FILE: optimum/quanto/tensor/weights/awq/__init__.py ================================================ from .packed import * from .qbits import * ================================================ FILE: optimum/quanto/tensor/weights/awq/packed.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import copy from enum import Enum import numpy as np import torch from torch.utils import _pytree as pytree from ..packing import unpack_int32_to_uint8 __all__ = ["AWQPackedTensor", "AWQPacking"] AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] def pack(unpacked: torch.Tensor, reorder=False): """ Pack uint4 weights in an int32 tensor as expected by AWQ mixed mm kernel As compared to the standard packing, this adds an optional permutation of the columns for faster dequantization, as explained in "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production", https://arxiv.org/pdf/2211.10017. Args: unpacked (`torch.Tensor`): The un-packed `torch.uint8` tensor reorder (`bool`): Whether columns should be reordered or not before packing. Returns: A int32 `torch.Tensor`. """ bits = 4 pack_num = 32 // bits packed = torch.zeros(unpacked.shape[0], unpacked.shape[1] // pack_num, dtype=torch.int32, device=unpacked.device) for col in range(unpacked.shape[1] // pack_num): if reorder: order_map = AWQ_ORDER else: order_map = [0, 1, 2, 3, 4, 5, 6, 7] for i in range(pack_num): packed_col = unpacked[:, col * pack_num + order_map[i]].to(torch.int32) packed[:, col] |= packed_col << (i * bits) return packed def reverse_awq_order(t: torch.Tensor): bits = 4 reverse_order_tensor = torch.arange( t.shape[-1], dtype=torch.int32, device=t.device, ) reverse_order_tensor = reverse_order_tensor.reshape(-1, 32 // bits) reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] reverse_order_tensor = reverse_order_tensor.reshape(-1) t = t[:, reverse_order_tensor] return t def unpack(packed: torch.Tensor, reorder=False): """Unpack a packed int32 tensor to a larger uint8 tensor Applies pack operations in reverse order (see pack method for details). Args: packed (`torch.Tensor`): The packed `torch.int32` tensor reorder (`bool`): Whether columns should be reordered or not. Returns: An unpacked uint8 `torch.Tensor` expanded along the second dimension. """ unpacked = unpack_int32_to_uint8(packed, bits=4) if reorder: unpacked = reverse_awq_order(unpacked) return unpacked def pack_v2(unpacked: torch.Tensor) -> torch.Tensor: """ Pack uint4 weights in an int16 tensor as expected by AWQ second generation mixed mm kernel As compared to the standard packing, this adds three specific formatting: - permute rows to counter implicit permutation on Turing and Ampere architecture, - permute rows for faster dequantization, - interleave groups of 'interleave' rows for efficient parallel processing. Note that this formatting expects a group size of 128. Args: unpacked (`torch.Tensor`): The un-packed `torch.uint8` tensor Returns: A int16 `torch.Tensor`. """ assert unpacked.device.type in ["cuda", "xpu"] assert unpacked.ndim == 2 N, K = unpacked.shape # These two values are hard-coded in the optimized kernels: # - I represents the 'interleave', i.e. the number of values packed at a single coordinate (16 bits / 4 bits), # - S represents the 'kernel stride', and is related to the group size (TBC). I = 4 S = 64 # 1. For faster dequantization, the tensor rows must be permuted as explained in: # https://github.com/NVIDIA/TensorRT-LLM/blob/035b99e0d09d4f2dfdb949810cf7245112aa4165/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp#L161 # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] => [0, 1, 8, 9, 16, 17, 24, 25, ...] packed = unpacked.reshape(N, K // 32, 4, 4, 2).permute(0, 1, 3, 2, 4) # Reorder each 8 weights for fast dequantization # From: "Who Says Elephants Can’t Run: Bringing Large Scale MoE Models into Cloud Scale Production" # https://arxiv.org/pdf/2211.10017 # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] packed = packed.permute(0, 1, 2, 4, 3) packed = packed.reshape(N, K) # 2. For efficient parallelization, the rows are grouped and interleaved by blocks of kstride into a single row, as explained in: # https://github.com/NVIDIA/TensorRT-LLM/blob/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h#L69 # interleaving (N, K) -> (N // I, I, K // S, S) packed = packed.reshape(N // I, I, K // S, S) # transpose (N // I, I, K // S, S) -> (N // I, K // S, I, S) packed = packed.permute(0, 2, 1, 3) # reshape (N // I, K // S, I, S) -> (N // I, K // S, S, I) packed = packed.reshape(N // I, K // S, S, I) # Packing (N // I, K // S, S, I) -> (N // I, K // S, S) packed = packed.to(torch.int32) packed = packed[..., 0] | (packed[..., 1] << 4) | (packed[..., 2] << 8) | (packed[..., 3] << 12) # Reshape to (N // I, K // S, S) -> (N // I, K) packed = packed.reshape(N // I, K) return packed.to(torch.int16).contiguous() def unpack_v2(packed): """Unpack a packed int16 tensor to a larger uint8 tensor Applies pack operations in reverse order (see pack_v2 method for details). Warning: very slow, to be used for debug only. Args: packed (`torch.Tensor`): The packed `torch.int16` tensor Returns: An unpacked uint8 `torch.Tensor` expanded along the first dimension. """ assert packed.device.type in ["cuda", "xpu"] assert packed.ndim == 2 I = 4 S = 64 N_div_I, K = packed.shape N = N_div_I * I # Reshape (N // I, K) -> (N // I, K // S, S, 1) unpacked = packed.reshape(N // I, K // S, S, 1) # Convert to uint16 (through numpy because not supported by pytorch) unpacked = unpacked.cpu().numpy().astype(np.uint16) # Unpack (N // I, K, S) -> (N // I, K // S, S, I) unpacked = torch.cat( [ torch.tensor((unpacked & 0xF).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF0) >> 4).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF00) >> 8).astype(np.uint8)).to(packed.device), torch.tensor(((unpacked & 0xF000) >> 12).astype(np.uint8)).to(packed.device), ], axis=-1, ) # reshape (N // I, K // S, S, I) -> (N // I, K // S, I, S) unpacked = unpacked.reshape(N // I, K // S, I, S) # transpose (N // I, K // S, I, S) -> (N // I, I, K // S, S) unpacked = unpacked.permute(0, 2, 1, 3) # deinterleaving (N // I, I, K // S, S) -> (N, K) unpacked = unpacked.reshape(N, K) # Final steps to reorder (see packing code for explaination) unpacked = unpacked.reshape(N, K // 32, 4, 2, 4).permute(0, 1, 2, 4, 3) unpacked = unpacked.permute(0, 1, 3, 2, 4) unpacked = unpacked.reshape(N, K) return unpacked class AWQPacking(Enum): V1 = 1 V2 = 2 class AWQPackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, packing, reorder, size, stride, requires_grad=False): # AWQPackedTensor represents uint8 data and can therefore NEVER require gradient assert data.device.type in ["cuda", "xpu"] assert data.dtype == torch.int32 if packing == AWQPacking.V1 else torch.int16 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, packing, reorder, size, stride, requires_grad=False): self._data = data self._packing = packing self._reorder = reorder def __repr__(self): return f"AWQPackedTensor({self._data}, packing={self._packing}, reorder={self._reorder})" @classmethod def pack(cls, t, packing=AWQPacking.V1, reorder=False): if packing == AWQPacking.V1: data = pack(t, reorder=reorder) else: data = pack_v2(t) # We need to store size and stride to make sure the unpacked data has the correct shape return AWQPackedTensor(data, packing, reorder, t.size(), t.stride()) def unpack(self): if self._packing == AWQPacking.V1: return unpack(self._data, self._reorder) return unpack_v2(self._data) @property def dtype(self): return torch.uint8 def __tensor_flatten__(self): inner_tensors = ["_data"] # Since meta can be used for serialization, use only AST compatible strings meta = { "packing": str(self._packing), "reorder": str(self._reorder), "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 4 data = inner_tensors["_data"] # Meta should contain only AST compatible strings packing = ast.literal_eval(meta["packing"]) reorder = ast.literal_eval(meta["reorder"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return AWQPackedTensor(data, packing, reorder, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Convert back to tensor before calling any operation except detach and move if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"AWQPackedTensor are torch.uint8 only and cannot be moved to {dtype}.") device = kwargs.get("device", t.device) # AWQPackedTensor can only be moved to CUDA devices if device.type == "cuda": data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype data = op(t._data, **data_kwargs) return AWQPackedTensor(data, t._packing, t._reorder, t.size(), t.stride()) args, kwargs = pytree.tree_map_only(AWQPackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy() ================================================ FILE: optimum/quanto/tensor/weights/awq/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.autograd import Function from ...function import QuantizedLinearFunction from ...grouped import group, ungroup from ...qtype import qtypes from ..qbits import WeightQBitsTensor from .packed import AWQPackedTensor, AWQPacking __all__ = ["AWQWeightQBitsTensor"] class AWQWeightQBitsDequantizer(Function): @staticmethod def forward(ctx, t): unpacked = t._data.unpack() scale = t._scale shift = t._shift unpacked = group(unpacked, axis=0, group_size=t._group_size) n_scales = scale.numel() scale = scale.t().reshape((n_scales, 1)) shift = shift.t().reshape((n_scales, 1)) if shift.dtype.is_floating_point: # Shift is already scaled and negated on CUDA dqt = scale * unpacked + shift else: # Shift is int type on XPU to support pytorch fused op dqt = (unpacked - shift) * scale return ungroup(dqt, axis=t.axis, orig_shape=t.shape) @staticmethod def backward(ctx, gO): return gO class AWQWeightQBitsLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias): ctx.save_for_backward(input, other) if type(input) is not torch.Tensor: input = input.dequantize() out_features, in_features = other.shape rows = input.numel() // in_features output = torch.ops.quanto.gemm_f16i4_awq( input, other._data._data, other._scale, other._shift, rows=rows, out_cols=out_features, in_cols=in_features, bits=4, group_size=other._group_size, ) if bias is not None: output = output + bias return output class AWQWeightQBitsTensor(WeightQBitsTensor): @staticmethod def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): assert data.device.type in ["cuda", "xpu"] assert data.device == scale.device assert data.device == shift.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): # XPU requires awq v1 to support pytorch fused op self.packing_type = AWQPacking.V1 if data.device.type == "xpu" else AWQPacking.V2 assert axis == 0 if not isinstance(data, AWQPackedTensor): assert type(data) is torch.Tensor # Format data, scale and shift for optimized CUDA/XPU gemm ungrouped = ungroup(data, axis=0, orig_shape=size) data = AWQPackedTensor.pack(ungrouped, packing=self.packing_type) out_features, in_features = size scale = scale.reshape(out_features, in_features // group_size).t().contiguous() shift = shift.reshape(out_features, in_features // group_size).t() if not shift.dtype.is_floating_point and data.device.type != "xpu": # Integer shift must be scaled shift = scale * shift # Shift must be negated shift = shift.contiguous() if data.device.type == "xpu" else -shift.contiguous() super().__init__(qtype, axis, group_size, size, stride, data, scale, shift) def dequantize(self): return AWQWeightQBitsDequantizer.apply(self) def weight_qbits_tensor(self): """Convert back to a WeightQBitsTensor This is required to make sure only standard packing is used when serializing. """ data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size) n_scales = self._scale.numel() scale = self._scale.t().reshape((n_scales, 1)) shift = self._shift if self._shift.device.type == "xpu" else -self._shift shift = shift.t().reshape((n_scales, 1)) return WeightQBitsTensor( self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift ) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale", "_shift"] # Since meta can be used for serialization, use only strings meta = { "qtype": self._qtype.name, "axis": str(self._axis), "group_size": str(self._group_size), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 3 assert len(meta) == 5 data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) group_size = ast.literal_eval(meta["group_size"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return AWQWeightQBitsLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/marlin/__init__.py ================================================ from .fp8 import * from .int4 import * from .permutations import * ================================================ FILE: optimum/quanto/tensor/weights/marlin/fp8/__init__.py ================================================ from .packed import * from .qbits import * ================================================ FILE: optimum/quanto/tensor/weights/marlin/fp8/packed.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import copy import torch from torch.utils import _pytree as pytree def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements). """ assert fp8_tensor.dtype == torch.float8_e4m3fn if fp8_tensor.shape[0] % 4 != 0: raise ValueError(f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}") # Reshape to prepare for packing reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) # Convert fp8 to uint8 (byte) representation byte_tensor = reshaped.view(torch.uint8) # Pack 4 uint8 values into one int32 packed = torch.zeros( fp8_tensor.shape[0] // 4, fp8_tensor.shape[1], dtype=torch.int32, device=fp8_tensor.device, ) for i in range(4): packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) return packed def unpack_int32_to_fp8(int32_tensor: torch.Tensor) -> torch.Tensor: """ Reinterpret a tensor (a, b) of type int32 to a tensor (a * 4, b) of type float8_e4m3fn. """ bits = 8 unpacked = [] # Unpack each set of values independently for i in range(4): mask = 2 ** (bits * (i + 1)) - 1 tmp = (int32_tensor & mask) >> bits * i tmp = tmp.to(torch.uint8) unpacked.append(tmp) # Return the concatenated unpacked tensors unpacked = torch.cat(unpacked).view(torch.float8_e4m3fn) return unpacked def get_scale_perms() -> torch.Tensor: scale_perm_single = [] for i in range(4): scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return torch.tensor(scale_perm_single, dtype=torch.int64) def get_row_permutation(n_rows: int) -> torch.Tensor: """ Generates a tensor of shape (4 * n_rows,) giving the rows mapping to map from marlin-repacked weights to natural order. Example: if n_rows = 8, the row mapping from natural to marlin format is rows_idx = [0, 2, 4, 6, 16, 18, 20, 22, 8, 10, 12, 14, 24, 26, 28, 30, 1, 3, 5, 7, 17, 19, 21, 23, 9, 11, 13, 15, 25, 27, 29, 31]. """ modulo = n_rows // 4 * 16 - 8 b = n_rows // 2 # Group by 16*k, then by 8 + 16*k rows_idx = [(i * 16) % modulo for i in range(b)] rows_idx[-1] = rows_idx[-2] + 16 if b > 2 else 8 rows_idx = torch.tensor(rows_idx) # All even indexes, and then all odd indexes. rows_idx = torch.cat((rows_idx, rows_idx + 1)) # Indexes are grouped by four, each spaced by 2. rows_idx = torch.tile(rows_idx[:, None], (1, 4)) rows_idx = rows_idx + torch.tensor([[0, 2, 4, 6]]) rows_idx = rows_idx.reshape(-1) # `rows_idx` holds the mapping of natural rows to marlin rows, so inverse it. rows_idx_rev = torch.empty_like(rows_idx) rows_idx_rev[rows_idx] = torch.arange(len(rows_idx)) return rows_idx_rev def get_column_permutation(n_col: int) -> torch.Tensor: """ Gets the column mapping to map from marlin-repacked weights to natural order. The natural order to marlin is: `8 * rest + frac` to `rest + 32 * frac`, by blocks of 256 values. """ tile_size = 256 n_blocks = n_col // tile_size a = torch.arange(tile_size) rest = a % 8 frac = a // 8 original_index = 32 * rest + frac original_index = torch.arange(n_blocks)[:, None] * 256 + original_index original_index = original_index.reshape(-1) # The mapping per-column is: # # 64 64 64 64 64 64 64 64 64 64 64 64 # ------------------------------------------------------------------------ # | 0 1 2 3 | 0 1 2 3 | 0 1 2 3 | # ------------------------------------------------------------------------ # # Hence to retrieve column 0, 1, 2, 3 in order, we need to # shuffle the blocks of 64 values. original_index = original_index.reshape(4 * n_blocks, 64) # Generate a shuffling as e.g. [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] for the above. tmp1 = torch.arange(4) tmp1 = tmp1.repeat(n_blocks, 1).T.reshape(-1) # e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] tmp2 = torch.arange(n_blocks) * 4 tmp2 = tmp2.repeat(4) # e.g. [0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8] remap_col_index = tmp1 + tmp2 original_index = original_index[remap_col_index] original_index = original_index.reshape(-1) return original_index class MarlinF8PackedTensor(torch.Tensor): def __new__(cls, data, size, stride, requires_grad=False): assert data.device.type == "cuda" assert data.dtype == torch.int32 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.int32, device=data.device, requires_grad=requires_grad ) def __init__(self, data, size, stride, requires_grad=False): self._data = data def __repr__(self): return f"MarlinF8PackedTensor({self._data})" @classmethod def pack(cls, tensor: torch.Tensor): out_features, in_features = tensor.shape data_int32 = pack_fp8_as_int32(tensor.T) # pack fp8 data to in32. perm = torch.empty(0, dtype=torch.int, device=tensor.device) data_int32 = torch.ops.quanto.pack_fp8_marlin( b_q_weight=data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8 ) return cls(data_int32, size=tensor.size(), stride=tensor.stride()) def unpack(self) -> torch.Tensor: """ Reinterprets the packed tensor (a, b) of type int32 and in the marlin order, to a tensor (a * 4, b) of type float8_e4m3fn, in the natural order. """ float8_data = unpack_int32_to_fp8(self._data) # complex indexing is not implemented for 'Float8_e4m3fn' uint8_data = float8_data.view(torch.uint8) n_rows, n_col = uint8_data.shape # swap columns column_map = get_column_permutation(n_col=n_col) uint8_data = uint8_data.T.contiguous() uint8_data = uint8_data[column_map] uint8_data = uint8_data.T.contiguous() uint8_data = uint8_data.reshape(uint8_data.shape[0] * 4, -1) # swap rows row_map = get_row_permutation(n_rows=n_rows) uint8_data = uint8_data[row_map] float8_data = uint8_data.view(torch.float8_e4m3fn) float8_data = float8_data.T # As we originally transposed in `pack_fp8_as_int32` return float8_data @property def dtype(self): return torch.int32 def __tensor_flatten__(self): inner_tensors = ["_data"] # Since meta can be used for serialization, use only AST compatible strings meta = { "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 2 data = inner_tensors["_data"] # Meta should contain only AST compatible strings size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinF8PackedTensor(data, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Convert back to tensor before calling any operation except detach and move if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return cls(data, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.int32) if dtype != torch.int32: raise ValueError(f"MarlinF8PackedTensor are torch.int32 only and cannot be moved to {dtype}.") device = kwargs.get("device", t.device) if device.type == "cuda": data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype data = op(t._data, **data_kwargs) return cls(data, t.size(), t.stride()) else: return t.unpack().to(device) else: args, kwargs = pytree.tree_map_only(cls, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/marlin/fp8/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from ....function import QuantizedLinearFunction from ....qtype import qfloat8_e4m3fn, qtypes from ...qbytes import WeightQBytesTensor from .packed import MarlinF8PackedTensor, get_scale_perms __all__ = ["MarlinF8QBytesTensor"] class MarlinF8QBytesLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) input_shape = input.shape if input.ndim > 2: input = input.reshape(-1, input_shape[-1]) output = torch.ops.quanto.gemm_f16f8_marlin( input, b_q_weight=other._data._data, b_scales=other._scale, # .to(input.dtype) workspace=other._workspace, num_bits=8, size_m=input.shape[0], size_n=other._scale.shape[1], size_k=input.shape[1], ) if len(input_shape) > 2: output = output.reshape(input_shape[:-1] + (other._scale.shape[1],)) return output class MarlinF8QBytesTensor(WeightQBytesTensor): @staticmethod def __new__(cls, qtype, axis, size, stride, data, scale, requires_grad=False): assert data.device.type == "cuda" assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, size, stride, data, scale, requires_grad=False): assert axis == 0 assert data.ndim == 2 out_features = size[0] self._workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=data.device) # TODO: Here we should use `not isinstance(data, MarlinF8PackedTensor)`, but `torch.compile` is bugged when using that. # Somewhere in the internals of torch.compile, `data` gets converted to a `torch._subclasses.fake_tensor.FakeTensor` not inheriting from `MarlinF8PackedTensor` and torch then goes into the wrong controlflow. # Reference: https://pytorch.slack.com/archives/C033H6DJSJU/p1721837684035049 if data.dtype != torch.int32: assert scale.shape == (out_features, 1) scale_perm_single = get_scale_perms() scale = scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] scale = scale.reshape(-1, out_features).contiguous() data_packed = MarlinF8PackedTensor.pack(data) # pack fp8 data to in32, and apply marlier re-ordering. else: # When freezing (`model.freeze()`), the data is already a MarlinF8PackedTensor and scale is already repacked. data_packed = data super().__init__( qtype, axis, size, stride, data_packed, scale, activation_qtype=qfloat8_e4m3fn, requires_grad=requires_grad ) def dequantize(self): float8_data = self._data.unpack() scale_perm_single = get_scale_perms() # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here. scale_perm_single_rev = torch.empty_like(scale_perm_single) scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single)) scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev] scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).contiguous() return float8_data.to(scale_reordered.dtype) * scale_reordered.T def __repr__(self): return f"MarlinF8QBytesTensor({self._data}, scale={self._scale}, dtype={self.dtype})" def weight_qbytes_tensor(self): data = self._data.unpack() scale_perm_single = get_scale_perms() # `scale_perm_single` holds the mapping of natural to marlin, so inverse it here. scale_perm_single_rev = torch.empty_like(scale_perm_single) scale_perm_single_rev[scale_perm_single] = torch.arange(len(scale_perm_single)) scale_reordered = self._scale.reshape((-1, len(scale_perm_single_rev)))[:, scale_perm_single_rev] scale_reordered = scale_reordered.reshape(-1, self._scale.shape[1]).t().contiguous() return WeightQBytesTensor( self._qtype, self._axis, self.size(), self.stride(), data, scale_reordered, self.activation_qtype ) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "axis": str(self._axis), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 4 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return MarlinF8QBytesLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) elif func is torch.equal: input, other = args return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/marlin/int4/__init__.py ================================================ from .packed import * from .qbits import * ================================================ FILE: optimum/quanto/tensor/weights/marlin/int4/packed.py ================================================ import ast from copy import copy import numpy as np import torch from torch.utils import _pytree as pytree from ...packing import unpack_int32_to_uint8 from ...reordering import reorder, reverse __all__ = ["MarlinInt4PackedTensor"] # From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L40 # this func does 2 things # 1. 1 thread can load 32 4bit == 128bit weights used for mulitple mma instructions at once # 2. faster dequant via parallel half2 mul def _get_perm(): perm = [] # 32 == # of threads in 1 warp for i in range(32): perm1 = [] # column id in 16x8 weight block # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float col = i // 4 # 1 32bit (int32) == 8 4bit, 1 thread has 4 weights per 16x8 & 4bit weights are packed in int32, so needs 2 16x8 == 1 16x16 blocks for block in [0, 1]: # row id in 16x8 weight block # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float for row in [ 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), 2 * (i % 4 + 4) + 1, ]: # 8 weights used for 1 thread (16x16 block) are contiguous in memory via interleaving # e.g. T0 uses (0, 16, 128, 144, 8, 24, 136, 152) perm1.append(16 * row + col + 8 * block) # 1 128bit (int4) == 4 32bit, 1 thread loads 128bit at once, so needs 4 16x16 == 1 16x64 blocks for j in range(4): # 32 weights loaded by 1 thread (16x64 block) are contiguous in memory via interleaving # e.g. T0 uses ((0 ~ 152) + 0 * 256, (0 ~ 152) + 1 * 256, ..., (0 ~ 152) + 3 * 256) perm.extend([p + 256 * j for p in perm1]) perm = np.array(perm) # for faster dequant # check https://arxiv.org/pdf/2211.10017 interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) perm = perm.reshape((-1, 8))[:, interleave].ravel() perm = torch.from_numpy(perm) return perm _perm = _get_perm() _rev_perm = reverse(_perm) # From: https://github.com/IST-DASLab/marlin/blob/master/marlin/__init__.py#L102 def pack(unpacked: torch.Tensor): w = unpacked N, K = w.shape w = unpacked.t() # 16 == tile size, marlin uses 16x16 tile, so 16x16 grouping via interleaving w = w.reshape((K // 16, 16, N // 16, 16)) w = w.permute((0, 2, 1, 3)) w = w.reshape((K // 16, N * 16)) res = w # _perm.numel() == 1024 == 4 16x16, permute weights with 4 16x16 unit for efficient mma + dequant res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) p = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) res = res.cpu().numpy().astype(np.uint32) for i in range(8): p |= res[:, i::8] << 4 * i p = torch.from_numpy(p.astype(np.int32)).to(w.device) return p def unpack(packed, orig_shape): N, K = orig_shape # Unpack to recover individual values unpacked = unpack_int32_to_uint8(packed, bits=4).to(torch.uint8) # Recover the original ordering unpacked = reorder(unpacked, _rev_perm) # Apply block permutations in the reverse order unpacked = unpacked.reshape(K // 16, N // 16, 16, 16) unpacked = unpacked.permute((0, 2, 1, 3)) unpacked = unpacked.reshape(K, N) return unpacked.t() class MarlinInt4PackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, size, stride, requires_grad=False): assert data.device.type == "cuda" assert data.dtype == torch.int32 assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, size, stride, requires_grad=False): self._data = data def __repr__(self): return f"MarlinInt4PackedTensor({self._data})" @classmethod def pack(cls, t): data = pack(t) return MarlinInt4PackedTensor(data, t.size(), t.stride()) def unpack(self): return unpack(self._data, self.size()) @property def dtype(self): return torch.uint8 def __tensor_flatten__(self): inner_tensors = ["_data"] meta = { "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 2 data = inner_tensors["_data"] size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinInt4PackedTensor(data, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return MarlinInt4PackedTensor(data, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"MarlinInt4PackedTensor are torch.uint8 only and cannot be moved to {dtype}.") device = kwargs.get("device", t.device) if device.type == "cuda": data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype data = op(t._data, **data_kwargs) return MarlinInt4PackedTensor(data, t.size(), t.stride()) return t.unpack() args, kwargs = pytree.tree_map_only(MarlinInt4PackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy() ================================================ FILE: optimum/quanto/tensor/weights/marlin/int4/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.autograd import Function from ....function import QuantizedLinearFunction from ....grouped import group, ungroup from ....qtype import qtypes from ...qbits import WeightQBitsTensor from ..permutations import marlin_permute from .packed import MarlinInt4PackedTensor __all__ = ["MarlinInt4WeightQBitsTensor"] class MarlinQBitsDequantizer(Function): @staticmethod def forward(ctx, t): unpacked = t._data.unpack() scale = t._scale shift = t._shift unpacked = group(unpacked, axis=0, group_size=t._group_size) # Apply inverted permutations scale = marlin_permute(scale, reverse=True) shift = marlin_permute(shift, reverse=True) n_scales = scale.numel() scale = scale.t().reshape((n_scales, 1)) shift = shift.t().reshape((n_scales, 1)) # Shift is already scaled and negated dqt = scale * unpacked + shift return ungroup(dqt, axis=t.axis, orig_shape=t.shape) @staticmethod def backward(ctx, gO): return gO class MarlinQBitsLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias): ctx.save_for_backward(input, other) if type(input) is not torch.Tensor: input = input.dequantize() out_features, in_features = other.shape output = torch.ops.quanto.gemm_f16i4_marlin( input, other._data._data, other._scale, other._shift, other._workspace, ) if bias is not None: output = output + bias return output class MarlinInt4WeightQBitsTensor(WeightQBitsTensor): @staticmethod def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): assert data.device.type == "cuda" assert data.device == scale.device assert data.device == shift.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): assert axis == 0 out_features, in_features = size if not isinstance(data, MarlinInt4PackedTensor): assert type(data) is torch.Tensor # Format data, scale and shift for optimized CUDA gemm ungrouped = ungroup(data, axis=0, orig_shape=size) data = MarlinInt4PackedTensor.pack(ungrouped) scale = scale.reshape(out_features, in_features // group_size).t().contiguous() shift = shift.reshape(out_features, in_features // group_size).t() if not shift.dtype.is_floating_point: # Integer shift must be scaled shift = scale * shift # Shift must be negated shift = -shift.contiguous() # Finally, apply scale and shift permutations scale = marlin_permute(scale) shift = marlin_permute(shift) super().__init__(qtype, axis, group_size, size, stride, data, scale, shift) self._workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=data.device) def dequantize(self): return MarlinQBitsDequantizer.apply(self) def weight_qbits_tensor(self): """Convert back to a WeightQBitsTensor This is required to make sure only standard packing is used when serializing. """ data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size) scale = marlin_permute(self._scale, reverse=True) shift = marlin_permute(self._shift, reverse=True) n_scales = scale.numel() scale = scale.t().reshape((n_scales, 1)) shift = -shift.t().reshape((n_scales, 1)) return WeightQBitsTensor( self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift ) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale", "_shift"] # Since meta can be used for serialization, use only strings meta = { "qtype": self._qtype.name, "axis": str(self._axis), "group_size": str(self._group_size), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 3 assert len(meta) == 5 data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) group_size = ast.literal_eval(meta["group_size"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return MarlinInt4WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return MarlinQBitsLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/marlin/permutations.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from typing import List, Tuple import torch from ..reordering import reorder, reverse __all__ = ["marlin_permute"] # https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54 @functools.cache def _get_perms() -> Tuple[List[int], List[int]]: perm = [] for i in range(8): perm.extend([i + 8 * j for j in range(8)]) perm_single = [] for i in range(4): perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return perm, perm_single @functools.cache def _get_inverted_perms() -> Tuple[List[int], List[int]]: perm, perm_single = _get_perms() return reverse(perm), reverse(perm_single) def marlin_permute(t: torch.Tensor, reverse=False): perm, perm_single = _get_inverted_perms() if reverse else _get_perms() out_features = t.shape[1] if t.shape[0] == 1: reordered = reorder(t, perm_single) else: reordered = reorder(t, perm) return reordered.reshape((-1, out_features)).contiguous() ================================================ FILE: optimum/quanto/tensor/weights/packing.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def unpack_int32_to_uint8(packed: torch.Tensor, bits: int): """Unpack a packed int32 tensor to a larger uint8 tensor Args: packed (`torch.Tensor`): The packed integer tensor bits: (`int`): The number of bits of each packed value. Returns: An unpacked uint8 `torch.Tensor` expanded along the last dimension. """ total_bits = 32 shifts = torch.arange(0, total_bits, bits, device=packed.device) # Unpack column-wise unpacked = torch.bitwise_right_shift(packed[:, :, None], shifts[None, None, :]).to( torch.int8 # smallest dtype available ) unpacked = unpacked.reshape(unpacked.shape[0], -1) # Convert to unsigned unpacked = torch.bitwise_and(unpacked, (2**bits) - 1) unpacked = unpacked if packed.device.type == "xpu" else unpacked.to(torch.uint8) return unpacked ================================================ FILE: optimum/quanto/tensor/weights/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from typing import Optional import torch from packaging import version from torch.autograd import Function from ...library import is_extension_available from ..function import QuantizedLinearFunction from ..grouped import grouped_shape from ..packed import PackedTensor from ..qbits import QBitsTensor from ..qtensor import qfallback from ..qtype import qint2, qint4, qtype, qtypes __all__ = ["WeightQBitsTensor"] class WeightsQBitsQuantizer(Function): @staticmethod def forward( ctx, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, scale: torch.Tensor, shift: torch.Tensor, optimized: bool, ): if qtype not in (qint2, qint4): raise ValueError("WeightQBitsTensor can only be of qint2 or qint4 qtype") if axis not in (0, -1): raise ValueError("WeightQBitsTensor axis parameter must be 0 (first axis) or -1 (last axis)") size = base.size() stride = base.stride() data = torch.ops.quanto.quantize_affine( base, bits=qtype.bits, axis=axis, group_size=group_size, scale=scale, shift=shift ) if optimized: return WeightQBitsTensor.create(qtype, axis, group_size, size, stride, data, scale, shift) return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift) @staticmethod def backward(ctx, gO): # For autograd, quantization is a no-op return gO, None, None, None, None, None, None class WeightQBitsTensor(QBitsTensor): @staticmethod def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): """Factory method to create a WeightQBitsTensor This selects the most appropriate WeightQBitsTensor based on the configuration. Args: axis (`int`): The axis that is preserved by quantization (usually zero for linear weights). group_size (`int`): The group size that further splits the data elements for each index along the quantization axis. size (): The Tensor size. stride(): The Tensor stride. data (`torch.Tensor`): The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor. scale (`torch.Tensor`): The floating point scale expressed as a torch.Tensor. shift (`torch.Tensor`): The shift expressed as a torch.Tensor. It can be either an integer representing zero (i.e. zero-point) or a float value. requires_grad (`bool`): If the Tensor must be receive a gradient or not. Returns: a `WeightQBitsTensor` (can be a subclass). """ from .awq import AWQWeightQBitsTensor from .tinygemm import TinyGemmWeightQBitsTensor if ( qtype == qint4 and size[0] >= 128 # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs) and scale.dtype == torch.float16 and axis == 0 and group_size == 128 and len(size) == 2 and (data.device.type == "cuda" and torch.version.cuda) and torch.cuda.get_device_capability(data.device)[0] >= 8 and is_extension_available("quanto_cuda") ) or ( qtype == qint4 and axis == 0 and group_size == 128 and len(size) == 2 and data.device.type == "xpu" and shift.dtype == torch.int8 and version.parse(torch.__version__).release >= version.parse("2.8.0").release ): if type(data) is PackedTensor: data = data.unpack() return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad) if ( qtype == qint4 and scale.dtype == torch.bfloat16 and shift.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2 ): if data.device.type == "cpu" or ( (data.device.type == "cuda" and torch.version.cuda) and version.parse(torch.version.cuda).release >= (12, 1) and torch.cuda.get_device_capability(data.device)[0] >= 8 ): if type(data) is PackedTensor: data = data.unpack() return TinyGemmWeightQBitsTensor( qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad ) return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad) @staticmethod def __new__(cls, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): assert data.device == scale.device assert data.device == shift.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False): if type(data) is torch.Tensor: data = PackedTensor.pack(data, qtype.bits) super().__init__(qtype, axis, group_size, size, stride, data, scale, shift) @classmethod def quantize( cls, base: torch.Tensor, qtype: qtype, axis: int, group_size: int, scale: torch.Tensor, shift: torch.Tensor, optimized: Optional[bool] = True, ): return WeightsQBitsQuantizer.apply(base, qtype, axis, group_size, scale, shift, optimized) @staticmethod def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys): if group_size is None: data_size = size data_stride = stride else: data_size = grouped_shape(size, axis, group_size) assert len(data_size) == 2 # In row major, inner dimension (stride 1) is the last one data_stride = (data_size[1], 1) inner_tensors_dict = { "_data": PackedTensor.load_from_state_dict( state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys ) } missing = inner_tensors_dict["_data"] is None for name in ["_scale", "_shift"]: if prefix + name not in state_dict: missing_keys.append(prefix + name) missing = True else: inner_tensors_dict[name] = state_dict.pop(prefix + name) if missing: # could not deserialize because of missing keys return None meta = { "qtype": qtype.name, "axis": str(axis), "group_size": str(group_size), "size": str(list(size)), "stride": str(list(stride)), } return WeightQBitsTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None) def optimize(self): """Allows to convert an existing WeightQBitsTensor to an optimized subclass This is used in particular after reloading a serialized WeightQBitsTensor (which is always saved using the kernel-agnostic packing). """ if type(self) is not WeightQBitsTensor: return self data = self._data.unpack() # Call dedicated helper to select the best subclass for this device return WeightQBitsTensor.create( self.qtype, self.axis, self._group_size, self.size(), self.stride(), data, self._scale, self._shift, self.requires_grad, ) def save_to_state_dict(self, destination, prefix, keep_vars): if type(self) is WeightQBitsTensor: super().save_to_state_dict(destination, prefix, keep_vars) else: # Convert back subclass before serializing self.weight_qbits_tensor().save_to_state_dict(destination, prefix, keep_vars) def weight_qbits_tensor(self): """Convert back a subclass to a WeightQBitsTensor This is required to make sure only standard packing is used when serializing. """ raise NotImplementedError def __tensor_flatten__(self): inner_tensors = ["_data", "_scale", "_shift"] # Since meta can be used for serialization, use only strings meta = { "qtype": self._qtype.name, "axis": str(self._axis), "group_size": str(self._group_size), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 3 assert len(meta) == 5 data, scale, shift = inner_tensors["_data"], inner_tensors["_scale"], inner_tensors["_shift"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) group_size = ast.literal_eval(meta["group_size"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return QuantizedLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) elif func is torch.equal: input, other = args return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Do not use directly op, but rather its overload op = op.overloadpacket if op is torch.ops.aten.detach: t = args[0] # Detach is required when copying and deserializing inner_tensor_names, meta = t.__tensor_flatten__() # Detach inner tensors detached_tensors = {} for inner_name in inner_tensor_names: detached_tensors[inner_name] = op(getattr(t, inner_name)) return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride()) elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]: t = args[0] dtype = kwargs.pop("dtype", t.dtype) device = kwargs.pop("device", t.device) if dtype is not None and dtype != t.dtype: raise ValueError("The dtype of a WeightQBitsTensor cannot be changed") if type(t) is not WeightQBitsTensor and t.device.type != device.type: # Before moving to another device type, convert back to a WeightQBitsTensor t = t.weight_qbits_tensor() scale = op(t._scale, dtype=dtype, device=device, **kwargs) data = op(t._data, device=device, **kwargs) shift = op(t._shift, device=device, **kwargs) return WeightQBitsTensor.create(t._qtype, t._axis, t._group_size, t.size(), t.stride(), data, scale, shift) # No dispatch available: qfallback kwargs = kwargs or {} return qfallback(op, *args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/qbytes.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from typing import Optional import torch from torch.autograd import Function from ...library import is_extension_available from ..function import QuantizedLinearFunction from ..qbytes import QBytesTensor from ..qtensor import qfallback from ..qtype import qtype, qtypes __all__ = ["WeightQBytesTensor"] class WeightQBytesQuantizer(Function): @staticmethod def forward( ctx, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, activation_qtype: qtype, optimized: bool ) -> torch.Tensor: if qtype.bits != 8: raise ValueError("QBytesTensor can only be of 8-bit qtype") data = torch.ops.quanto.quantize_symmetric(base, dtype=qtype.dtype, axis=axis, scale=scale) # The instantiation of the quantized tensor must happen within the context of the Function # for the autograd magic to work. if optimized: return WeightQBytesTensor.create( qtype, axis, size=base.size(), stride=base.stride(), data=data, scale=scale, activation_qtype=activation_qtype, ) return WeightQBytesTensor( qtype, axis, size=base.size(), stride=base.stride(), data=data, scale=scale, activation_qtype=activation_qtype, ) @staticmethod def backward(ctx, gO): # For autograd, quantization is a no-op return gO, None, None, None, None, None, None class WeightQBytesLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias=None): ctx.save_for_backward(input, other) if isinstance(input, QBytesTensor): output = torch.ops.quanto.qbytes_mm(input._data, other._data, input._scale * other._scale) else: in_features = input.shape[-1] out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) output = torch.ops.quanto.qbytes_mm(input.reshape(-1, in_features), other._data, other._scale) output = output.reshape(output_shape) if bias is not None: output = output + bias return output class WeightQBytesTensor(QBytesTensor): @staticmethod def create( qtype, axis, size, stride, data, scale, activation_qtype: Optional[qtype] = None, requires_grad=False, ): """Factory method to create a QBytesTensor This selects the most appropriate QBytesTensor based on the configuration. Args: axis (`int`): The axis that is preserved by quantization (usually zero for linear weights). size (): The Tensor size. stride(): The Tensor stride. data (`torch.Tensor`): The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor. scale (`torch.Tensor`): The floating point scale expressed as a torch.Tensor. activation_qtype (`qtype`, defaults to `None`): The qtype used for the activations. If one needs to use a different tensor subclass e.g. for weights depending on the activations qtype, this argument must be specified accordingly when calling `QBytesTensor.create`. requires_grad (`bool`): If the Tensor must be receive a gradient or not. Returns: a `QBytesTensor` (can be a subclass). """ from .marlin import MarlinF8QBytesTensor if ( qtype == qtypes["qfloat8_e4m3fn"] and activation_qtype is None and scale.dtype in [torch.float16, torch.bfloat16] and len(size) == 2 and (data.device.type == "cuda" and torch.version.cuda) and axis == 0 and torch.cuda.get_device_capability(data.device)[0] >= 8 and is_extension_available("quanto_cuda") ): out_features, in_features = size if ( in_features >= 64 and out_features >= 64 and ( (in_features % 64 == 0 and out_features % 128 == 0) or (in_features % 128 == 0 and out_features % 64 == 0) ) ): return MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale, requires_grad) return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype, requires_grad) @staticmethod def __new__(cls, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False): assert data.device == scale.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, size, stride, data, scale, activation_qtype, requires_grad=False): super().__init__(qtype, axis, size, stride, data, scale, requires_grad=requires_grad) self.activation_qtype = activation_qtype @classmethod def quantize( cls, base: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, activation_qtype: Optional[qtype] = None, optimized: Optional[bool] = True, ) -> torch.Tensor: return WeightQBytesQuantizer.apply(base, qtype, axis, scale, activation_qtype, optimized) @staticmethod def load_from_state_dict(state_dict, prefix, qtype, axis, size, stride, activation_qtype, missing_keys): inner_tensors_dict = {} missing = False for name in ["_data", "_scale"]: if prefix + name not in state_dict: missing_keys.append(prefix + name) missing = True else: inner_tensors_dict[name] = state_dict.pop(prefix + name) if missing: # could not deserialize because of missing keys return None meta = { "qtype": qtype.name, "axis": str(axis), "size": str(list(size)), "stride": str(list(stride)), "activation_qtype": "none" if activation_qtype is None else activation_qtype.name, } return WeightQBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None) def optimize(self): """Allows to convert an existing WeightQBytesTensor to an optimized subclass This is used in particular after reloading a serialized WeightQBytesTensor (which is always saved using the kernel-agnostic packing). """ if type(self) is not WeightQBytesTensor: return self # Call dedicated helper to select the best subclass for this device return WeightQBytesTensor.create( self.qtype, self.axis, self.size(), self.stride(), self._data, self._scale, self.activation_qtype, self.requires_grad, ) def save_to_state_dict(self, destination, prefix, keep_vars): if type(self) is WeightQBytesTensor: super().save_to_state_dict(destination, prefix, keep_vars) else: # Convert back subclass before serializing self.weight_qbytes_tensor().save_to_state_dict(destination, prefix, keep_vars) def weight_qbytes_tensor(self): """Convert back a subclass to a WeightQBytesTensor This is required to make sure only standard packing is used when serializing. """ raise NotImplementedError def __tensor_flatten__(self): inner_tensors = ["_data", "_scale"] meta = { "qtype": self._qtype.name, "axis": str(self._axis), "size": str(list(self.size())), "stride": str(list(self.stride())), "activation_qtype": "none" if self.activation_qtype is None else self.activation_qtype.name, } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 5 data, scale = inner_tensors["_data"], inner_tensors["_scale"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) activation_qtype = None if meta["activation_qtype"] == "none" else qtypes[meta["activation_qtype"]] return WeightQBytesTensor(qtype, axis, size, stride, data, scale, activation_qtype) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return WeightQBytesLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) elif func is torch.equal: input, other = args return input.equal(other) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Do not use directly op, but rather its overload op = op.overloadpacket if op is torch.ops.aten.detach: t = args[0] # Detach is required when copying and deserializing inner_tensor_names, meta = t.__tensor_flatten__() # Detach inner tensors detached_tensors = {} for inner_name in inner_tensor_names: detached_tensors[inner_name] = op(getattr(t, inner_name)) return cls.__tensor_unflatten__(detached_tensors, meta, t.size(), t.stride()) elif op in [torch.ops.aten._to_copy, torch.ops.aten.to]: t = args[0] dtype = kwargs.pop("dtype", t.dtype) device = kwargs.pop("device", t.device) if dtype != t.dtype: raise ValueError("The dtype of a weights Tensor cannot be changed") if type(t) is not WeightQBytesTensor and t.device.type != device.type: # Before moving to another device type, convert back to a WeightQBytesTensor t = t.weight_qbytes_tensor() out_data = op(t._data, device=device, **kwargs) out_scale = op(t._scale, device=device, **kwargs) return WeightQBytesTensor.create( t.qtype, t.axis, t.size(), t.stride(), out_data, out_scale, activation_qtype=t.activation_qtype, requires_grad=t.requires_grad, ) elif op is torch.ops.aten.t and cls is WeightQBytesTensor: t = args[0] out_data = op(t._data) out_scale = t._scale out_axis = t.axis # Manually reverse size and stride because we cannot trust the out_data shape dim0, dim1 = t.size() out_size = torch.Size([dim1, dim0]) out_stride = t.stride()[::-1] if t.axis is not None: # We need to transpose also the scale out_scale = op(out_scale) out_axis = 0 if out_axis == -1 else -1 return WeightQBytesTensor(t.qtype, out_axis, out_size, out_stride, out_data, out_scale, t.activation_qtype) # No dispatch available: qfallback kwargs = kwargs or {} return qfallback(op, *args, **kwargs) ================================================ FILE: optimum/quanto/tensor/weights/quantization.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from ..qtype import qtype from .qbits import WeightQBitsTensor from .qbytes import WeightQBytesTensor __all__ = ["quantize_weight"] def quantize_weight( t: torch.Tensor, qtype: qtype, axis: int, scale: torch.Tensor, shift: Optional[torch.Tensor] = None, group_size: Optional[int] = None, activation_qtype: Optional[qtype] = None, optimized: Optional[bool] = True, ): """Quantize a weight Tensor. Weights are always quantized per-axis. Args: t (`torch.Tensor`): the weight Tensor to quantize qtype (`quanto.qtype`): The target quantization type axis ('int`): The quantization axis (0 or -1) scale (`torch.Tensor`): the quantization scale shift (`Optional[torch.Tensor]`): optional shift to apply group_size (`Optional[int]`): The quantization group size activation_qtype (`Optional[qtype]`, defaults to `None`): Which quantization type is being used for the activations. The function `quantize_weight` initializes `torch.Tensor` subclasses that may depend on the activation dtype. `None` corresponds to no quantization. optimized (`Optional[bool]`, defaults to True): If True, the quantization algorithm will select the most efficient kernel for the weights and format the resulting Tensor accordingly. If False, a kernel-agnostic Tensor will be returned (but it can be optimized later explicitly by calling QTensor.optimize() or implicitly by moving it to a specific device). Returns: A quantized Tensor. """ if axis not in (0, -1): raise ValueError("axis parameter must be 0 (first axis) or -1 (last axis)") if qtype.bits == 8: if shift is not None: raise ValueError("shift cannot be specified for 8-bit qtypes") if group_size is not None: raise ValueError("group_size cannot be specified for 8-bit qtypes.") if axis is not None and t.shape[axis] == 1: # Quantizing along an axis of dimension 1 means quantizing per-tensor axis = None return WeightQBytesTensor.quantize(t, qtype, axis, scale, activation_qtype, optimized) if shift is None: raise ValueError("shift must be specified for qtypes lower than 8-bit") return WeightQBitsTensor.quantize(t, qtype, axis, group_size, scale, shift, optimized) ================================================ FILE: optimum/quanto/tensor/weights/reordering.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Union import torch __all__ = ["reorder", "reverse"] def reorder(t: torch.Tensor, permutation: Union[torch.Tensor, List[int]]): """Reorder a Tensor using a permutation Args: t (`torch.Tensor`): the Tensor to reorder permutation (`Union[torch.Tensor, List[int]]`): the permutation to apply Returns: The reordered torch.Tensor """ block_size = permutation.numel() if isinstance(permutation, torch.Tensor) else len(permutation) reordered = t.reshape((-1, block_size))[:, permutation].reshape(t.shape) return reordered.contiguous() def reverse(permutation: Union[torch.Tensor, List[int]]): """Reverse a permutation The reversed permutation can be used to revert a reordered Tensor to its original ordering. Args: permutation (`Union[torch.Tensor, List[int]]`): the permutation to reverse Returns: The reversed permutation """ block_size = permutation.numel() if isinstance(permutation, torch.Tensor) else len(permutation) reversed = torch.empty((block_size,), dtype=torch.int64) reversed[permutation] = torch.arange(block_size) return reversed ================================================ FILE: optimum/quanto/tensor/weights/tinygemm/__init__.py ================================================ from .packed import * from .qbits import * ================================================ FILE: optimum/quanto/tensor/weights/tinygemm/packed.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast from copy import copy import torch from torch.utils import _pytree as pytree __all__ = ["TinyGemmPackedTensor"] class TinyGemmPackedTensor(torch.Tensor): @staticmethod def __new__(cls, data, size, stride, requires_grad=False): # TinyGemmPackedTensor represents uint8 data and can therefore NEVER require gradient assert requires_grad is False return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=torch.uint8, device=data.device, requires_grad=requires_grad ) def __init__(self, data, size, stride, requires_grad=False): self._data = data def __repr__(self): return f"TinyGemmPackedTensor({self._data})" @classmethod def pack(cls, t): """Pack a torch.Tensor for tinygemm kernel This packs uint4 weights in an int32 tensor as expected by the torch tinygemm mixed mm kernel Args: t (`torch.Tensor`): The un-packed `torch.uint8` tensor Returns: A `TinyGemmPackedTensor`. """ inner_ktiles = 2 t = t.to(torch.int32).contiguous() if t.device.type == "cpu": data = torch._convert_weight_to_int4pack_for_cpu(t, innerKTiles=inner_ktiles) elif t.device.type == "xpu": t_uint8 = (t[::, 1::2] << 4 | t[::, ::2]).to(torch.uint8) data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles) else: t_uint8 = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8) data = torch._convert_weight_to_int4pack(t_uint8, innerKTiles=inner_ktiles) # We need to store size and stride to make sure the unpacked data has the correct shape return TinyGemmPackedTensor(data, t.size(), t.stride()) def unpack(self): """Unpack the packed tensor to a torch.Tensor Packing is device specific and implemented in undocumented dedicated kernels that are synchronized with the corresponding matrix multiplication operation. Instead of implementing a dedicated unpacking code, we pass an identity matrix to the mm operation with identity scale and shifts to produce the unpacked uint8 weights. Returns: An unpacked uint8 `torch.Tensor` expanded along the second dimension. """ out_features, in_features = self.size() # We need to pass a group_size to the mm and format the scale and shift accordingly, # although it does not modify the calculation since we use identity scales and shifts. # We arbitrarily choose the smallest group_size to be sure it divides in_features group_size = 32 scale_and_shift_shape = (in_features // group_size, out_features, 2) # Initialize identity scale id_scale_and_shift = torch.ones(scale_and_shift_shape, dtype=torch.bfloat16, device=self.device) # Set shift to mid-point, i.e. 2 **(bits - 1) id_scale_and_shift[:, :, 1] = 8 identity = torch.eye(in_features, dtype=torch.bfloat16, device=self.device) if self._data.device.type == "cpu": unpacked_data = torch._weight_int4pack_mm_for_cpu(identity, self._data, group_size, id_scale_and_shift) else: unpacked_data = torch._weight_int4pack_mm(identity, self._data, group_size, id_scale_and_shift) return unpacked_data.t().to(torch.uint8) @property def dtype(self): return torch.uint8 def __tensor_flatten__(self): inner_tensors = ["_data"] # Since meta can be used for serialization, use only AST compatible strings meta = { "size": str(list(self.size())), "stride": str(self.stride()), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 1 assert len(meta) == 2 data = inner_tensors["_data"] # Meta should contain only AST compatible strings size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return TinyGemmPackedTensor(data, size, stride) __torch_function__ = torch._C._disabled_torch_function_impl @classmethod def __torch_dispatch__(cls, op, types, args, kwargs=None): # Convert back to tensor before calling any operation except detach and move if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) return TinyGemmPackedTensor(data, t.size(), t.stride()) elif op.overloadpacket in (torch.ops.aten._to_copy, torch.ops.aten.to): t = args[0] dtype = kwargs.get("dtype", torch.uint8) if dtype != torch.uint8: raise ValueError(f"TinyGemmPackedTensor are torch.uint8 only and cannot be moved to {dtype}.") data_kwargs = copy(kwargs) data_kwargs["dtype"] = t._data.dtype if kwargs.get("device", t.device).type != t.device.type: # Packing is device specific, so we need to unpack before moving unpacked = t.unpack() unpacked = op(unpacked, **data_kwargs) return TinyGemmPackedTensor.pack(unpacked) # If we stay on the same device type, just copy/move packed data data = op(t._data, **data_kwargs) return TinyGemmPackedTensor(data, t.size(), t.stride()) args, kwargs = pytree.tree_map_only(TinyGemmPackedTensor, lambda x: x.unpack(), (args, kwargs or {})) return op(*args, **kwargs) def numpy(self): return self.unpack().cpu().numpy() ================================================ FILE: optimum/quanto/tensor/weights/tinygemm/qbits.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import torch from torch.autograd import Function from ...function import QuantizedLinearFunction from ...grouped import group, ungroup from ...qtype import qtypes from ..qbits import WeightQBitsTensor from .packed import TinyGemmPackedTensor __all__ = ["TinyGemmWeightQBitsTensor"] class TinyGemmQBitsDequantizer(Function): @staticmethod def forward(ctx, t): # There is no custom dequantize kernel available, so we need to convert back to a QBitsTensor qbt = t.weight_qbits_tensor() return qbt.dequantize() @staticmethod def backward(ctx, gO): return gO class TinyGemmQBitsLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias): ctx.save_for_backward(input, other) if type(input) is not torch.Tensor: input = input.dequantize() in_features = input.shape[-1] out_features = other.shape[0] output_shape = input.shape[:-1] + (out_features,) if input.device.type == "cpu": output = torch._weight_int4pack_mm_for_cpu( input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift ) else: output = torch._weight_int4pack_mm( input.reshape(-1, in_features), other._data._data, other._group_size, other._scale_shift ) output = output.reshape(output_shape) if bias is not None: output = output + bias return output class TinyGemmWeightQBitsTensor(WeightQBitsTensor): @staticmethod def __new__(cls, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False): if isinstance(scale_shift, torch.Tensor): dtype = scale_shift.dtype assert data.device == scale_shift.device else: assert isinstance(scale_shift, (tuple, list)) scale, shift = scale_shift dtype = scale.dtype assert shift.dtype == dtype assert data.device == scale.device assert data.device == shift.device return torch.Tensor._make_wrapper_subclass( cls, size, strides=stride, dtype=dtype, device=data.device, requires_grad=requires_grad ) def __init__(self, qtype, axis, group_size, size, stride, data, scale_shift, requires_grad=False): assert axis == 0 if not isinstance(data, TinyGemmPackedTensor): assert type(data) is torch.Tensor assert isinstance(scale_shift, (tuple, list)) # Format data, scale and shift for tinygemm ungrouped = ungroup(data, axis=0, orig_shape=size) self._data = TinyGemmPackedTensor.pack(ungrouped) out_features, in_features = size scale, shift = scale_shift scale = scale.reshape(out_features, in_features // group_size, 1) shift = shift.reshape(out_features, in_features // group_size, 1) if not shift.dtype.is_floating_point: # Integer shift must be scaled shift = scale * shift # The tinygemm kernel actually uses the mid-point of the quantization range as shift min_range = -shift half_qrange = 2 ** (qtype.bits - 1) * scale # This operation is lossy for bfloat16, and the actual value of shift will be lost shift = min_range + half_qrange # Scale and shift are actually stored in the same tensor self._scale_shift = torch.cat([scale, shift], 2).transpose(0, 1).contiguous() else: self._data = data self._scale_shift = scale_shift self._qtype = qtype self._axis = axis self._group_size = group_size def dequantize(self): return TinyGemmQBitsDequantizer.apply(self) def weight_qbits_tensor(self): """Convert back to a WeightQBitsTensor This is required to make sure only standard packing is used when serializing. """ data = group(self._data.unpack(), axis=self.axis, group_size=self._group_size) n_scales = self._scale_shift.numel() // 2 scale = self._scale_shift[:, :, 0].t().reshape((n_scales, 1)) shift = self._scale_shift[:, :, 1].t().reshape((n_scales, 1)) half_qrange = 2 ** (self.qtype.bits - 1) * scale # This operation is lossy for bfloat16, and the actual value of shift will not be recovered shift = half_qrange - shift return WeightQBitsTensor( self._qtype, self._axis, self._group_size, self.size(), self.stride(), data, scale, shift ) def __tensor_flatten__(self): inner_tensors = ["_data", "_scale_shift"] # Since meta can be used for serialization, use only strings meta = { "qtype": self._qtype.name, "axis": str(self._axis), "group_size": str(self._group_size), "size": str(list(self.size())), "stride": str(list(self.stride())), } return inner_tensors, meta @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): assert len(inner_tensors) == 2 assert len(meta) == 5 data, scale_shift = inner_tensors["_data"], inner_tensors["_scale_shift"] # Meta should only contain strings, AST compatible except qtype qtype = qtypes[meta["qtype"]] axis = ast.literal_eval(meta["axis"]) group_size = ast.literal_eval(meta["group_size"]) size = ast.literal_eval(meta["size"]) stride = ast.literal_eval(meta["stride"]) return TinyGemmWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale_shift) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Dispatch torch functions applied on this subtensor This method is called whenever a torch function (such as `torch.nn.functional.linear`) is called with at least one parameter coresponding to this subtensor: - if a quantized implementation exists for the selected function, it is called, - otherwise, the original implementation is called, deactivating further functional dispatch. During the execution of the standard torch function, a second-level of dispatch will happen, but this time directly on individual torch Tensor operations (mainly ATEN). """ kwargs = kwargs or {} if func is torch.nn.functional.linear: def qlinear(input, other, bias=None): return TinyGemmQBitsLinearFunction.apply(input, other, bias) return qlinear(*args, **kwargs) # Defer to operations dispatcher with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) ================================================ FILE: pyproject.toml ================================================ [project] name = 'optimum-quanto' description = 'A pytorch quantization backend for optimum.' classifiers = [ 'Development Status :: 2 - Pre-Alpha', 'License :: OSI Approved :: Apache Software License', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'Operating System :: OS Independent', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Topic :: Scientific/Engineering :: Artificial Intelligence' ] keywords = ['torch', 'quantization'] requires-python = '>=3.9.0' authors = [{ name = 'David Corvoysier' }] maintainers = [ {name = "HuggingFace Inc. Special Ops Team", email="hardware@huggingface.co"}, ] dependencies = ['torch>=2.6.0', 'ninja', 'numpy', 'safetensors', 'huggingface_hub'] license = { text = 'Apache-2.0' } readme = 'README.md' dynamic = ['version'] [project.urls] homepage = 'https://github.com/huggingface/optimum-quanto' [project.optional-dependencies] dev = ['pytest', 'ruff'] examples = [ 'torchvision', 'transformers', 'diffusers', 'datasets', 'accelerate', 'sentencepiece', 'scipy' ] [tool.setuptools.packages.find] where = ["."] include = ["optimum*"] [tool.setuptools.dynamic] version = {attr = 'optimum.quanto.__version__'} [build-system] requires = ['setuptools>65.5.1', 'setuptools_scm'] build-backend = 'setuptools.build_meta' [tool.ruff] # Configuration for Ruff line-length = 119 # Same line-length as Black had # Linting rules: # Never enforce `E501` (line length violations) and other specific rules. lint.ignore = ['C901', 'E501', 'E741'] lint.select = ['C', 'E', 'F', 'I', 'W'] # Ignore import violations in all `__init__.py` files. [tool.ruff.lint.per-file-ignores] '__init__.py' = ['E402', 'F401', 'F403', 'F811'] # isort configuration (to sort imports) [tool.ruff.lint.isort] lines-after-imports = 2 known-first-party = ['optimum.quanto'] ================================================ FILE: setup.sh ================================================ #!/bin/bash NIGHTLY=${1:-0} VENV=".venv" if [ ! -d "${VENV}" ]; then python3 -m venv ${VENV} fi . ${VENV}/bin/activate if [ "$NIGHTLY" -eq "0" ]; then pip install --upgrade torch torchvision torchaudio else pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 fi # Build tools pip install ruff pytest build # For examples pip install accelerate transformers datasets ================================================ FILE: tests/cli/cli_helpers.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import pytest requires_optimum_cli = pytest.mark.skipif( importlib.util.find_spec("optimum.commands") is None, reason="optimum-cli is required" ) ================================================ FILE: tests/cli/test_quantize_cli.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import subprocess from tempfile import TemporaryDirectory import pytest from cli_helpers import requires_optimum_cli from optimum.quanto import quantization_map @requires_optimum_cli @pytest.mark.parametrize("weights", ["int4", "int8"]) def test_export_decoder_cli(weights): from optimum.quanto import QuantizedModelForCausalLM model_id = "facebook/opt-125m" with TemporaryDirectory() as tempdir: subprocess.run( [ "optimum-cli", "quanto", "quantize", "--model", model_id, "--weights", f"{weights}", tempdir, ], shell=False, check=True, ) # Verify we can reload the quantized model qmodel = QuantizedModelForCausalLM.from_pretrained(tempdir) qmap = quantization_map(qmodel) for layer_qconfig in qmap.values(): assert layer_qconfig["weights"] == f"q{weights}" ================================================ FILE: tests/conftest.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch devices = ["cpu"] if torch.cuda.is_available(): devices += ["cuda"] elif torch.backends.mps.is_available(): devices += ["mps"] elif torch.xpu.is_available(): devices += ["xpu"] @pytest.fixture(scope="module", params=devices) def device(request): return torch.device(request.param) def pytest_configure(config): # register additional markers config.addinivalue_line("markers", "skip_device(type): mark test to be skipped for the specified device type") def pytest_runtest_call(item): fixture_name = "device" if fixture_name in item.fixturenames: # TODO: should be able to recover the fixture id instead of the actual value fixture_arg = item.funcargs[fixture_name].type skip_marks = {mark.args[0] for mark in item.iter_markers(name=f"skip_{fixture_name}")} if fixture_arg in skip_marks: pytest.skip(f"Test skipped for {fixture_name} {fixture_arg}") ================================================ FILE: tests/helpers.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import gc import os import pytest import torch from packaging import version from optimum.quanto import ( AbsmaxOptimizer, MaxOptimizer, absmax_scale, qint8, quantize_activation, quantize_weight, ) def torch_min_version(v): def torch_min_version_decorator(test): @functools.wraps(test) def test_wrapper(*args, **kwargs): if version.parse(torch.__version__) < version.parse(v): pytest.skip(f"Requires pytorch >= {v}") test(*args, **kwargs) return test_wrapper return torch_min_version_decorator def device_eq(a, b): if a.type != b.type: return False a_index = a.index if a.index is not None else 0 b_index = b.index if b.index is not None else 0 return a_index == b_index def random_tensor(shape, dtype=torch.float32, device="cpu"): if dtype.is_floating_point: rand_dtype = dtype if dtype.itemsize > 1 else torch.float16 # Generate a random tensor between -1. and 1. t = torch.rand(shape, dtype=rand_dtype, device=device) * 2 - 1 return t.to(dtype) else: assert dtype == torch.int8 return torch.randint(-127, 127, shape, dtype=torch.int8, device=device) def random_qactivation(shape, qtype=qint8, dtype=torch.float32, device="cpu"): t = random_tensor(shape, dtype, device=device) scale = absmax_scale(t, qtype=qtype) return quantize_activation(t, qtype=qtype, scale=scale) def random_qweight(shape, qtype, dtype=torch.float32, axis=0, group_size=None, device="cpu"): device = device.type if isinstance(device, torch.device) else device t = random_tensor(shape, dtype, device=device) if qtype.bits == 8: scale = AbsmaxOptimizer()(t, qtype=qtype, axis=axis) shift = None else: optimizer_kwargs = {"qtype": qtype, "axis": axis, "group_size": group_size} if device == "xpu": optimizer_kwargs.update({"zeropoint": True}) scale, shift = MaxOptimizer()(t, **optimizer_kwargs) return quantize_weight(t, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size, optimized=False) def assert_similar(a, b, atol=None, rtol=None): """Verify that the cosine similarity of the two inputs is close to 1.0 everywhere""" assert a.dtype == b.dtype assert a.shape == b.shape if atol is None: # We use torch finfo resolution atol = torch.finfo(a.dtype).resolution if rtol is None: # Please refer to that discussion for default rtol values based on the float type: # https://scicomp.stackexchange.com/questions/43111/float-equality-tolerance-for-single-and-half-precision rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype] sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0) if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol): max_deviation = torch.min(sim) raise ValueError(f"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}") def get_device_memory(device): gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return torch.cuda.memory_allocated() elif device.type == "mps": torch.mps.empty_cache() return torch.mps.current_allocated_memory() elif device.type == "xpu": torch.xpu.empty_cache() return torch.xpu.memory_allocated() return None _run_staging = os.getenv("HUGGINGFACE_CO_STAGING", False) ================================================ FILE: tests/library/test_extensions.py ================================================ import platform import pytest import torch from packaging import version from optimum.quanto.library.extensions import get_extension, is_extension_available def _is_xpu_available(): # SYCL extension support is added in torch>=2.7 on Linux if platform.system() != "Linux": return False if version.parse(torch.__version__).release < version.parse("2.7").release: return False return torch.xpu.is_available() extension_names = ["quanto_cpp"] if torch.cuda.is_available(): if torch.version.cuda: extension_names.append("quanto_cuda") if torch.version.hip: extension_names.append("quanto_hip") if torch.backends.mps.is_available(): extension_names.append("quanto_mps") if _is_xpu_available(): extension_names.append("quanto_xpu") @pytest.mark.parametrize("extension_name", extension_names) def test_extension_available(extension_name): assert is_extension_available(extension_name) @pytest.mark.parametrize("extension_name", extension_names) def test_extension_compilation(extension_name): extension = get_extension(extension_name) assert extension.lib is not None ================================================ FILE: tests/library/test_mm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_tensor from optimum.quanto.library.extensions import is_extension_available from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking from optimum.quanto.tensor.weights.marlin import marlin_permute from optimum.quanto.tensor.weights.marlin.fp8.packed import get_scale_perms, pack_fp8_as_int32 from optimum.quanto.tensor.weights.marlin.int4.packed import MarlinInt4PackedTensor @pytest.mark.parametrize("batch_size", [1, 10, None], ids=["single", "batched", "static"]) @pytest.mark.parametrize("input_features", [32, 50]) @pytest.mark.parametrize("output_features", [48, 50, 64]) @pytest.mark.parametrize("input_dtype", [None, torch.int8], ids=["i-as-out", "i-int8"]) @pytest.mark.parametrize( "weight_dtype", [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.int8], ids=["w-float8", "w-float8-uz", "w-int8"] ) @pytest.mark.parametrize("output_dtype", [torch.float16, torch.bfloat16], ids=["o-fp16", "o-bf16"]) def test_qbytes_mm(batch_size, input_features, input_dtype, weight_dtype, output_features, output_dtype, device): if device.type in ["mps"] and weight_dtype.is_floating_point: pytest.skip(f"Float8 types are not supported on {device.type} device") input_shape = (32, input_features) if batch_size is not None: input_shape = (batch_size,) + input_shape if input_dtype is None: input_dtype = output_dtype input = random_tensor(input_shape, dtype=input_dtype, device=device) weight = random_tensor((output_features, input_features), dtype=weight_dtype, device=device) # Use a scale small enough to prevent overflows scale = random_tensor((output_features, 1), dtype=output_dtype, device=device) / 1e3 output = torch.ops.quanto.qbytes_mm(input, weight, scale) expected = torch.matmul(input.to(scale.dtype), (weight.to(scale.dtype) * scale).t()) assert_similar(expected, output) @pytest.mark.skipif( (not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8) and not torch.xpu.is_available(), reason="The test requires CUDA device >= sm80 or Intel XPU", ) @pytest.mark.parametrize("in_features, out_features", [(256, 256), (512, 256)]) @pytest.mark.parametrize("batch_size, tokens", [(4, 1), (10, 128)], ids=["gemv", "gemm"]) def test_gemm_fp16_int4(batch_size, tokens, in_features, out_features): """This test verifies that the GEMM operation is equivalent to torch.mm.""" bits = 4 group_size = 128 # Hard-coded in kernels device = torch.device(0) # XPU can also share this setting. input_shape = (batch_size, tokens, in_features) # FIXME: does not work if inputs are negative !!?? inputs = torch.rand(input_shape, dtype=torch.float16, device=device) qmax = 2**bits other_shape = (out_features, in_features) other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device) pack_type = AWQPacking.V1 if device.type == "xpu" else AWQPacking.V2 packed_other_data = AWQPackedTensor.pack(other_data, packing=pack_type)._data # The GEMM kernel works on transposed scales scales_shape = (in_features // group_size, out_features) other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax # The GEMM kernel works on transposed, negated and scaled shifts qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) other_shifts = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device) # Negate and scale, xpu should keep the original int8 shifts other_scaled_shifts = other_shifts if device.type == "xpu" else -other_shifts * other_scales # Evaluate mm outputs using the GEMM kernel lib_outputs = torch.ops.quanto.gemm_f16i4_awq( inputs, packed_other_data, other_scales, other_scaled_shifts, rows=inputs.numel() // inputs.shape[-1], out_cols=out_features, in_cols=in_features, bits=4, group_size=group_size, ) # Transpose other data and reshape it to align it with transposed scales and zeros other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features) # Dequantize transposed other other_t = (other_data_t - other_shifts) * other_scales # Reshape it as expected by the matmul other_t = other_t.reshape(in_features, out_features) # Evaluate the matrix multiplication using pytorch float16 mm pt_outputs = torch.matmul(inputs, other_t) # Verify the results are similar assert_similar(lib_outputs, pt_outputs, rtol=5e-3) @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA device >= sm80 not available", ) @pytest.mark.parametrize("tokens", [1, 10, 128]) @pytest.mark.parametrize("in_features, out_features", [(256, 1024), (512, 2048)]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) def test_fp8_marlin(tokens, in_features, out_features, dtype): device = torch.device("cuda") input_shape = (tokens, in_features) inputs = torch.rand(input_shape, dtype=dtype, device=device) other_shape = (in_features, out_features) other_data = torch.rand(other_shape, dtype=dtype, device=device).to(torch.float8_e4m3fn) other_data_int32 = pack_fp8_as_int32(other_data) perm = torch.empty(0, dtype=torch.int, device=device) other_data_repack = torch.ops.quanto.pack_fp8_marlin( b_q_weight=other_data_int32, perm=perm, size_k=in_features, size_n=out_features, num_bits=8 ) other_scale = torch.rand(1, out_features, dtype=dtype, device=device) other_scale_original = other_scale.clone() scale_perm_single = get_scale_perms() other_scale = other_scale.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] other_scale = other_scale.reshape(-1, out_features).contiguous() workspace = torch.zeros(out_features // 64 * 16, dtype=torch.int, device=device) lib_outputs = torch.ops.quanto.gemm_f16f8_marlin( a=inputs, b_q_weight=other_data_repack, b_scales=other_scale, workspace=workspace, num_bits=8, size_m=tokens, size_n=out_features, size_k=in_features, ) # Evaluate the matrix multiplication using pytorch mm other = other_data.to(dtype) * other_scale_original pt_outputs = torch.matmul(inputs.to(dtype), other) # Verify the results are similar assert_similar(lib_outputs, pt_outputs) @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA device >= sm80 not available", ) @pytest.mark.parametrize("in_features, out_features", [(256, 256), (512, 256)]) @pytest.mark.parametrize("batch_size, tokens", [(1, 16), (10, 128)], ids=["small", "medium"]) def test_gemm_marlin_fp16_int4(batch_size, tokens, in_features, out_features): bits = 4 group_size = 128 # Hard-coded in kernels device = torch.device("cuda") input_shape = (batch_size, tokens, in_features) # FIXME: does not work if inputs are negative !!?? inputs = torch.rand(input_shape, dtype=torch.float16, device=device) qmax = 2**bits other_shape = (out_features, in_features) other_data = torch.randint(0, qmax, other_shape, dtype=torch.uint8, device=device) # The GEMM kernel works on transposed scales scales_shape = (in_features // group_size, out_features) other_scales = torch.rand(scales_shape, dtype=torch.float16, device=device) / qmax # This kernel works on transposed, negated and scaled zeropoints qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) other_shifts = torch.randint(qmin, qmax, scales_shape, dtype=torch.int8, device=device) # Negate and scale other_scaled_shifts = -other_shifts * other_scales workspace = torch.zeros(out_features // 128 * 16, dtype=torch.int, device=inputs.device) packed_other_data_marlin = MarlinInt4PackedTensor.pack(other_data)._data # Apply scale and shift permutations other_scales_marlin = marlin_permute(other_scales) other_scaled_shifts_marlin = marlin_permute(other_scaled_shifts) lib_outputs = torch.ops.quanto.gemm_f16i4_marlin( inputs, packed_other_data_marlin, other_scales_marlin, other_scaled_shifts_marlin, workspace ) # Transpose other data and reshape it to align it with transposed scales and zeros other_data_t = other_data.t().reshape(group_size, in_features // group_size, out_features) # Dequantize transposed other other_t = other_data_t * other_scales + other_scaled_shifts # Reshape it as expected by the matmul other_t = other_t.reshape(in_features, out_features) # Evaluate the matrix multiplication using pytorch float16 mm pt_outputs = torch.matmul(inputs, other_t) # Verify the results are similar assert_similar(lib_outputs, pt_outputs, rtol=1e-3) ================================================ FILE: tests/library/test_quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, device_eq, random_tensor from optimum.quanto import ( MaxOptimizer, absmax_scale, qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint2, qint4, qint8, ) from optimum.quanto.tensor.grouped import ungroup @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) @pytest.mark.parametrize( "axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"], ) def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) data = torch.ops.quanto.quantize_symmetric(a, dtype=qtype.dtype, axis=axis, scale=scale) assert data.dtype == qtype.dtype assert device_eq(data.device, device) assert_similar(a, data * scale) @pytest.mark.skip_device("mps") @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize( "qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"], ) @pytest.mark.parametrize( "axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"], ) def test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) data = torch.ops.quanto.quantize_symmetric(a, dtype=qtype.dtype, axis=axis, scale=scale) assert data.dtype == qtype.dtype assert device_eq(data.device, device) assert_similar(a, data.to(dtype) * scale, atol=5e-3) @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("group_size", [None, 8], ids=["channel-wise", "group-wise"]) @pytest.mark.parametrize("shift_mode", ["zeropoint", "float"]) def test_affine_quantize(input_shape, dtype, qtype, axis, group_size, shift_mode, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size) if shift_mode == "zeropoint": shift = torch.round(shift / scale).to(torch.int8) data = torch.ops.quanto.quantize_affine(a, qtype.bits, axis, group_size, scale, shift) assert data.dtype == torch.uint8 assert device_eq(data.device, device) if shift_mode == "zeropoint": qa = (data - shift) * scale else: qa = data * scale - shift atol = { qint4: { "zeropoint": 4e-3, "float": 3e-3, }, qint2: { "zeropoint": 6e-2, "float": 5e-2, }, }[qtype][shift_mode] if group_size is not None: qa = ungroup(qa, axis=axis, orig_shape=a.shape) assert_similar(a, qa, atol=atol) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) def test_affine_quantize_integer_tensor(dtype, qtype, device): """This test verifies that an integer tensor in the correct range is preserved.""" bits = qtype.bits qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) - 1 a = torch.tensor(range(qmin, qmax + 1), dtype=dtype).to(device) scale, shift = MaxOptimizer()(a, qtype=qtype, axis=0, group_size=None) zeropoint = torch.round(shift / scale) data = torch.ops.quanto.quantize_affine(a, bits, 0, None, scale, zeropoint) assert data.dtype == torch.uint8 assert device_eq(data.device, device) assert torch.equal(a, data - zeropoint) ================================================ FILE: tests/library/test_unpack.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from optimum.quanto.tensor.packed import pack_weights @pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) @pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"]) def test_unpack(bits, shape, device): qmax = 2**bits a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) packed_a = pack_weights(a, bits) unpacked_a = torch.ops.quanto.unpack(packed_a, bits) assert unpacked_a.dtype == torch.uint8 assert torch.equal(unpacked_a, a) ================================================ FILE: tests/models/conftest.py ================================================ import pytest from huggingface_hub.constants import _staging_mode @pytest.fixture def staging(): """A pytest fixture only available in huggingface_hub staging mode If the huggingface_hub is not operating in staging mode, tests using that fixture are automatically skipped. Returns: a Dict containing a valid staging user and token. """ if not _staging_mode: pytest.skip("requires huggingface_hub staging mode") return { "user": "__DUMMY_TRANSFORMERS_USER__", # Not critical, only usable on the sandboxed CI instance. "token": "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL", } @pytest.fixture(autouse=True) def skip_if_staging(request): if _staging_mode: if "staging" not in request.fixturenames: pytest.skip("requires huggingface_hub standard mode") ================================================ FILE: tests/models/test_quantized_model_for_causal_lm.py ================================================ import uuid from tempfile import TemporaryDirectory import pytest import torch from huggingface_hub import delete_repo from optimum.quanto import QModuleMixin, is_transformers_available, qint4, qint8 def quantized_model_for_causal_lm(model_id, qtype, exclude, from_config=False): from transformers import AutoModelForCausalLM, OPTConfig from optimum.quanto import QuantizedModelForCausalLM if from_config: config = OPTConfig( **{ "activation_dropout": 0.0, "activation_function": "relu", "architectures": ["OPTForCausalLM"], "attention_dropout": 0.0, "bos_token_id": 2, "do_layer_norm_before": True, "dropout": 0.1, "eos_token_id": 2, "ffn_dim": 32, "hidden_size": 8, "init_std": 0.02, "layerdrop": 0.0, "max_position_embeddings": 16, "model_type": "opt", "num_attention_heads": 2, "num_hidden_layers": 2, "pad_token_id": 1, "prefix": "", "torch_dtype": "float16", "use_cache": True, "vocab_size": 64, "word_embed_proj_dim": 8, } ) model = AutoModelForCausalLM.from_config(config).eval() else: model = AutoModelForCausalLM.from_pretrained(model_id) return QuantizedModelForCausalLM.quantize(model, weights=qtype, exclude=exclude) def compare_models(a_model, b_model): # Compare tensors for (a_name, a_m), (b_name, b_m) in zip(a_model.named_modules(), b_model.named_modules()): assert a_name == b_name if isinstance(a_m, QModuleMixin): assert isinstance(b_m, QModuleMixin) if isinstance(b_m, QModuleMixin): assert isinstance(a_m, QModuleMixin) if isinstance(a_m, QModuleMixin): assert torch.equal(a_m.weight, b_m.weight) for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()): assert a_p_name == b_p_name assert isinstance(a_p, torch.Tensor) assert torch.equal(a_p, b_p) # Compare model outputs inputs = torch.ones((1, 1), dtype=torch.int64) with torch.no_grad(): output_a = a_model.forward(inputs) output_b = b_model.forward(inputs) assert torch.equal(output_a.logits, output_b.logits) for i, a_key_value in enumerate(output_a.past_key_values): b_key_value = output_b.past_key_values[i] for j, a_value in enumerate(a_key_value): assert torch.equal(a_value, b_key_value[j]) @pytest.mark.skipif(not is_transformers_available(), reason="requires transformers") @pytest.mark.parametrize("model_id", ["facebook/opt-125m"]) @pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"]) @pytest.mark.parametrize("exclude_lm_head", [True, False], ids=["full", "no_lm_head"]) def test_quantized_model_for_causal_lm_base(model_id, qtype, exclude_lm_head): from optimum.quanto import QuantizedModelForCausalLM exclude = "lm_head" if exclude_lm_head else None quantized = quantized_model_for_causal_lm(model_id, qtype, exclude) with TemporaryDirectory() as tmpdir: quantized.save_pretrained(tmpdir) requantized = QuantizedModelForCausalLM.from_pretrained(tmpdir) compare_models(quantized, requantized) @pytest.mark.skipif(not is_transformers_available(), reason="requires transformers") def test_quantized_model_for_causal_lm_sharded(): from optimum.quanto import QuantizedModelForCausalLM model_id = "facebook/opt-125m" qtype = qint4 quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None) with TemporaryDirectory() as tmpdir: quantized.save_pretrained(tmpdir, max_shard_size="100MB") requantized = QuantizedModelForCausalLM.from_pretrained(tmpdir) compare_models(quantized, requantized) @pytest.mark.skipif(not is_transformers_available(), reason="requires transformers") @pytest.mark.parametrize("in_org", [True, False], ids=["org", "user"]) def test_causal_lm_base_push_to_hub(staging, in_org): from optimum.quanto import QuantizedModelForCausalLM identifier = uuid.uuid4() qtype = qint4 exclude = None quantized = quantized_model_for_causal_lm(None, qtype, exclude, from_config=True) repo_id = f"test-model-{identifier}" if in_org: quantized.push_to_hub(repo_id, token=staging["token"]) hub_repo_id = f"{staging['user']}/{repo_id}" else: hub_repo_id = f"valid_org/{repo_id}-org" quantized.push_to_hub(hub_repo_id, token=staging["token"]) requantized = QuantizedModelForCausalLM.from_pretrained(hub_repo_id, token=staging["token"]) compare_models(quantized, requantized) delete_repo(hub_repo_id, token=staging["token"]) @pytest.mark.skipif(not is_transformers_available(), reason="requires transformers") @pytest.mark.parametrize("model_id", ["facebook/opt-125m"]) @pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"]) def test_quantized_model_load_state_dict_non_strict(model_id, qtype): # see issue #278 quantized = quantized_model_for_causal_lm(model_id, qtype, exclude=None) sd = quantized.state_dict() # delete a key used by both qint4 and qint8 from the state dict key = "model.decoder.layers.0.self_attn.k_proj.weight._scale" del sd[key] # strict loading should raise a RuntimeError, which is what PyTorch does in this case with pytest.raises(RuntimeError, match=key): quantized.load_state_dict(sd) # non-strict loading should not raise an errror result = quantized.load_state_dict(sd, strict=False) assert result.missing_keys == [key] ================================================ FILE: tests/models/test_quantized_model_for_pixart.py ================================================ import uuid from tempfile import TemporaryDirectory import pytest import torch from huggingface_hub import delete_repo from optimum.quanto import QModuleMixin, is_diffusers_available, qint4, qint8 def quantized_model_for_pixart(qtype, exclude): from diffusers import PixArtTransformer2DModel from optimum.quanto import QuantizedPixArtTransformer2DModel init_dict = { "sample_size": 8, "num_layers": 1, "patch_size": 2, "attention_head_dim": 2, "num_attention_heads": 2, "in_channels": 4, "cross_attention_dim": 8, "out_channels": 8, "attention_bias": True, "activation_fn": "gelu-approximate", "num_embeds_ada_norm": 8, "norm_type": "ada_norm_single", "norm_elementwise_affine": False, "norm_eps": 1e-6, "use_additional_conditions": False, "caption_channels": None, } torch.manual_seed(0) model = PixArtTransformer2DModel(**init_dict).eval() return QuantizedPixArtTransformer2DModel.quantize(model, weights=qtype, exclude=exclude) def compare_models(a_model, b_model): # Compare tensors for (a_name, a_m), (b_name, b_m) in zip(a_model.named_modules(), b_model.named_modules()): assert a_name == b_name if isinstance(a_m, QModuleMixin): assert isinstance(b_m, QModuleMixin) if isinstance(b_m, QModuleMixin): assert isinstance(a_m, QModuleMixin) if isinstance(a_m, QModuleMixin): assert torch.equal(a_m.weight, b_m.weight) for (a_p_name, a_p), (b_p_name, b_p) in zip(a_m.named_parameters(), b_m.named_parameters()): assert a_p_name == b_p_name assert isinstance(a_p, torch.Tensor) assert torch.equal(a_p, b_p) for (a_b_name, a_b), (b_b_name, b_b) in zip(a_m.named_buffers(), b_m.named_buffers()): assert a_b_name == b_b_name assert isinstance(a_b, torch.Tensor) assert torch.equal(a_b, b_b) # Compare model outputs hidden_states = torch.randn((1, 4, 8, 8)) timesteps = torch.tensor([1.0]) encoder_hidden_states = torch.randn((1, 8, 8)) model_inputs = { "hidden_states": hidden_states, "timestep": timesteps, "encoder_hidden_states": encoder_hidden_states, "added_cond_kwargs": {"aspect_ratio": None, "resolution": None}, "return_dict": False, } with torch.no_grad(): output_a = a_model(**model_inputs)[0] output_b = b_model(**model_inputs)[0] assert torch.allclose(output_a, output_b, atol=1e-3, rtol=1e-3) @pytest.mark.skipif(not is_diffusers_available(), reason="requires diffusers") @pytest.mark.parametrize("qtype", [qint4, qint8], ids=["qint4", "qint8"]) @pytest.mark.parametrize("exclude_proj_out", [True, False], ids=["without_proj_out", "with_proj_out"]) def test_quantized_model_for_pixart(qtype, exclude_proj_out): from optimum.quanto import QuantizedPixArtTransformer2DModel exclude = "proj_out" if exclude_proj_out else None quantized = quantized_model_for_pixart(qtype, exclude) with TemporaryDirectory() as tmpdir: quantized.save_pretrained(tmpdir) requantized = QuantizedPixArtTransformer2DModel.from_pretrained(tmpdir) compare_models(quantized, requantized) @pytest.mark.skipif(not is_diffusers_available(), reason="requires diffusers") @pytest.mark.parametrize("in_org", [True, False], ids=["org", "user"]) def test_push_to_hub(staging, in_org): from optimum.quanto import QuantizedPixArtTransformer2DModel identifier = uuid.uuid4() exclude = None quantized = quantized_model_for_pixart("qint8", exclude) repo_id = f"test-model-{identifier}" if in_org: quantized.push_to_hub(repo_id, token=staging["token"]) hub_repo_id = f"{staging['user']}/{repo_id}" else: hub_repo_id = f"valid_org/{repo_id}-org" quantized.push_to_hub(hub_repo_id, token=staging["token"]) requantized = QuantizedPixArtTransformer2DModel.from_pretrained(hub_repo_id, token=staging["token"]) compare_models(quantized, requantized) delete_repo(hub_repo_id, token=staging["token"]) ================================================ FILE: tests/nn/test_calibrate.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import random_qactivation from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8 from optimum.quanto.nn import QLinear def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device): linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(device) qlinear = QLinear.from_module(linear, weights=qint8, activations=activations) qinputs = random_qactivation( (batch_size, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device ) # Run a first inference without Calibration with torch.no_grad(): qout = qlinear(qinputs) assert torch.all(qlinear.input_scale == 1) assert torch.all(qlinear.output_scale == 1) # Calibrate to adjust input and output scales and set the correct dtype with torch.no_grad(), Calibration(): qout = qlinear(qinputs) assert qout.qtype == activations assert torch.any(qlinear.input_scale != 1) assert torch.any(qlinear.output_scale != 1) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_calibrate_qlinear_activations_int8(batch_size, tokens, embeddings, use_bias, device): _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-qfloat8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_calibrate_qlinear_activations_float8(batch_size, tokens, embeddings, use_bias, activations, device): _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activations, device) def _test_calibrate_custom_module(activations, device): tokens = 10 embeddings = 32 class TwoLinearModel(torch.nn.Module): def __init__(self, embeddings): super().__init__() self.linear1 = torch.nn.Linear(embeddings, embeddings) self.linear2 = torch.nn.Linear(embeddings, embeddings) def forward(self, input): return self.linear2(self.linear1(input)) model = TwoLinearModel(embeddings).to(device) model.linear1 = QLinear.from_module(model.linear1, weights=qint8, activations=activations) model.linear2 = QLinear.from_module(model.linear2, weights=qint8, activations=activations) qinputs = random_qactivation((1, tokens, embeddings), qtype=activations, dtype=torch.float32, device=device) with torch.no_grad(), Calibration(): qout = model(qinputs) assert torch.any(model.linear1.input_scale != 1) assert torch.any(model.linear1.output_scale != 1) assert torch.any(model.linear2.input_scale != 1) assert torch.any(model.linear2.output_scale != 1) assert qout.qtype == activations def test_calibrate_custom_module_activations_int8(device): _test_calibrate_custom_module(qint8, device) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-qfloat8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_calibrate_custom_module_activations_float8(activations, device): _test_calibrate_custom_module(activations, device) ================================================ FILE: tests/nn/test_qattention.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Optional import pytest import torch import torch.utils.checkpoint from helpers import assert_similar, random_tensor from torch import nn from optimum.quanto import Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8, quantize class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, hidden_size=128, num_heads=4, max_position_embeddings=1024, bias=False): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = self.hidden_size // self.num_heads self.max_position_embeddings = max_position_embeddings self.rope_theta = 10000.0 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=bias) self.rotary_emb = RotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) return self.o_proj(attn_output) def _test_quantize_attention(device, dtype=torch.float32, weights=qint8, activations=None, atol=None): att = Attention().to(dtype).to(device) batch_size = 10 seq_len = 64 input_shape = (batch_size, seq_len, att.hidden_size) inputs = random_tensor(input_shape).to(device) with torch.no_grad(): outputs = att(inputs) quantize(att, weights=weights, activations=activations) if activations is None: with torch.no_grad(): qoutputs = att(inputs) else: with torch.no_grad(), Calibration(): qoutputs = att(inputs) assert_similar(outputs, qoutputs, atol=atol) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) def test_quantize_attention_weights_only(weights, device): _test_quantize_attention(device, weights=weights, atol=1e-4) @pytest.mark.skip_device("mps") def test_quantize_attention_weights_only_float8(device): _test_quantize_attention(device, weights=qfloat8_e4m3fn, atol=1e-3) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) def test_quantize_attention_activations_int8(weights, device): _test_quantize_attention(device, weights=weights, activations=qint8, atol=1e-3) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_attention_activations_float8(weights, activations, device): _test_quantize_attention(device, weights=weights, activations=activations, atol=1e-2) ================================================ FILE: tests/nn/test_qconv2d.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qactivation, random_tensor from optimum.quanto import ( ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint4, qint8, ) from optimum.quanto.nn import QConv2d def _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, dtype, device): conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=use_bias).to(dtype).to(device) qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations) assert qconv2d.qweight.qtype == weights inputs = random_tensor((batch_size,) + img_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype with torch.no_grad(), Calibration(): qout = qconv2d(inputs) if activations is not None: assert isinstance(qout, ActivationQBytesTensor) assert qout.qtype == activations # Align weights with quantized linear weights for comparison conv2d.weight = torch.nn.Parameter(qconv2d.qweight.dequantize()) out = conv2d(inputs) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] # We also need to increase atol for float8 itypes atol = {None: dtype_atol, qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[ activations ] assert_similar(out, qout, atol=atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) def test_quantize_conv2d_float16_activations_int8(batch_size, img_shape, out_channels, use_bias, weights, device): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, qint8, torch.float16, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) def test_quantize_conv2d_float32_activations_int8(batch_size, img_shape, out_channels, use_bias, weights, device): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, qint8, torch.float32, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8_e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_conv2d_float16_activations_float8( batch_size, img_shape, out_channels, use_bias, weights, activations, device ): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, torch.float16, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_conv2d_float32_activations_float8( batch_size, img_shape, out_channels, use_bias, weights, activations, device ): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, activations, torch.float32, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) def test_quantize_conv2d_float16_weight_only(batch_size, img_shape, out_channels, use_bias, weights, device): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, None, torch.float16, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) def test_quantize_conv2d_float32_weight_only(batch_size, img_shape, out_channels, use_bias, weights, device): _test_quantize_conv2d(batch_size, img_shape, out_channels, use_bias, weights, None, torch.float32, device) @pytest.mark.parametrize("img_shape", [(3, 32, 32), (10, 32, 32)]) @pytest.mark.parametrize("out_channels", [3, 10]) @pytest.mark.parametrize("activations", [None, qint8], ids=["a-float", "a-int8"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-int4", "w-int8"]) def test_qconv2d_gradient(img_shape, out_channels, activations, weights, device): batch_size = 10 conv2d = torch.nn.Conv2d(img_shape[0], out_channels, kernel_size=3, bias=True).to(device) qconv2d = QConv2d.from_module(conv2d, weights=weights, activations=activations) assert qconv2d.weight.requires_grad is True assert qconv2d.bias.requires_grad is True # Run an inference with identical inputs qinputs = random_qactivation((batch_size,) + img_shape, dtype=torch.float32).to(device) qout = qconv2d(qinputs) out = conv2d(qinputs.dequantize()) # Outputs are not identical because of the quantization assert not torch.equal(qout, out) # Compute gradients and compare gradient = torch.randn(qout.size()).to(device) qout.backward(gradient) out.backward(gradient) # Gradients are nearly identical because they depend only on the input atol = 1e-5 assert_similar(qconv2d.weight.grad, conv2d.weight.grad, atol=atol) assert_similar(qconv2d.bias.grad, conv2d.bias.grad, atol=atol) ================================================ FILE: tests/nn/test_qlayernorm.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qactivation from optimum.quanto import ActivationQBytesTensor, Calibration, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8 from optimum.quanto.nn import QLayerNorm def _test_quantize_layernorm(batch_size, tokens, embeddings, affine, dtype, activations, device): # Instantiate a normalization layer norm = torch.nn.LayerNorm(embeddings, elementwise_affine=affine).to(dtype).to(device) qnorm = QLayerNorm.from_module(norm, activations=activations) qinputs = random_qactivation((batch_size,) + (tokens, embeddings), qtype=activations, dtype=dtype).to(device) # Calibrate to avoid clipping and to set the correct dtype with torch.no_grad(), Calibration(): qout = qnorm(qinputs) qout = qnorm(qinputs) assert isinstance(qout, ActivationQBytesTensor) assert qout.dtype == dtype assert qout.qtype == activations # Compare with the float results out = norm(qinputs.dequantize()) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] # We also need to increase atol for float8 qtypes atol = {qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3, qfloat8_e4m3fnuz: 5e-3}[activations] assert_similar(out, qout, atol=atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) def test_quantize_layernorm_float16_activations_int8(batch_size, tokens, embeddings, affine, device): _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) def test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddings, affine, device): _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embeddings, affine, activations, device): _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float16, activations, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("affine", [True, False], ids=["affine", "non-affine"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-float8-e5m2", "a-float8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_layernorm_float32_activations_float8(batch_size, tokens, embeddings, affine, activations, device): _test_quantize_layernorm(batch_size, tokens, embeddings, affine, torch.float32, activations, device) def test_quantize_layernom_no_activation(): norm = torch.nn.LayerNorm(32) qnorm = QLayerNorm.from_module(norm, activations=None) assert qnorm is None ================================================ FILE: tests/nn/test_qlinear.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io from contextlib import nullcontext import pytest import torch from helpers import assert_similar, random_qactivation, random_tensor from optimum.quanto import ( ActivationQBytesTensor, Calibration, absmax_scale, qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint4, qint8, quantize_activation, ) from optimum.quanto.nn import QLinear def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, activations, dtype, device, atol=None): linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.qweight.qtype == weights input_shape = (batch_size, tokens, embeddings) if activations is not None: qinputs = random_qactivation(input_shape, qtype=activations, dtype=dtype).to(device) inputs = qinputs.dequantize() else: inputs = random_tensor(input_shape, dtype=dtype, device=device) # Run an inference with Calibration to get the correct output dtype context = nullcontext if activations is None else Calibration with torch.no_grad(), context(): qout = qlinear(inputs if activations is None else qinputs) if activations is not None: assert isinstance(qout, ActivationQBytesTensor) assert qout.qtype == activations # Align linear weights with quantized linear weights for comparison linear.weight = torch.nn.Parameter(qlinear.qweight.dequantize()) out = linear(inputs) assert_similar(out, qout, atol=atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_quantize_linear_float16_activations_int8(batch_size, tokens, embeddings, use_bias, dtype, weights, device): _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float16, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings, use_bias, weights, device): # Default atol for float32 is 1e-6 atol = 1e-4 _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float32, device, atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", [qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-qfloat8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_linear_float16_activations_float8( batch_size, tokens, embeddings, use_bias, dtype, weights, activations, device ): atol = 5e-3 _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, activations, dtype, device, atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.skip_device("mps") def test_quantize_linear_float32_activations_float8( batch_size, tokens, embeddings, use_bias, weights, activations, device ): atol = 5e-3 _test_quantize_linear( batch_size, tokens, embeddings, use_bias, weights, activations, torch.float32, device, atol=atol ) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8, qfloat8], ids=["w-qint4", "w-qint8", "float8"]) def test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use_bias, weights, device): if device.type in ["mps"] and weights == qfloat8: pytest.skip(f"Float8 are not supported on {device.type} device") atol = None if device.type == "cuda" and weights == qfloat8 and embeddings % 64 == 0: # FIXME: accuracy is slightly worse using MARLIN FP8 kernels atol = 1e-2 _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device, atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_quantize_linear_float32_weight_only(batch_size, tokens, embeddings, use_bias, weights, device): _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float32, device) @pytest.mark.parametrize("tokens, embeddings", [(10, 32), (10, 256)]) @pytest.mark.parametrize("activations", [None, qint8, qfloat8], ids=["a-float", "a-qint8", "a-float8"]) @pytest.mark.parametrize("weights", [qint4, qint8, qfloat8], ids=["w-qint4", "w-qint8", "w-float8"]) def test_qlinear_gradient(tokens, embeddings, activations, weights, device): if device.type in ["mps"] and (activations == qfloat8 or weights == qfloat8): pytest.skip(f"Float8 is not supported on {device.type} device") batch_size = 10 linear = torch.nn.Linear(embeddings, embeddings).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) assert qlinear.weight.requires_grad is True assert qlinear.bias.requires_grad is True # Run an inference with dynamically quantized inputs inputs = random_tensor((batch_size, tokens, embeddings), dtype=torch.float32, device=device) inputs.requires_grad = True if activations is None: qout = qlinear(inputs) float_inputs = inputs.clone().detach() else: qinputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs, activations)) qout = qlinear(qinputs) # Run an equivalent inference with float inputs float_inputs = qinputs.dequantize().clone().detach() float_inputs.requires_grad = True out = linear(float_inputs) # Outputs are not identical because of the quantization assert not torch.equal(qout, out) # Compute gradients and compare gradient = torch.randn(qout.size()).to(device) qout.backward(gradient) out.backward(gradient) # Bias gradients are identical because they don't depend on inputs and weights atol = 1e-6 assert_similar(qlinear.bias.grad, linear.bias.grad, atol=atol) # Weights gradients are nearly identical, based on identical inputs through subtly different graphs atol = 1e-5 assert_similar(qlinear.weight.grad, linear.weight.grad, atol=atol) # Inputs gradients are slightly different because they depend on the quantized weights atol = {qint8: 1e-5, qint4: 5e-3, qfloat8: 5e-3}[weights] assert_similar(inputs.grad, float_inputs.grad, atol=atol) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8, qfloat8], ids=["w-int4", "w-int8", "w-float8"]) def test_move_qlinear(dtype, use_bias, weights, device): linear = torch.nn.Linear(1024, 1024, bias=use_bias).to(dtype) qlinear = QLinear.from_module(linear, weights=weights) qlinear.freeze() qlinear.to(device) inner_tensor_names, _ = qlinear.weight.__tensor_flatten__() for name in inner_tensor_names: assert getattr(qlinear.weight, name).device.type == device.type if use_bias: assert qlinear.bias.device.type == device.type @pytest.mark.parametrize("features", [10, 256], ids=["per-axis", "per-group"]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("weights", [qint4, qint8, qfloat8], ids=["w-qint4", "w-qint8", "w-qfloat8"]) @pytest.mark.parametrize("activations", [None, qint8, qfloat8], ids=["a-float", "a-qint8", "a-qfloat8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"]) def test_qlinear_serialization(features, use_bias, activations, weights, dtype, weights_only, device): if device.type in ["mps"] and (activations == qfloat8 or weights == qfloat8): pytest.skip(f"Float8 is not supported on {device.type} device") linear = torch.nn.Linear(features, features, bias=use_bias).to(dtype).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) if activations is not None: qinputs = random_qactivation((10, 10, features), qtype=activations, dtype=dtype).to(device) with Calibration(): qlinear(qinputs) qlinear.freeze() b = io.BytesIO() torch.save(qlinear.state_dict(), b) b.seek(0) state_dict = torch.load(b, weights_only=weights_only) qlinear_reloaded = QLinear(features, features, weights=weights, activations=activations, bias=use_bias).to(device) qlinear_reloaded.load_state_dict(state_dict) assert qlinear_reloaded.weight_qtype == weights w = qlinear.weight w_reloaded = qlinear_reloaded.weight assert torch.equal(w, w_reloaded) if activations is not None: assert qlinear_reloaded.activation_qtype == activations for attr in ["input_scale", "output_scale"]: v = getattr(qlinear, attr) v_reloaded = getattr(qlinear_reloaded, attr) assert torch.equal(v, v_reloaded) ================================================ FILE: tests/nn/test_qmodule.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from optimum.quanto import QTensor, qint8, qtypes from optimum.quanto.nn import QLinear @pytest.mark.parametrize("in_features", [8, 16]) @pytest.mark.parametrize("out_features", [32, 64]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) def test_qmodule_freeze(in_features, out_features, use_bias, dtype): qlinear = QLinear(in_features, out_features, bias=use_bias, weights=qint8).to(dtype) assert not qlinear.frozen assert not isinstance(qlinear.weight, QTensor) assert qlinear.weight.dtype == dtype if use_bias: assert not isinstance(qlinear.bias, QTensor) assert qlinear.bias.dtype == dtype qweight = qlinear.qweight assert isinstance(qweight, QTensor) assert qweight.dtype == dtype assert qweight.qtype == qint8 qlinear.freeze() assert qlinear.frozen assert isinstance(qlinear.weight, QTensor) assert qlinear.weight.dtype == dtype assert qlinear.weight.qtype == qint8 if use_bias: assert not isinstance(qlinear.bias, QTensor) assert qlinear.bias.dtype == dtype @pytest.mark.parametrize("weights", ["qint2", "qint4", "qint8", "qfloat8"]) @pytest.mark.parametrize("activations", [None, "qint8", "qfloat8"]) def test_qmodule_qtype_as_string(weights, activations): qlinear = QLinear(16, 64, weights=weights, activations=activations) assert qlinear.weight_qtype == qtypes[weights] assert qlinear.activation_qtype is None if activations is None else qtypes[activations] ================================================ FILE: tests/quantize/test_quantize_mlp.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import nullcontext import pytest import torch from helpers import assert_similar, get_device_memory, random_tensor from optimum.quanto import ( AbsmaxOptimizer, ActivationQBytesTensor, Calibration, MaxOptimizer, QLinear, QTensor, absmax_scale, freeze, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint4, qint8, quantize, quantize_activation, ) class MLP(torch.nn.Module): def __init__(self, input_size, output_size, hidden_size): super().__init__() self.input_layer = torch.nn.Linear(input_size, hidden_size) self.mid_layer = torch.nn.Linear(hidden_size, hidden_size) self.output_layer = torch.nn.Linear(hidden_size, output_size) def forward(self, inputs): x = torch.nn.functional.relu(self.input_layer(inputs)) x = torch.nn.functional.relu(self.mid_layer(x)) return torch.nn.functional.softmax(self.output_layer(x), dim=-1) def check_mlp(model, frozen): assert isinstance(model.input_layer, QLinear) assert isinstance(model.mid_layer, QLinear) assert isinstance(model.output_layer, QLinear) if frozen: assert isinstance(model.input_layer.weight, QTensor) assert isinstance(model.mid_layer.weight, QTensor) assert isinstance(model.output_layer.weight, QTensor) def _test_quantize_mlp(weights, activations, optimizer, frozen, device, atol=1e-6): model = MLP(32, 10, 128).to(device) inputs = random_tensor((1, 32), dtype=torch.float32, device=device) output = model(inputs) quantize(model, weights=weights, activations=activations, optimizer=optimizer) if frozen: freeze(model) check_mlp(model, frozen) if activations is not None: inputs = quantize_activation(inputs, qtype=activations, scale=absmax_scale(inputs)) context = Calibration else: context = nullcontext with context(): qoutput = model(inputs) if activations is not None: assert isinstance(qoutput, ActivationQBytesTensor) assert_similar(output, qoutput, atol=atol) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only(weights, frozen, device): _test_quantize_mlp(weights, None, None, frozen, device) @pytest.mark.skip_device("mps") @pytest.mark.parametrize("weights", [qfloat8_e4m3fn], ids=["w-float8_e4m3fn"]) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only_float8(weights, frozen, device): _test_quantize_mlp(weights, None, None, frozen, device) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_int8_activations(weights, frozen, device): _test_quantize_mlp(weights, qint8, None, frozen, device, atol=1e-3) @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", [qfloat8_e5m2, qfloat8_e4m3fn, qfloat8_e4m3fnuz], ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3", "a-float8-e4m3-uz"], ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_float8_activations(weights, activations, frozen, device): atol = {qfloat8_e4m3fn: 1e-3, qfloat8_e4m3fnuz: 1e-3, qfloat8_e5m2: 1e-2}[activations] _test_quantize_mlp(weights, activations, None, frozen, device, atol=atol) @pytest.mark.skip_device("cpu") @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("weights_only", [True, False], ids=["weights-only", "pickle"]) def test_quantized_mlp_device_memory(weights, dtype, weights_only, device): # We might not start from a clean state base_memory = get_device_memory(device) input_features = 1024 hidden_features = 2048 output_features = 1024 model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) full_precision_memory = get_device_memory(device) assert full_precision_memory > base_memory quantize(model, weights=weights) freeze(model) quantized_memory = get_device_memory(device) assert quantized_memory > base_memory assert quantized_memory < full_precision_memory @pytest.mark.parametrize( "weights, optimizer", [[qint8, AbsmaxOptimizer()], [qint4, MaxOptimizer()]], ids=["w-qint8", "w-qint4"] ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only_optimizers(weights, optimizer, frozen, device): atol = {qint4: 1e-4, qint8: 1e-6}[weights] _test_quantize_mlp(weights, None, optimizer, frozen, device, atol=atol) @pytest.mark.parametrize( "weights, optimizer", [[qint8, MaxOptimizer()], [qint4, AbsmaxOptimizer()]], ids=["w-qint8", "w-qint4"] ) def test_quantize_mlp_wrong_optimizer(weights, optimizer, device): with pytest.raises(ValueError): _test_quantize_mlp(weights, None, optimizer, False, device) ================================================ FILE: tests/quantize/test_quantize_patterns.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from optimum.quanto import ( qint8, quantize, ) from optimum.quanto.nn import QModuleMixin class MLP(torch.nn.Module): def __init__(self, input_size, output_size, hidden_size): super().__init__() self.input_layer = torch.nn.Linear(input_size, hidden_size) self.mid_layer = torch.nn.Linear(hidden_size, hidden_size) self.output_layer = torch.nn.Linear(hidden_size, output_size) def forward(self, inputs): x = torch.nn.functional.relu(self.input_layer(inputs)) x = torch.nn.functional.relu(self.mid_layer(x)) return self.output_layer(x) class ClassificationModel(torch.nn.Module): def __init__(self, input_size, output_size, hidden_size, classes): super().__init__() self.model = MLP(input_size, output_size, hidden_size) self.lm_head = torch.nn.Linear(output_size, classes) def forward(self, inputs): x = self.model(inputs) return torch.nn.functional.softmax(self.classifier(x), dim=-1) def has_children(module: torch.nn.Module): return next(module.children(), None) is not None def leaf_module_names(module: torch.nn.Module): return [name for name, m in module.named_modules() if not has_children(m)] def parent_module_names(module: torch.nn.Module): return [name for name, m in module.named_children() if has_children(m)] def test_quantize_mlp_include_explicit_layers(): model = ClassificationModel(32, 10, 128, 10) include_names = leaf_module_names(model) for include in include_names: model = ClassificationModel(32, 10, 128, 10) quantize(model, weights=qint8, include=include) for name, m in model.named_modules(): if name == include: assert isinstance(m, QModuleMixin) else: assert not isinstance(m, QModuleMixin) def test_quantize_mlp_exclude_explicit_layers(): model = ClassificationModel(32, 10, 128, 10) exclude_names = leaf_module_names(model) for exclude in exclude_names: model = ClassificationModel(32, 10, 128, 10) quantize(model, weights=qint8, exclude=exclude) for name, m in model.named_modules(): if name == exclude: assert not isinstance(m, QModuleMixin) elif not has_children(m): assert isinstance(m, QModuleMixin) def test_quantize_mlp_include_layer_patterns(): model = ClassificationModel(32, 10, 128, 10) parent_names = parent_module_names(model) for parent_name in parent_names: model = ClassificationModel(32, 10, 128, 10) quantize(model, weights=qint8, include=f"{parent_name}*") for name, m in model.named_modules(): if name.startswith(parent_name) and not has_children(m): assert isinstance(m, QModuleMixin) else: assert not isinstance(m, QModuleMixin) def test_quantize_mlp_exclude_layer_patterns(): model = ClassificationModel(32, 10, 128, 10) parent_names = parent_module_names(model) for parent_name in parent_names: model = ClassificationModel(32, 10, 128, 10) quantize(model, weights=qint8, exclude=f"{parent_name}*") for name, m in model.named_modules(): if name.startswith(parent_name): assert not isinstance(m, QModuleMixin) elif not has_children(m): assert isinstance(m, QModuleMixin) ================================================ FILE: tests/quantize/test_requantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io from tempfile import NamedTemporaryFile import pytest import torch from helpers import get_device_memory, random_tensor from safetensors.torch import load_file, save_file from test_quantize_mlp import MLP from optimum.quanto import Calibration, freeze, qint4, qint8, quantization_map, quantize, requantize from optimum.quanto.nn import QModuleMixin def save_and_reload_state_dict(state_dict, serialization): if serialization == "safetensors": with NamedTemporaryFile() as tmp_file: save_file(state_dict, tmp_file.name) return load_file(tmp_file.name) else: b = io.BytesIO() torch.save(state_dict, b) b.seek(0) weights_only = serialization == "weights_only" return torch.load(b, weights_only=weights_only) @pytest.mark.parametrize( "input_features, hidden_features, output_features", [(32, 10, 128), (1024, 1024, 1024)], ids=["small", "large"], ) @pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) @pytest.mark.parametrize("activations", [None, qint8], ids=["a-none", "a-qint8"]) def test_requantize_serialized_model( input_features, hidden_features, output_features, weights, activations, dtype, serialization, device ): model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) quantize(model, weights=weights, activations=activations) inputs = random_tensor((1, 10, input_features), dtype=dtype).to(device) if activations is not None: with Calibration(): model(inputs) freeze(model) qmap = quantization_map(model) model_reloaded = MLP(input_features, hidden_features, output_features).to(device) state_dict = save_and_reload_state_dict(model.state_dict(), serialization) requantize(model_reloaded, state_dict, qmap) for name, module in model.named_modules(): if isinstance(module, QModuleMixin): module_reloaded = getattr(model_reloaded, name) assert torch.equal(module_reloaded.weight, module.weight) assert module_reloaded.weight_qtype == module.weight_qtype assert module_reloaded.activation_qtype == module.activation_qtype assert torch.equal(module_reloaded.input_scale, module.input_scale) assert torch.equal(module_reloaded.output_scale, module.output_scale) @pytest.mark.skip_device("cpu") @pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("serialization", ["weights_only", "pickle", "safetensors"]) def test_requantized_model_device_memory(weights, dtype, serialization, device): input_features = 1024 hidden_features = 2048 output_features = 1024 model = MLP(input_features, hidden_features, output_features).to(dtype).to(device) full_precision_memory = get_device_memory(device) quantize(model, weights=weights) freeze(model) qmap = quantization_map(model) quantized_memory = get_device_memory(device) assert quantized_memory < full_precision_memory state_dict = save_and_reload_state_dict(model.state_dict(), serialization) # Free device memory del model with torch.device("meta"): reloaded_model = MLP(input_features, hidden_features, output_features).to(dtype) requantize(reloaded_model, state_dict, qmap, device) # Free device memory del state_dict requantized_memory = get_device_memory(device) assert requantized_memory <= quantized_memory ================================================ FILE: tests/tensor/activations/test_activations_compile.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import random_tensor from optimum.quanto import ActivationQBytesTensor, absmax_scale, qint8, quantize_activation def compile_for_device(f, device): # Remove any side-effects form previous compilation torch.compiler.reset() # Inductor relies on Triton for inference which does not support MPS backend = "aot_eager" if device == torch.device("mps") else "inductor" return torch.compile(f, backend=backend) @pytest.mark.skip("Disabled as it is not working (yet ?)") @pytest.mark.parametrize("input_shape", [(2, 10), (10, 32, 32)]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) def test_compile_quantize_tensor(input_shape, qtype, dtype, device): if device == torch.device("mps") and dtype == torch.bfloat16: pytest.skip("BFloat16 is not supported on MPS") a = random_tensor(input_shape, dtype=dtype).to(device) def f(x, qtype): scale = absmax_scale(x) return quantize_activation(x, qtype=qtype, scale=scale) compiled_f = compile_for_device(f, device) qa = compiled_f(a, qtype) assert isinstance(qa, ActivationQBytesTensor) assert qa.qtype == qtype assert qa._scale.dtype == dtype assert qa.axis is None def test_compile_qtensor_to(device): input_shape = (10, 32, 32) a = random_tensor(input_shape).to(device) def f(x, dtype): return x.to(dtype) compiled_f = compile_for_device(f, device) scale = absmax_scale(a) qa = quantize_activation(a, qtype=qint8, scale=scale) cqa = compiled_f(qa, torch.float16) assert isinstance(cqa, ActivationQBytesTensor) assert cqa.qtype == qint8 assert cqa._scale.dtype == torch.float16 ================================================ FILE: tests/tensor/activations/test_activations_dispatch.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qactivation, random_tensor from optimum.quanto import ActivationQBytesTensor, quantize_activation @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("scalar", [1, 0.5, torch.tensor(0.12)], ids=["int", "float", "tensor"]) def test_qactivation_mul_scalar(input_shape, scalar, device): qa = random_qactivation(input_shape, dtype=torch.float32).to(device) if isinstance(scalar, torch.Tensor): scalar = scalar.to(device) qprod = qa * scalar assert isinstance(qprod, ActivationQBytesTensor) prod = qa.dequantize() * scalar assert_similar(prod, qprod) qprod = scalar * qa assert isinstance(qprod, ActivationQBytesTensor) prod = scalar * qa.dequantize() assert_similar(prod, qprod) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(5, 5), (32, 32), (10, 32)]) def test_qactivation_relu(batch_size, tokens, embeddings, device): qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) qout = torch.nn.functional.relu(qinputs) assert isinstance(qout, ActivationQBytesTensor) assert torch.equal(qout._data, torch.maximum(qinputs._data, torch.zeros((1,)).to(device))) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(5, 5), (32, 32), (10, 32)]) def test_qactivation_softmax(batch_size, tokens, embeddings, device): qinputs = random_qactivation((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) qout = torch.nn.functional.softmax(qinputs, dim=-1) assert isinstance(qout, ActivationQBytesTensor) assert torch.min(qout.dequantize()) >= 0 assert torch.max(qout.dequantize()) <= 1 @pytest.mark.parametrize("input_shape", [(10,), (10, 32)]) def test_qactivation_view(input_shape, device): qinputs = random_qactivation(input_shape, dtype=torch.float32).to(device) qview = qinputs.view((1,) + input_shape) assert isinstance(qview, ActivationQBytesTensor) @pytest.mark.parametrize("input_shape", [(10,), (10, 32)]) def test_qactivation_cat(input_shape, device): qinputs = random_qactivation(input_shape, dtype=torch.float32).to(device) other = random_tensor(input_shape, dtype=torch.float32).to(device) # First, quantize other with the same scale qother = quantize_activation(other, qtype=qinputs.qtype, scale=qinputs._scale) qcat = torch.cat([qinputs, qother]) assert isinstance(qcat, ActivationQBytesTensor) assert_similar(torch.cat([qinputs.dequantize(), qother.dequantize()]), qcat) def test_qactivation_transpose_2d(device): input_shape = (4, 6) qinputs = random_qactivation(input_shape).to(device) qtransposed = qinputs.t() assert qtransposed.qtype == qinputs.qtype assert qtransposed.shape == input_shape[::-1] assert torch.equal(qtransposed.dequantize(), qinputs.dequantize().t()) def test_qactivation_transpose(device): input_shape = (10, 32, 64) qinputs = random_qactivation(input_shape).to(device) qtransposed = torch.transpose(qinputs, 1, 2) assert qtransposed.qtype == qinputs.qtype assert torch.equal(qtransposed.dequantize(), torch.transpose(qinputs.dequantize(), 1, 2)) ================================================ FILE: tests/tensor/activations/test_activations_quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, device_eq, random_tensor from optimum.quanto import ( ActivationQBytesTensor, absmax_scale, qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8, ) @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) def test_symmetric_quantize_int(input_shape, dtype, qtype, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=None) qa = ActivationQBytesTensor.quantize(a, qtype, scale) assert isinstance(qa, ActivationQBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) assert_similar(a, qa) @pytest.mark.skip_device("mps") @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize( "qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"], ) def test_symmetric_quantize_float8(input_shape, dtype, qtype, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=None) qa = ActivationQBytesTensor.quantize(a, qtype, scale) assert isinstance(qa, ActivationQBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) assert_similar(a, qa, atol=5e-3) ================================================ FILE: tests/tensor/ops/test_linear_dispatch.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qactivation, random_qweight, random_tensor from optimum.quanto import qint2, qint4, qint8 @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(5, 5), (32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("activation_qtype", [None, qint8], ids=["a-none", "a-qint8"]) @pytest.mark.parametrize("weight_qtype", [qint2, qint4, qint8], ids=["w-qint2", "w-qint4", "w-qint8"]) def test_qactivation_qweight_linear( batch_size, tokens, embeddings, use_bias, dtype, activation_qtype, weight_qtype, device ): input_shape = (batch_size, tokens, embeddings) if activation_qtype is None: inputs = random_tensor(input_shape, dtype=dtype).to(device) else: inputs = random_qactivation(input_shape, qtype=activation_qtype, dtype=dtype).to(device) qweight = random_qweight((embeddings, embeddings), qtype=weight_qtype, dtype=dtype, axis=0).to(device) bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None qout = torch.nn.functional.linear(inputs, qweight, bias) if activation_qtype is not None: inputs = inputs.dequantize() out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias) assert_similar(out, qout) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(256, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_linear_fp16_int4(batch_size, tokens, embeddings, use_bias, device): dtype = torch.float16 weight_qtype = qint4 inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device) qweight = random_qweight((embeddings, embeddings), weight_qtype, dtype=dtype, axis=0, group_size=128).to(device) bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None qout = torch.nn.functional.linear(inputs, qweight, bias) out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias) assert_similar(out, qout) @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(256, 256)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_linear_bf16_int4(batch_size, tokens, embeddings, use_bias, device): dtype = torch.bfloat16 weight_qtype = qint4 input_shape = (batch_size, tokens, embeddings) inputs = torch.rand(input_shape, dtype=dtype, device=device) weight_shape = (embeddings, embeddings) qweight = random_qweight(weight_shape, weight_qtype, dtype=dtype, axis=0, group_size=128, device=device) bias = random_tensor((embeddings,), dtype=dtype).to(device) if use_bias else None qout = torch.nn.functional.linear(inputs, qweight, bias) out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias) assert_similar(out, qout) ================================================ FILE: tests/tensor/ops/test_mm_dispatch.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qactivation, random_qweight from optimum.quanto import qint8 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("in_features", [5, 16, 24]) @pytest.mark.parametrize("hidden", [5, 16, 24]) @pytest.mark.parametrize("out_features", [5, 16, 24]) def test_qactivation_qweight_matmul(dtype, in_features, hidden, out_features, device): qa = random_qactivation((in_features, hidden), qint8, dtype=dtype).to(device) qb = random_qweight((hidden, out_features), qint8, dtype=dtype, axis=-1).to(device) qmatmul = torch.matmul(qa, qb) # The outputs should be almost identical if we use the dequantized inputs matmul = torch.matmul(qa.dequantize(), qb.dequantize()) assert_similar(matmul, qmatmul) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("a_shape, b_shape", [[(16, 32), (32, 24)], [(5, 10), (10, 6)]]) def test_qactivation_qactivation_bmm(dtype, batch_size, a_shape, b_shape, device): qa = random_qactivation((batch_size,) + a_shape, qint8, dtype=dtype).to(device) qb = random_qactivation((batch_size,) + b_shape, qint8, dtype=dtype).to(device) qbmm = torch.bmm(qa, qb) # The outputs should be almost identical if we use the dequantized inputs bmm = torch.bmm(qa.dequantize(), qb.dequantize()) assert_similar(bmm, qbmm) ================================================ FILE: tests/tensor/optimizers/test_hqq_optimizer.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import random_tensor from optimum.quanto import ( HqqOptimizer, MaxOptimizer, WeightQBitsTensor, qint2, qint4, ) def compare_quantized_tensor(a, qtype, axis, group_size, scale, shift): qa = WeightQBitsTensor.quantize(a, qtype, axis, group_size, scale, shift) # Evaluate mean absolute error mean_error = torch.mean(torch.abs(a - qa)) # Also evaluate cosine similarity sim = torch.nn.functional.cosine_similarity(a.flatten(), qa.flatten(), dim=0) return mean_error, sim @pytest.mark.parametrize("input_shape", [(1024, 1024)]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("group_size", [32, 64, 128]) def test_hqq_optimizer(input_shape, dtype, qtype, axis, group_size, device): a = random_tensor(input_shape, dtype=dtype).to(device) max_scale, max_shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size) max_mean_error, max_sim = compare_quantized_tensor(a, qtype, axis, group_size, max_scale, max_shift) hqq_scale, hqq_shift = HqqOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size) hqq_mean_error, hqq_sim = compare_quantized_tensor(a, qtype, axis, group_size, hqq_scale, hqq_shift) # HQQ optimizes the mean error, so it should be lower assert hqq_mean_error <= max_mean_error # FIXME: HQQ cosine similarity should be also closer to 1 ================================================ FILE: tests/tensor/test_absmax.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import random_tensor from optimum.quanto import absmax_scale, qfloat8, qint8 @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (2, 10), (10, 32, 32)]) @pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"]) def test_absmax_scale(input_shape, axis, dtype, qtype, device): if device.type == "mps" and qtype.is_floating_point: pytest.skip("Float8 are not supported on MPS device") a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype, axis) assert scale.dtype == dtype if axis is None: assert scale.ndim == 0 else: assert scale.ndim == a.ndim sscale = torch.squeeze(scale) if a.ndim == 1 or a.shape[axis] == 1: # Quantization is actually per-tensor as the axis dim is 1 assert sscale.ndim == 0 else: assert sscale.ndim == 1 ================================================ FILE: tests/tensor/test_packed_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import pytest import torch from helpers import device_eq from optimum.quanto.tensor.packed import PackedTensor @pytest.mark.parametrize("shape", [(10,), (12,), (10, 10), (12, 10), (32, 32)]) @pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) def test_pack_tensor(shape, bits, device): """This test verifies that an integer tensor in the correct range is preserved.""" qmax = 2**bits t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) packed = PackedTensor.pack(t, bits=bits) assert isinstance(packed, PackedTensor) assert packed.dtype == torch.uint8 assert device_eq(packed.device, device) assert torch.equal(t, packed.unpack()) @pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) def test_packed_tensor_serialization(bits, device): qmax = 2**bits shape = (10, 32) t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) packed = PackedTensor.pack(t, bits=bits) b = io.BytesIO() torch.save(packed, b) b.seek(0) packed_reloaded = torch.load(b, weights_only=False) assert isinstance(packed_reloaded, PackedTensor) assert packed_reloaded.shape == packed.shape assert packed_reloaded.dtype == packed.dtype assert packed_reloaded.bits == packed.bits assert torch.equal(packed_reloaded._data, packed._data) assert torch.equal(t, packed_reloaded.unpack()) ================================================ FILE: tests/tensor/weights/optimized/test_awq_packed_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest import torch from helpers import device_eq from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking @pytest.mark.skip_device("cpu") @pytest.mark.skip_device("mps") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("random", [True, False]) @pytest.mark.parametrize("packing, reorder", [(AWQPacking.V1, True), (AWQPacking.V1, False), (AWQPacking.V2, False)]) def test_pack_awq_tensor(in_features, out_features, random, packing, reorder, device): bits = 4 qmax = 2**bits shape = (out_features, in_features) if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) packed = AWQPackedTensor.pack(t, packing=packing, reorder=reorder) assert isinstance(packed, AWQPackedTensor) assert packed._packing == packing assert packed._reorder == reorder assert device_eq(packed.device, device) assert torch.equal(t, packed.unpack()) @pytest.mark.skip_device("cpu") @pytest.mark.skip_device("mps") @pytest.mark.parametrize("packing, reorder", [(AWQPacking.V1, True), (AWQPacking.V2, False)]) def test_move_awq_tensor(packing, reorder, device): shape = (256, 256) bits = 4 qmax = 2**bits numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) packed = AWQPackedTensor.pack(t, packing=packing, reorder=reorder) assert packed._packing == packing assert packed._reorder == reorder moved = packed.to(device) assert isinstance(moved, AWQPackedTensor) assert moved._packing == packing assert moved._reorder == reorder # TensorRT tensors are unpacked when moved out of CUDA or XPU device moved = packed.to("cpu") assert type(moved) is torch.Tensor ================================================ FILE: tests/tensor/weights/optimized/test_awq_weight_qbits_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import device_eq, random_qweight from tensor.weights.weight_helpers import check_weight_qtensor_linear from optimum.quanto import qint4 from optimum.quanto.library.extensions import is_extension_available from optimum.quanto.tensor.weights import WeightQBitsTensor from optimum.quanto.tensor.weights.awq import AWQWeightQBitsTensor @pytest.mark.skip_device("cpu") @pytest.mark.skip_device("mps") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) def test_awq_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device): qtype = qint4 group_size = 128 dtype = torch.float16 shape = (out_features, in_features) qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) # Create a AWQWeightQBitsTensor from the WeightQBitsTensor members awqbt = AWQWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) assert awqbt.dtype == dtype assert awqbt.qtype == qtype assert awqbt.shape == shape assert device_eq(awqbt.device, device) # Verify the dequantized tensors are identical assert torch.equal(awqbt.dequantize(), qbt.dequantize()) # Now verify that we can reconstruct the WeightQBitsTensor new_qbt = awqbt.weight_qbits_tensor() assert type(new_qbt) is WeightQBitsTensor assert new_qbt.dtype == dtype assert new_qbt.qtype == qtype assert new_qbt.shape == shape assert torch.equal(new_qbt._data, qbt._data) assert torch.equal(new_qbt._scale, qbt._scale) assert torch.equal(new_qbt._shift, qbt._shift) @pytest.mark.skip_device("cpu") @pytest.mark.skip_device("mps") def test_awq_weight_qbits_tensor_move(device): qtype = qint4 group_size = 128 dtype = torch.float16 shape = (1024, 1024) # Create an AWQWeightQBitsTensor from a QBitsTensor on CUDA or XPU qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) awqbt = AWQWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) # Move to device, dequantize and compare moved_qbt = awqbt.to(device) assert isinstance(moved_qbt, WeightQBitsTensor) if device.type not in ["cuda", "xpu"]: assert type(moved_qbt) is not AWQWeightQBitsTensor assert awqbt.dtype == moved_qbt.dtype assert awqbt.qtype == moved_qbt.qtype assert awqbt.shape == moved_qbt.shape assert torch.equal(awqbt.dequantize().to(device), moved_qbt.dequantize()) def _test_awq_weight_qbits_tensor_linear( dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias ): # Create an AWQWeightQBitsTensor from a QBitsTensor on CUDA qbt = random_qweight( (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device(0) ) awq_qweight = AWQWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) check_weight_qtensor_linear(awq_qweight, batch_size, tokens, use_bias) @pytest.mark.skipif( (not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8) and not torch.xpu.is_available(), reason="The test requires CUDA device >= sm80 or Intel XPU", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32, 48, 64]) @pytest.mark.parametrize("in_features", [256, 512, 1024, 4096, 16384]) @pytest.mark.parametrize("out_features", [256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_awq_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias): dtype = torch.float16 weight_qtype = qint4 group_size = 128 _test_awq_weight_qbits_tensor_linear( dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias ) ================================================ FILE: tests/tensor/weights/optimized/test_marlin_fp8_packed_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest import torch from helpers import device_eq from optimum.quanto.library.extensions import is_extension_available from optimum.quanto.tensor.weights.marlin.fp8 import MarlinF8PackedTensor def get_fp8_tensor(shape, device, random=False): # We will initialize float8 from an uint8 tensor qmax = 2**8 if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) # Remove values that would be interpreted as nans in float8. t[t == 127] = 0 t[t == 255] = 0 return t.view(torch.float8_e4m3fn).to(device) @pytest.mark.skipif(not is_extension_available("quanto_cuda"), reason="CUDA extension is not available") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("random", [True, False]) def test_pack_marlin_fp8_tensor(in_features, out_features, random): shape = (out_features, in_features) device = torch.device("cuda") t = get_fp8_tensor(shape, device, random) packed = MarlinF8PackedTensor.pack(t) assert isinstance(packed, MarlinF8PackedTensor) assert device_eq(packed.device, device) assert torch.equal(t, packed.unpack()) @pytest.mark.skipif(not is_extension_available("quanto_cuda"), reason="CUDA extension is not available") def test_move_marlin_fp8_tensor(): shape = (256, 256) device = torch.device("cuda") t = get_fp8_tensor(shape, device) packed = MarlinF8PackedTensor.pack(t) moved = packed.to("cuda") assert isinstance(moved, MarlinF8PackedTensor) # Marlin FP8 tensors are unpacked when moved out of CUDA device moved = packed.to("cpu") assert type(moved) is torch.Tensor assert torch.equal(t, moved.to("cuda")) ================================================ FILE: tests/tensor/weights/optimized/test_marlin_int4_packed_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest import torch from helpers import device_eq from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4PackedTensor def get_uint4_tensor(shape, device, random=False): qmax = 2**4 if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) return t @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("random", [True, False]) def test_pack_marlin_int4_tensor(in_features, out_features, random): shape = (out_features, in_features) device = torch.device("cuda") t = get_uint4_tensor(shape, device, random) packed = MarlinInt4PackedTensor.pack(t) assert isinstance(packed, MarlinInt4PackedTensor) assert device_eq(packed.device, device) assert torch.equal(t, packed.unpack()) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_move_marlin_int4_packed_tensor(device): shape = (256, 256) device = torch.device("cuda") t = get_uint4_tensor(shape, device) packed = MarlinInt4PackedTensor.pack(t) moved = packed.to("cuda") assert isinstance(moved, MarlinInt4PackedTensor) # Marlin int4 tensors are unpacked when moved out of CUDA device moved = packed.to("cpu") assert type(moved) is torch.Tensor assert torch.equal(t, moved.to("cuda")) ================================================ FILE: tests/tensor/weights/optimized/test_marlin_int4_weight_qbits_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import device_eq, random_qweight from tensor.weights.weight_helpers import check_weight_qtensor_linear from optimum.quanto import qint4 from optimum.quanto.library.extensions import is_extension_available from optimum.quanto.tensor.weights import WeightQBitsTensor from optimum.quanto.tensor.weights.marlin.int4 import MarlinInt4WeightQBitsTensor @pytest.mark.skipif( not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available" ) @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) def test_marlin_int4_weight_qbits_tensor_from_qbits_tensor(in_features, out_features): qtype = qint4 group_size = 128 dtype = torch.float16 shape = (out_features, in_features) device = torch.device("cuda") qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) # Create a MarlinInt4WeightQBitsTensor from the WeightQBitsTensor members marlinqbt = MarlinInt4WeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) assert marlinqbt.dtype == dtype assert marlinqbt.qtype == qtype assert marlinqbt.shape == shape assert device_eq(marlinqbt.device, device) # Verify the dequantized tensors are identical assert torch.equal(marlinqbt.dequantize(), qbt.dequantize()) # Now verify that we can reconstruct the WeightQBitsTensor new_qbt = marlinqbt.weight_qbits_tensor() assert type(new_qbt) is WeightQBitsTensor assert new_qbt.dtype == dtype assert new_qbt.qtype == qtype assert new_qbt.shape == shape assert torch.equal(new_qbt._data, qbt._data) assert torch.equal(new_qbt._scale, qbt._scale) assert torch.equal(new_qbt._shift, qbt._shift) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_marlin_int4_weight_qbits_tensor_move(device): qtype = qint4 group_size = 128 dtype = torch.float16 shape = (1024, 1024) device = torch.device("cuda") # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cuda")) marlinqbt = MarlinInt4WeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) # Move to device, dequantize and compare moved_qbt = marlinqbt.to(device) assert isinstance(moved_qbt, WeightQBitsTensor) if device.type != "cuda": assert type(moved_qbt) is not MarlinInt4WeightQBitsTensor assert marlinqbt.dtype == moved_qbt.dtype assert marlinqbt.qtype == moved_qbt.qtype assert marlinqbt.shape == moved_qbt.shape assert torch.equal(marlinqbt.dequantize().to(device), moved_qbt.dequantize()) def _test_marlin_int4_weight_qbits_tensor_linear( dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias ): # Create an MarlinInt4WeightQBitsTensor from a QBitsTensor on CUDA qbt = random_qweight( (out_features, in_features), weight_qtype, dtype, group_size=group_size, device=torch.device("cuda") ) marlin_qweight = MarlinInt4WeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale=qbt._scale, shift=qbt._shift, ) check_weight_qtensor_linear(marlin_qweight, batch_size, tokens, use_bias) @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32]) @pytest.mark.parametrize("in_features", [1024]) @pytest.mark.parametrize("out_features", [1024, 2048, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_marlin_int4_weight_qbits_tensor_linear(batch_size, tokens, in_features, out_features, use_bias): dtype = torch.float16 weight_qtype = qint4 group_size = 128 _test_marlin_int4_weight_qbits_tensor_linear( dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias ) @pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False) @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [48, 64]) # @pytest.mark.parametrize("in_features", [1024, 2048, 4096, 16384]) @pytest.mark.parametrize("in_features", [4096, 16384]) @pytest.mark.parametrize("out_features", [2048, 4096]) def test_marlin_int4_weight_qbits_tensor_linear_failing(batch_size, tokens, in_features, out_features): dtype = torch.float16 weight_qtype = qint4 group_size = 128 _test_marlin_int4_weight_qbits_tensor_linear( dtype, weight_qtype, group_size, batch_size, tokens, in_features, out_features, use_bias=False ) ================================================ FILE: tests/tensor/weights/optimized/test_marlin_qbytes_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from optimum.quanto import qfloat8_e4m3fn from optimum.quanto.library.extensions import is_extension_available from optimum.quanto.tensor.weights.marlin import MarlinF8QBytesTensor @pytest.mark.skipif( not is_extension_available("quanto_cuda") or torch.cuda.get_device_capability()[0] < 8, reason="CUDA >= sm80 not available", ) @pytest.mark.parametrize("in_features", [16, 32, 48, 64]) @pytest.mark.parametrize("out_features", [64, 128, 192, 256]) def test_pack_unpack(in_features: int, out_features: int): data = torch.randint(0, 256, size=(out_features, in_features), dtype=torch.uint8, device="cuda") # Remove nans. data[data == 127] = 0 data[data == 255] = 0 data = data.view(torch.float8_e4m3fn) qtype = qfloat8_e4m3fn axis = 0 size = data.shape stride = data.stride() scale = torch.rand((out_features, 1), dtype=torch.float16, device="cuda") marlin_tensor = MarlinF8QBytesTensor(qtype, axis, size, stride, data, scale) data_dequantized = marlin_tensor.dequantize() assert torch.all((data.to(torch.float16) * scale - data_dequantized).abs() < 1e-4) ================================================ FILE: tests/tensor/weights/optimized/test_tinygemm_packed_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pytest import torch from helpers import device_eq from packaging import version from optimum.quanto.tensor.weights.tinygemm import TinyGemmPackedTensor @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("random", [True, False]) def test_pack_tinygemm_tensor(in_features, out_features, random, device): if device.type == "cuda": if torch.version.hip: pytest.skip(reason="TinyGemm is not supported on ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: pytest.skip(reason="CUDA device >= sm80 not available") bits = 4 qmax = 2**bits shape = (out_features, in_features) if random: t = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) else: numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8).to(device) packed = TinyGemmPackedTensor.pack(t) assert isinstance(packed, TinyGemmPackedTensor) assert device_eq(packed.device, device) assert torch.equal(t, packed.unpack()) @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 def test_move_tinygemm_packed_tensor(device): if device.type == "cuda": if torch.version.hip: pytest.skip(reason="TinyGemm is not supported on ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: pytest.skip(reason="CUDA device >= sm80 not available") shape = (256, 256) bits = 4 qmax = 2**bits numel = np.prod(shape) t = torch.tensor(range(numel), dtype=torch.int32) t = (t % qmax).reshape(shape).to(torch.uint8) packed = TinyGemmPackedTensor.pack(t) moved = packed.to(device) assert torch.equal(t.to(device), moved.unpack()) ================================================ FILE: tests/tensor/weights/optimized/test_tinygemm_weight_qbits_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, device_eq, random_qweight, random_tensor from packaging import version from optimum.quanto import qint4 from optimum.quanto.tensor.weights import WeightQBitsTensor from optimum.quanto.tensor.weights.tinygemm import TinyGemmWeightQBitsTensor @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 @pytest.mark.parametrize("in_features", [128, 256, 512, 1024]) @pytest.mark.parametrize("out_features", [128, 256, 512, 1024]) def test_tinygemm_weight_qbits_tensor_from_qbits_tensor(in_features, out_features, device): if device.type == "cuda": if torch.version.hip: pytest.skip(reason="TinyGemm not available for ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: pytest.skip(reason="CUDA device >= sm80 not available") qtype = qint4 group_size = 128 dtype = torch.bfloat16 shape = (out_features, in_features) qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=device) # Create a TinyGemmWeightQBitsTensor from the WeightQBitsTensor members tgqbt = TinyGemmWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale_shift=(qbt._scale, qbt._shift), ) assert tgqbt.dtype == dtype assert tgqbt.qtype == qtype assert tgqbt.shape == shape assert device_eq(tgqbt.device, device) # Verify that we can reconstruct the WeightQBitsTensor new_qbt = tgqbt.weight_qbits_tensor() assert type(new_qbt) is WeightQBitsTensor assert new_qbt.dtype == dtype assert new_qbt.qtype == qtype assert new_qbt.shape == shape assert torch.equal(new_qbt._data, qbt._data) assert torch.equal(new_qbt._scale, qbt._scale) # FIXME: we cannot guarantee an exact match because of the addition/removal of the mid-point # which is lossy in bfloat16 (a + b - b != a) assert_similar(new_qbt._shift, qbt._shift) # Verify the dequantized tensors are similar assert_similar(tgqbt.dequantize(), qbt.dequantize()) @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 def test_tinygemm_weight_qbits_tensor_move(device): qtype = qint4 group_size = 128 dtype = torch.bfloat16 shape = (1024, 1024) # Create a TinyGemmWeightQBitsTensor from a QBitsTensor on CPU qbt = random_qweight(shape, qtype, dtype, group_size=group_size, device=torch.device("cpu")) tgqbt_cpu = TinyGemmWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale_shift=(qbt._scale, qbt._shift), ) # Move to device, dequantize and compare tgqbt = tgqbt_cpu.to(device) assert isinstance(tgqbt, WeightQBitsTensor) assert tgqbt.dtype == tgqbt_cpu.dtype assert tgqbt.qtype == tgqbt_cpu.qtype assert tgqbt.shape == tgqbt_cpu.shape assert torch.equal(tgqbt.dequantize().cpu(), tgqbt_cpu.dequantize()) @pytest.mark.skip_device("mps") # Only available with pytorch 2.4 @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [256, 512]) @pytest.mark.parametrize("embeddings", [256, 512, 1024, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_tinygemm_weight_qbits_tensor_linear(batch_size, tokens, embeddings, use_bias, device): if device.type == "cuda": if torch.version.hip: pytest.skip(reason="TinyGemm not available for ROCm devices") if version.parse(torch.version.cuda).release < (12, 1): pytest.skip(reason="CUDA runtime must be at least 12.1") if torch.cuda.get_device_capability()[0] < 8: pytest.skip(reason="CUDA device >= sm80 not available") qtype = qint4 group_size = 128 dtype = torch.bfloat16 inputs = torch.rand((batch_size,) + (tokens, embeddings), dtype=dtype, device=device) # Create a TinyGemmWeightQBitsTensor from a QBitsTensor qbt = random_qweight((tokens, embeddings), qtype, dtype, group_size=group_size, device=device) tinygemm_qweight = TinyGemmWeightQBitsTensor( qtype=qbt.qtype, axis=qbt.axis, group_size=qbt._group_size, size=qbt.size(), stride=qbt.stride(), data=qbt._data.unpack(), scale_shift=(qbt._scale, qbt._shift), ) bias = random_tensor((tokens,), dtype=dtype).to(device) if use_bias else None qout = torch.nn.functional.linear(inputs, tinygemm_qweight, bias) out = torch.nn.functional.linear(inputs, qbt.dequantize(), bias) assert_similar(out, qout) ================================================ FILE: tests/tensor/weights/test_weight_qbits_tensor.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import pytest import torch from helpers import random_qweight, random_tensor from optimum.quanto import MaxOptimizer, WeightQBitsTensor, qint2, qint4, quantize_weight @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["int2", "int4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) def test_weight_qbits_tensor_serialization(qtype, axis): qa = random_qweight((5, 5), qtype=qtype, axis=axis) b = io.BytesIO() torch.save(qa, b) b.seek(0) qa_reloaded = torch.load(b, weights_only=False) assert isinstance(qa_reloaded, WeightQBitsTensor) assert qa_reloaded.qtype == qa.qtype assert qa_reloaded.dtype == qa.dtype assert torch.equal(qa_reloaded._data, qa._data) assert torch.equal(qa_reloaded._scale, qa._scale) assert torch.equal(qa_reloaded._shift, qa._shift) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["int2", "int4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("group_size", [None, 16], ids=["channel-wise", "group-wise"]) def test_weight_qbits_tensor_requires_grad(qtype, axis, group_size, device): weight = random_tensor((32, 32), dtype=torch.float32).to(device) weight.requires_grad = True scale, shift = MaxOptimizer()(weight, qtype=qtype, axis=axis, group_size=group_size) qweight = quantize_weight(weight, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size) assert qweight.requires_grad is True @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["int2", "int4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("group_size", [None, 16], ids=["channel-wise", "group-wise"]) def test_weight_qbits_tensor_backward(qtype, axis, group_size, device): weight = random_tensor((32, 32), dtype=torch.float32).to(device) weight.requires_grad = True scale, shift = MaxOptimizer()(weight, qtype=qtype, axis=axis, group_size=group_size) qweight = quantize_weight(weight, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=group_size) gradient = torch.randn((32, 32)).to(device) # Backpropagate gradient to the inner float weights qweight.dequantize().backward(gradient) assert torch.equal(weight.grad, gradient) ================================================ FILE: tests/tensor/weights/test_weight_qbits_tensor_dispatch.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, random_qweight, random_tensor from tensor.weights.weight_helpers import check_weight_qtensor_linear from optimum.quanto import MaxOptimizer, QBitsTensor, qint2, qint4, quantize_weight @pytest.mark.parametrize("group_size", [None, 128], ids=["channel-wise", "group-wise"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"]) def test_qbitstensor_to_device(dtype, group_size, device): qa = random_qweight((256, 512), dtype=dtype, qtype=qint4, group_size=group_size, device="cpu") # Keep a copy of the dequantized Tensor as a reference dqa = qa.dequantize() # Move to the target device moved_qa = qa.to(device) assert isinstance(moved_qa, QBitsTensor) assert moved_qa.device.type == device.type assert moved_qa._data.device.type == device.type assert moved_qa._scale.device.type == device.type assert moved_qa._shift.device.type == device.type moved_dqa = moved_qa.dequantize().to("cpu") if type(moved_qa) is not QBitsTensor: # Since we use an optimized packing, the order of operations during # dequantization might differ, but the moved dequantized Tensor should be nearly identical assert_similar(moved_dqa, dqa) else: assert torch.equal(moved_dqa, dqa) def test_qbitstensor_detach(): qa = random_qweight((32, 32), qtype=qint4) dqa = qa.detach() assert isinstance(dqa, QBitsTensor) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) def test_qbitstensor_equal(dtype, qtype, axis, device): a = random_tensor((1024, 1024), dtype=dtype, device=device) scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=128) qa1 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=128) qa2 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale, shift=shift, group_size=128) assert torch.equal(qa1, qa2) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32]) @pytest.mark.parametrize("in_features", [256, 512]) @pytest.mark.parametrize("out_features", [256, 512]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_weight_qbits_tensor_linear(dtype, batch_size, tokens, in_features, out_features, use_bias, device): weight_qtype = qint4 group_size = 128 # Create a QBitsTensor qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=device) check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("tokens", [16, 32, 48, 64]) @pytest.mark.parametrize("in_features", [1024, 4096, 16384]) @pytest.mark.parametrize("out_features", [1024, 2048, 4096]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_weight_qbits_tensor_linear_gpu(dtype, batch_size, tokens, in_features, out_features, use_bias): if torch.cuda.is_available(): device = torch.device("cuda") elif torch.xpu.is_available(): device = torch.device("xpu") else: pytest.skip(reason="Test is too slow on non-GPU devices") weight_qtype = qint4 group_size = 128 # Create a QBitsTensor qbt = random_qweight((out_features, in_features), weight_qtype, dtype, group_size=group_size, device=device) check_weight_qtensor_linear(qbt, batch_size, tokens, use_bias) ================================================ FILE: tests/tensor/weights/test_weight_qbits_tensor_instantiate.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from optimum.quanto import qint2, qint4 from optimum.quanto.tensor.weights import WeightQBitsTensor def random_data_scale_shift(input_shape, dtype, qtype, axis, group_size): out_features, in_features = input_shape n_groups = in_features * out_features // group_size data_shape = (n_groups, group_size) if axis == 0 else (group_size, n_groups) scale_shape = (n_groups, 1) if axis == 0 else (1, n_groups) min_value = -(2 ** (qtype.bits - 1)) max_value = 2 ** (qtype.bits - 1) - 1 data = torch.randint(max_value - min_value + 1, data_shape, dtype=torch.uint8) scale = torch.full(scale_shape, 1.0 / -min_value, dtype=dtype) shift = torch.ones(scale_shape, dtype=dtype) return data, scale, shift @pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) def test_weight_qbits_tensor_instantiate(input_shape, dtype, qtype, axis, group_size, device): data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size) input_stride = torch.ones(input_shape).stride() qa = WeightQBitsTensor(qtype, axis, group_size, input_shape, input_stride, data, scale=scale, shift=shift).to( device ) assert torch.max(torch.abs(qa.dequantize())) <= 1 assert qa.dtype == dtype assert qa.qtype == qtype assert qa.shape == input_shape @pytest.mark.parametrize("input_shape, group_size", [[(32, 32), 16], [(1024, 1024), 128]]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) def test_weight_qbits_tensor_equal(input_shape, dtype, qtype, axis, group_size, device): data, scale, shift = random_data_scale_shift(input_shape, dtype, qtype, axis, group_size) qa = WeightQBitsTensor(qtype, axis, group_size, data.size(), data.stride(), data, scale=scale, shift=shift).to( device ) qb = WeightQBitsTensor( qtype, axis, group_size, data.size(), data.stride(), data.clone(), scale=scale.clone(), shift=shift.clone() ).to(device) assert qa.equal(qb) ================================================ FILE: tests/tensor/weights/test_weight_qbits_tensor_quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, device_eq, random_tensor from optimum.quanto import ( MaxOptimizer, qint2, qint4, ) from optimum.quanto.tensor.weights import WeightQBitsTensor @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("group_size", [None, 8], ids=["channel-wise", "group-wise"]) @pytest.mark.parametrize("shift_mode", ["zeropoint", "float"]) def test_weight_qbits_tensor_quantize(input_shape, dtype, qtype, axis, group_size, shift_mode, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale, shift = MaxOptimizer()(a, qtype=qtype, axis=axis, group_size=group_size) if shift_mode == "zeropoint": shift = torch.round(shift / scale).to(torch.int8) qa = WeightQBitsTensor.quantize(a, qtype, axis, group_size, scale, shift) assert isinstance(qa, WeightQBitsTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) atol = { qint4: { "zeropoint": 4e-3, "float": 3e-3, }, qint2: { "zeropoint": 6e-2, "float": 5e-2, }, }[qtype][shift_mode] assert_similar(a, qa, atol=atol) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) def test_weight_qbits_tensor_quantize_integer_tensor(dtype, qtype, device): """This test verifies that an integer tensor in the correct range is preserved.""" bits = qtype.bits qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) - 1 a = torch.tensor(range(qmin, qmax + 1), dtype=dtype).to(device) scale, shift = MaxOptimizer()(a, qtype=qtype, axis=0, group_size=None) zeropoint = torch.round(shift / scale) qa = WeightQBitsTensor.quantize(a, qtype, 0, None, scale, zeropoint) assert qa._data.dtype == torch.uint8 assert isinstance(qa, WeightQBitsTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) assert torch.equal(a, qa.dequantize()) ================================================ FILE: tests/tensor/weights/test_weight_qbytes_tensor_backward.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from helpers import random_tensor from optimum.quanto import AbsmaxOptimizer, qint8, quantize_weight def test_weight_qbytes_tensor_requires_grad(device): w = random_tensor((10, 10), dtype=torch.float32).to(device) w.requires_grad = True scale = AbsmaxOptimizer()(w, qtype=qint8, axis=0) qw = quantize_weight(w, qtype=qint8, axis=0, scale=scale) assert qw.requires_grad is True def test_weight_qbytes_tensor_backward(device): w = random_tensor((10, 10), dtype=torch.float32).to(device) w.requires_grad = True scale = AbsmaxOptimizer()(w, qtype=qint8, axis=0) qw = quantize_weight(w, qtype=qint8, axis=0, scale=scale) gradient = torch.randn((10, 10)).to(device) # Backpropagate gradient to the inner float weights qw.dequantize().backward(gradient) assert torch.equal(w.grad, gradient) def test_weight_qbytes_tensor_chained_backward(device): a = random_tensor((10, 10), dtype=torch.float32).to(device) a.requires_grad = True scale = AbsmaxOptimizer()(a, qtype=qint8, axis=0) qa = quantize_weight(a, qtype=qint8, axis=0, scale=scale) b = random_tensor((10, 10), dtype=torch.float32).to(device) b.requires_grad = True scale = AbsmaxOptimizer()(b, qtype=qint8, axis=0) qb = quantize_weight(b, qtype=qint8, axis=0, scale=scale) # Evaluate the product prod = qa * qb # Backpropagate gradient = torch.randn((10, 10)).to(device) prod.backward(gradient) assert torch.allclose(a.grad, qb.dequantize() * gradient) assert torch.allclose(b.grad, qa.dequantize() * gradient) ================================================ FILE: tests/tensor/weights/test_weight_qbytes_tensor_dispatch.py ================================================ import pytest import torch from helpers import random_qweight, random_tensor from optimum.quanto import AbsmaxOptimizer, WeightQBytesTensor, qint8, quantize_weight def test_weight_qytes_tensor_to_device(device): qa = random_qweight((32, 32), qtype=qint8, dtype=torch.float) qa = qa.to(device) assert isinstance(qa, WeightQBytesTensor) assert qa.device.type == device.type assert qa._data.device.type == device.type assert qa._scale.device.type == device.type @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) def test_weight_qbytes_tensor_equal(dtype, qtype, axis, device): a = random_tensor((32, 32), dtype=dtype, device=device) scale = AbsmaxOptimizer()(a, qtype=qtype, axis=axis) qa1 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale) qa2 = quantize_weight(a, qtype=qtype, axis=axis, scale=scale) assert torch.equal(qa1, qa2) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("qtype", [qint8]) def test_weight_qbytes_tensor_transpose_contiguous(axis, qtype, device): input_shape = (16, 32) qa = random_qweight(input_shape, axis=axis, qtype=qtype, dtype=torch.float32).to(device) assert qa.is_contiguous() tqa = qa.t() assert isinstance(tqa, WeightQBytesTensor) assert not tqa.is_contiguous() tqa = tqa.contiguous() assert tqa.is_contiguous() @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) @pytest.mark.parametrize("qtype", [qint8]) def test_weight_qbytes_tensor_transposed_stride(axis, qtype, device): input_shape = (16, 32) a = random_tensor(input_shape, dtype=torch.float32).to(device) scale = AbsmaxOptimizer()(a, qtype=qtype, axis=axis) qa = quantize_weight(a, qtype=qtype, axis=axis, scale=scale) assert qa.stride() == a.stride() ta = a.t() tqa = qa.t() assert isinstance(tqa, WeightQBytesTensor) assert tqa.stride() == ta.stride() ================================================ FILE: tests/tensor/weights/test_weight_qbytes_tensor_instantiate.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from optimum.quanto import WeightQBytesTensor, qfloat8, qint8 def random_data_scale(input_shape, dtype, qtype): if qtype.is_floating_point: min_value = torch.finfo(qtype.dtype).min max_value = torch.finfo(qtype.dtype).max data = (torch.rand(input_shape) * max_value + min_value).to(qtype.dtype) else: max_value = torch.iinfo(qtype.dtype).max data = torch.randint(-max_value, max_value, input_shape, dtype=qtype.dtype) scale = torch.tensor(1.0 / max_value, dtype=dtype) return data, scale @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) def test_qbytestensor_instantiate(input_shape, dtype, qtype, device): if qtype.is_floating_point and device.type == "mps": pytest.skip("float8 types are not supported on MPS device") data, scale = random_data_scale(input_shape, dtype, qtype) qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale, activation_qtype=None).to( device ) assert torch.max(torch.abs(qa.dequantize())) <= 1 assert qa.dtype == dtype assert qa.qtype == qtype assert qa.shape == input_shape @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32], ids=["bf16", "fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) def test_qbytestensor_equal(input_shape, dtype, qtype, device): data, scale = random_data_scale(input_shape, dtype, qtype) qa = WeightQBytesTensor(qtype, None, data.size(), data.stride(), data, scale=scale, activation_qtype=None).to( device ) qb = WeightQBytesTensor( qtype, None, data.size(), data.stride(), data.clone(), scale=scale, activation_qtype=None ).to(device) assert qa.equal(qb) ================================================ FILE: tests/tensor/weights/test_weight_qbytes_tensor_quantize.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest import torch from helpers import assert_similar, device_eq, random_qweight, random_tensor from optimum.quanto import ( WeightQBytesTensor, absmax_scale, qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2, qint8, ) @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) @pytest.mark.parametrize( "axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"], ) def test_symmetric_quantize_int(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) qa = WeightQBytesTensor.quantize(a, qtype, axis, scale) assert isinstance(qa, WeightQBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) assert_similar(a, qa) @pytest.mark.skip_device("mps") @pytest.mark.parametrize("input_shape", [(32, 32), (32, 10, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize( "qtype", [qfloat8, qfloat8_e4m3fn, qfloat8_e4m3fnuz, qfloat8_e5m2], ids=["qfloat8", "qfloat8_e4m3fn", "qfloat8_e4m3fnuz", "qfloat8_e5m2"], ) @pytest.mark.parametrize( "axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"], ) def test_symmetric_quantize_float8(input_shape, dtype, qtype, axis, device): a = random_tensor(input_shape, dtype=dtype).to(device) scale = absmax_scale(a, qtype=qtype, axis=axis) qa = WeightQBytesTensor.quantize(a, qtype, axis, scale) assert isinstance(qa, WeightQBytesTensor) assert qa.dtype == dtype assert qa.qtype == qtype assert device_eq(qa.device, device) assert_similar(a, qa, atol=5e-3) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) def test_quantize_weight_axis_dim_1(axis, device): input_shape = (1, 32) if axis == 0 else (32, 1) qa = random_qweight(input_shape, dtype=torch.float32, qtype=qint8, axis=axis, device=device) # Quantizing along an axis of dimension 1 actually means per-tensor assert qa.axis is None ================================================ FILE: tests/tensor/weights/test_weight_qbytes_tensor_serialization.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import pytest import torch from helpers import random_qweight from optimum.quanto import qfloat8, qint8 @pytest.mark.parametrize("input_shape", [(10, 10), (10, 32, 32)]) @pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("axis", [0, -1], ids=["first-axis", "last-axis"]) def test_weights_qbytes_tensor_serialization(input_shape, qtype, dtype, axis): qinputs = random_qweight(input_shape, dtype=dtype, qtype=qtype, axis=axis) b = io.BytesIO() torch.save(qinputs, b) b.seek(0) qinputs_reloaded = torch.load(b, weights_only=False) assert qinputs_reloaded.qtype == qtype assert torch.equal(qinputs_reloaded._scale, qinputs._scale) if qtype.is_floating_point: # Equality is not supported for float8 assert torch.equal(qinputs_reloaded._data.to(torch.float32), qinputs._data.to(torch.float32)) else: assert torch.equal(qinputs_reloaded._data, qinputs._data) # We cannot test dtype directly as it is not correctly set by torch.load assert qinputs_reloaded._scale.dtype == dtype assert qinputs_reloaded.axis == qinputs.axis ================================================ FILE: tests/tensor/weights/weight_helpers.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from helpers import assert_similar, random_tensor def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_err=0.0): dtype = qweight.dtype device = qweight.device out_features, in_features = qweight.shape inputs = torch.rand((batch_size, tokens, in_features), dtype=dtype, device=device) bias = random_tensor((out_features,), dtype=dtype, device=device) if use_bias else None qout = torch.nn.functional.linear(inputs, qweight, bias) out = torch.nn.functional.linear(inputs, qweight.dequantize(), bias) # Verify global alignment assert_similar(out, qout) # Also look for outliers mean_val = out.abs().max() max_err = (out - qout).abs().max() rel_max_err = max_err / mean_val # These values were evaluated empirically without any optimized kernels. rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type] assert rel_max_err < rtol, ( f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err * 100:.2f} %)" )